diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index bb7b08b7..02cde129 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -697,7 +697,40 @@ def _pad_sequences(self, data, horizon: int | None) -> dict: return data - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + def _get_fallback_idx( + self, + idx: int, + _fallback_origin: int | None, + _direction: str | None, + log_prefix: str, + ) -> tuple[int, int, str]: + """ + Compute next frame index for fallback when decode/transform fails. + Strategy: try left from origin (idx-5, idx-10, ...), then right (origin+5, ...). + Returns (next_idx, origin, direction). Raises RuntimeError if entire episode is bad. + """ + origin = _fallback_origin if _fallback_origin is not None else idx + if _direction == "right": + next_idx = idx + 5 + if next_idx >= self.total_frames: + raise RuntimeError( + f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}" + ) + logger.warning(f"{log_prefix} | trying right {next_idx}") + return (next_idx, origin, "right") + if idx > 0: + next_idx = max(0, idx - 5) + logger.warning(f"{log_prefix} | trying left idx {next_idx}") + return (next_idx, origin, "left") + next_idx = origin + 5 + if next_idx >= self.total_frames: + raise RuntimeError( + f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}" + ) + logger.warning(f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}") + return (next_idx, origin, "right") + + def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: str | None = None) -> dict[str, torch.Tensor]: # Build keys_dict with ranges based on whether action chunking is enabled data = {} for k in self.key_map: @@ -723,8 +756,15 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if zarr_key in self._image_keys: jpeg_bytes = data[k] # Decode JPEG bytes to numpy array (H, W, 3) - decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") - # data[k] = torch.from_numpy(np.transpose(decoded, (2, 0, 1))).to(torch.float32) / 255.0 + try: + decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") + except Exception: + next_idx, origin, direction = self._get_fallback_idx( + idx, _fallback_origin, _direction, + f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}", + ) + result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction) + return result data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0 elif zarr_key in self._json_keys: if isinstance(data[k], np.ndarray): @@ -740,15 +780,12 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: try: data = transform.transform(data) except Exception as e: - logger.error(f"Error transforming data: {e}") - logger.error(f"Data: {data}") - logger.error(f"Transform: {transform}") - logger.error(f"Error: {e}") - if idx == 0: - logger.error("Error in first frame") - raise e - else: - return self.__getitem__(0) + next_idx, origin, direction = self._get_fallback_idx( + idx, _fallback_origin, _direction, + f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})", + ) + result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction) + return result for k, v in data.items(): if isinstance(v, np.ndarray): @@ -756,7 +793,13 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: return data - def get_item_keys(self, idx: int, keys) -> dict[str, torch.Tensor]: + def get_item_keys( + self, + idx: int, + keys, + _fallback_origin: int | None = None, + _direction: str | None = None, + ) -> dict[str, torch.Tensor]: requested = self._normalize_keys_arg(keys) out = {} @@ -780,6 +823,14 @@ def get_item_keys(self, idx: int, keys) -> dict[str, torch.Tensor]: val = raw[zarr_key] if zarr_key in self._image_keys: + def _jpeg_fallback() -> dict[str, torch.Tensor]: + next_idx, origin, direction = self._get_fallback_idx( + idx, _fallback_origin, _direction, + f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}", + ) + result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction) + return result + if ( isinstance(val, np.ndarray) and val.dtype == object @@ -787,11 +838,17 @@ def get_item_keys(self, idx: int, keys) -> dict[str, torch.Tensor]: ): decoded_seq = [] for jpeg_bytes in val: - img = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") + try: + img = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") + except Exception: + return _jpeg_fallback() decoded_seq.append(np.transpose(img, (2, 0, 1)) / 255.0) val = np.stack(decoded_seq, axis=0) else: - img = simplejpeg.decode_jpeg(val, colorspace="RGB") + try: + img = simplejpeg.decode_jpeg(val, colorspace="RGB") + except Exception: + return _jpeg_fallback() val = np.transpose(img, (2, 0, 1)) / 255.0 out[k] = val @@ -801,16 +858,12 @@ def get_item_keys(self, idx: int, keys) -> dict[str, torch.Tensor]: try: out = transform.transform(out) except Exception as e: - logger.error(f"Error transforming data: {e}") - # NOTE: avoid dumping full arrays into logs - logger.error(f"Data keys: {list(out.keys())}") - logger.error(f"Transform: {transform}") - logger.error(f"Error: {e}") - if idx == 0: - logger.error("Error in first frame") - raise e - else: - return self.get_item_keys(0, keys) + next_idx, origin, direction = self._get_fallback_idx( + idx, _fallback_origin, _direction, + f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})", + ) + result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction) + return result for k, v in out.items(): if isinstance(v, np.ndarray):