Skip to content

fix: array IndexError when center_agent_idx is not set as sdc#50

Open
SS47816 wants to merge 1 commit intowaymo-research:mainfrom
SS47816:main
Open

fix: array IndexError when center_agent_idx is not set as sdc#50
SS47816 wants to merge 1 commit intowaymo-research:mainfrom
SS47816:main

Conversation

@SS47816
Copy link

@SS47816 SS47816 commented Feb 1, 2024

Fixed the behavior of the function plot_simulator_state() when the parameter center_agent_idx in viz_config is not -1 # sdc.

Original Issue

When the parameter center_agent_idx in viz_config is set to a user-specified index, the following issue will arise:

File [~/anaconda3/envs/waymax/lib/python3.10/site-packages/waymax/visualization/viz.py:287], in plot_simulator_state(state, use_log_traj, viz_config, batch_idx, highlight_obj)
    285 else:
    286   xy = current_xy[viz_config.center_agent_idx]
--> 287 origin_x, origin_y = xy[0, :2]
    288 ax.axis((
    289     origin_x - viz_config.back_x,
    290     origin_x + viz_config.front_x,
    291     origin_y - viz_config.back_y,
    292     origin_y + viz_config.front_y,
    293 ))
    295 return utils.img_from_fig(fig)

File [~/anaconda3/envs/waymax/lib/python3.10/site-packages/jax/_src/array.py:314], in ArrayImpl.__getitem__(self, idx)
    312   num_idx = sum(e is not None and e is not Ellipsis for e in idx)
    313   if num_idx > self.ndim:
--> 314     raise IndexError(
    315         f"Too many indices for array: array has ndim of {self.ndim}, but "
...
    316         f"was indexed with {num_idx} non-None[/Ellipsis](https://file+.vscode-resource.vscode-cdn.net/Ellipsis) indices.")
    318 if isinstance(self.sharding, PmapSharding):
    319   if not isinstance(idx, tuple):

IndexError: Too many indices for array: array has ndim of 1, but was indexed with 2 non-None[/Ellipsis](https://file+.vscode-resource.vscode-cdn.net/Ellipsis) indices.

This issue occurred because the shape of the xy = current_xy[state.object_metadata.is_sdc] is [1, 2] (2-dimension) whereas the shape of the xy = current_xy[viz_config.center_agent_idx] is [2] (1-dimension).

Fix

The following code in visualization/viz.py, from the line 280 to 294:

  # 3. Gets np img, centered on selected agent's current location.
  # [A, 2]
  current_xy = traj.xy[:, state.timestep, :]
  if viz_config.center_agent_idx == -1:
    xy = current_xy[state.object_metadata.is_sdc]
  else:
    xy = current_xy[viz_config.center_agent_idx]
  origin_x, origin_y = xy[0, :2]
  ax.axis((
      origin_x - viz_config.back_x,
      origin_x + viz_config.front_x,
      origin_y - viz_config.back_y,
      origin_y + viz_config.front_y,
  ))

has been changed to:

  # 3. Gets np img, centered on selected agent's current location.
  # [A, 2]
  current_xy = traj.xy[:, state.timestep, :]
  if viz_config.center_agent_idx == -1:
    xy = current_xy[state.object_metadata.is_sdc]
    origin_x, origin_y = xy[0, :2]
  else:
    xy = current_xy[viz_config.center_agent_idx]
    origin_x, origin_y = xy[:2]
  ax.axis((
      origin_x - viz_config.back_x,
      origin_x + viz_config.front_x,
      origin_y - viz_config.back_y,
      origin_y + viz_config.front_y,
  ))

so that the shapes of the origin_x, origin_y, and xy in the if-else statement are now properly aligned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant