diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 02cde129..dbb2af61 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -89,10 +89,12 @@ def __init__( folder_path: Path, key_map: dict | None = None, transform_list: list | None = None, + norm_stats: dict | None = None, ): self.folder_path = Path(folder_path) self.key_map = key_map self.transform_list = transform_list + self.norm_stats = norm_stats def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): """ @@ -121,7 +123,10 @@ def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): continue try: ds_obj = ZarrDataset( - p, key_map=self.key_map, transform_list=self.transform_list + p, + key_map=self.key_map, + transform_list=self.transform_list, + norm_stats=self.norm_stats, ) datasets[name] = ds_obj except Exception as e: @@ -149,12 +154,18 @@ def __init__( main_prefix: str = "processed_v3", key_map: dict | None = None, transform_list: list | None = None, + norm_stats: dict | None = None, debug: bool = False, ): self.bucket_name = bucket_name self.main_prefix = main_prefix self.debug = debug - super().__init__(folder_path, key_map=key_map, transform_list=transform_list) + super().__init__( + folder_path, + key_map=key_map, + transform_list=transform_list, + norm_stats=norm_stats, + ) def resolve( self, @@ -378,9 +389,10 @@ def __init__( folder_path: Path, key_map: dict | None = None, transform_list: list | None = None, + norm_stats: dict | None = None, debug=False, ): - super().__init__(folder_path, key_map, transform_list) + super().__init__(folder_path, key_map, transform_list, norm_stats=norm_stats) self.debug = debug @staticmethod @@ -543,6 +555,22 @@ def __getitem__(self, idx): return data + def set_norm_stats(self, norm_stats: dict) -> None: + """ + Propagate norm_stats to all child ZarrDatasets. + Note: have to add recursive functionality to this, multidatset doesn't self propagate norm yet + If we stack 2 multi or more this won't work + Args: + norm_stats: dict mapping key names to {"quantile_1": tensor, "quantile_99": tensor, ...} + (typically data_schematic.norm_stats[embodiment_id]) + """ + for ds in self.datasets.values(): + ds.norm_stats = norm_stats + logger.info( + f"Set norm_stats with {len(norm_stats)} keys on " + f"{len(self.datasets)} ZarrDatasets" + ) + @classmethod def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): """ @@ -583,12 +611,17 @@ def __init__( Episode_path: Path, key_map: dict, transform_list: list | None = None, + norm_stats: dict | None = None, ): """ Args: episode_path: just a path to the designated zarr episode key_map: dict mapping from dataset keys to zarr keys and horizon info, e.g. {"obs/image/front": {"zarr_key": "observations.images.front", "horizon": 4}, ...} transform_list: list of Transform objects to apply to the data after loading, e.g. for action chunk transformations. Should be in order of application. + norm_stats: optional dict mapping dataset key names (same keys as key_map) to + {"quantile_1": tensor, "quantile_99": tensor} bounds. When provided, any + loaded sample whose values fall outside [quantile_1, quantile_99] for any + tracked key triggers the random index fallback. """ self.episode_path = Episode_path self.metadata = None @@ -599,6 +632,8 @@ def __init__( self.key_map = key_map self.transform = transform_list + self.norm_stats = norm_stats or {} + self._warned_violations: set[str] = set() super().__init__() def init_episode(self): @@ -697,7 +732,106 @@ def _pad_sequences(self, data, horizon: int | None) -> dict: return data + def _check_bounds(self, data: dict, idx: int) -> str | None: + """ + Check whether any tracked key's values fall outside [quantile_1, quantile_99]. + + Logs detailed violation info and returns a log-prefix string if a violation + is found, else None. + """ + for k, stats in self.norm_stats.items(): + if k not in data: + continue + v = data[k] + if isinstance(v, torch.Tensor): + arr = v.float() + elif isinstance(v, np.ndarray): + arr = torch.from_numpy(v).float() + else: + continue + q1 = stats["quantile_1"] + q99 = stats["quantile_99"] + if isinstance(q1, np.ndarray): + q1 = torch.from_numpy(q1).float() + if isinstance(q99, np.ndarray): + q99 = torch.from_numpy(q99).float() + + # sanity nan/inf check, don't think this is needed but going to keep just in case + has_nan = torch.any(torch.isnan(arr)) + has_inf = torch.any(torch.isinf(arr)) + if has_nan or has_inf: + nan_mask = torch.isnan(arr) + inf_mask = torch.isinf(arr) + n_nan = nan_mask.sum().item() + n_inf = inf_mask.sum().item() + bad_mask = nan_mask | inf_mask + bad_indices = bad_mask.nonzero(as_tuple=False).tolist() + bad_values = arr[bad_mask].tolist() + prefix = ( + f"NaN/Inf violation ep={Path(self.episode_path).name} " + f"frame={idx} key={k}" + ) + warn_key = f"nan_inf:{Path(self.episode_path).name}:{k}" + if warn_key not in self._warned_violations: + self._warned_violations.add(warn_key) + logger.warning( + f"{prefix} | n_nan={int(n_nan)} n_inf={int(n_inf)} " + f"indices={bad_indices[:10]} values={[f'{v:.4f}' for v in bad_values[:10]]}" + ) + return prefix + + # regular bounds violation, if above q99 or below q1 log it and trigger fallback + below = arr < q1 + above = arr > q99 + if torch.any(below) or torch.any(above): + n_below = below.sum().item() + n_above = above.sum().item() + below_vals = arr[below].tolist() + above_vals = arr[above].tolist() + below_bounds = q1[below].tolist() + above_bounds = q99[above].tolist() + prefix = ( + f"Bounds violation ep={Path(self.episode_path).name} " + f"frame={idx} key={k}" + ) + warn_key = f"bounds:{Path(self.episode_path).name}:{k}" + if warn_key not in self._warned_violations: + self._warned_violations.add(warn_key) + logger.warning( + f"{prefix} | " + f"n_below={int(n_below)} below_vals={[f'{v:.4f}' for v in below_vals[:5]]} below_q1={[f'{b:.4f}' for b in below_bounds[:5]]} " + f"n_above={int(n_above)} above_vals={[f'{v:.4f}' for v in above_vals[:5]]} above_q99={[f'{b:.4f}' for b in above_bounds[:5]]}" + ) + return prefix + + return None + def _get_fallback_idx( + self, + idx: int, + _fallback_origin: int | None, + _attempts: int | None, + log_prefix: str, + ) -> tuple[int, int, int]: + """ + Compute next frame index for fallback when decode/transform fails. + Strategy: randomly sample a different index from the episode. + Returns (next_idx, origin, attempts). Raises RuntimeError if attempts exceed episode length. + """ + origin = _fallback_origin if _fallback_origin is not None else idx + attempts = (_attempts or 0) + 1 + if attempts >= self.total_frames: + raise RuntimeError( + f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}" + ) + candidates = list(range(0, idx)) + list(range(idx + 1, self.total_frames)) + next_idx = random.choice(candidates) + logger.warning( + f"{log_prefix} | attempt {attempts}, trying random idx {next_idx}" + ) + return (next_idx, origin, attempts) + + def _get_image_fallback_idx( self, idx: int, _fallback_origin: int | None, @@ -705,7 +839,7 @@ def _get_fallback_idx( log_prefix: str, ) -> tuple[int, int, str]: """ - Compute next frame index for fallback when decode/transform fails. + Compute next frame index for fallback when image decode fails. Strategy: try left from origin (idx-5, idx-10, ...), then right (origin+5, ...). Returns (next_idx, origin, direction). Raises RuntimeError if entire episode is bad. """ @@ -727,10 +861,18 @@ def _get_fallback_idx( raise RuntimeError( f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}" ) - logger.warning(f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}") + logger.warning( + f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}" + ) return (next_idx, origin, "right") - def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: str | None = None) -> dict[str, torch.Tensor]: + def __getitem__( + self, + idx: int, + _fallback_origin: int | None = None, + _attempts: int | None = None, + _direction: str | None = None, + ) -> dict[str, torch.Tensor]: # Build keys_dict with ranges based on whether action chunking is enabled data = {} for k in self.key_map: @@ -759,11 +901,18 @@ def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: try: decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") except Exception: - next_idx, origin, direction = self._get_fallback_idx( - idx, _fallback_origin, _direction, + next_idx, origin, direction = self._get_image_fallback_idx( + idx, + _fallback_origin, + _direction, f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}", ) - result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction) + result = self.__getitem__( + next_idx, + _fallback_origin=origin, + _attempts=_attempts, + _direction=direction, + ) return result data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0 elif zarr_key in self._json_keys: @@ -780,13 +929,33 @@ def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: try: data = transform.transform(data) except Exception as e: - next_idx, origin, direction = self._get_fallback_idx( - idx, _fallback_origin, _direction, + next_idx, origin, attempts = self._get_fallback_idx( + idx, + _fallback_origin, + _attempts, f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})", ) - result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction) + result = self.__getitem__( + next_idx, + _fallback_origin=origin, + _attempts=attempts, + _direction=_direction, + ) return result + if self.norm_stats: + violation = self._check_bounds(data, idx) + if violation is not None: + next_idx, origin, attempts = self._get_fallback_idx( + idx, _fallback_origin, _attempts, violation + ) + return self.__getitem__( + next_idx, + _fallback_origin=origin, + _attempts=attempts, + _direction=_direction, + ) + for k, v in data.items(): if isinstance(v, np.ndarray): data[k] = torch.from_numpy(v).to(torch.float32) @@ -798,6 +967,7 @@ def get_item_keys( idx: int, keys, _fallback_origin: int | None = None, + _attempts: int | None = None, _direction: str | None = None, ) -> dict[str, torch.Tensor]: requested = self._normalize_keys_arg(keys) @@ -823,12 +993,21 @@ def get_item_keys( val = raw[zarr_key] if zarr_key in self._image_keys: + def _jpeg_fallback() -> dict[str, torch.Tensor]: - next_idx, origin, direction = self._get_fallback_idx( - idx, _fallback_origin, _direction, + next_idx, origin, direction = self._get_image_fallback_idx( + idx, + _fallback_origin, + _direction, f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}", ) - result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction) + result = self.get_item_keys( + next_idx, + keys, + _fallback_origin=origin, + _attempts=_attempts, + _direction=direction, + ) return result if ( @@ -858,13 +1037,35 @@ def _jpeg_fallback() -> dict[str, torch.Tensor]: try: out = transform.transform(out) except Exception as e: - next_idx, origin, direction = self._get_fallback_idx( - idx, _fallback_origin, _direction, + next_idx, origin, attempts = self._get_fallback_idx( + idx, + _fallback_origin, + _attempts, f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})", ) - result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction) + result = self.get_item_keys( + next_idx, + keys, + _fallback_origin=origin, + _attempts=attempts, + _direction=_direction, + ) return result + if self.norm_stats: + violation = self._check_bounds(out, idx) + if violation is not None: + next_idx, origin, attempts = self._get_fallback_idx( + idx, _fallback_origin, _attempts, violation + ) + return self.get_item_keys( + next_idx, + keys, + _fallback_origin=origin, + _attempts=attempts, + _direction=_direction, + ) + for k, v in out.items(): if isinstance(v, np.ndarray): out[k] = torch.from_numpy(v).to(torch.float32) diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index 8e4540e1..9ce889ae 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf from tabulate import tabulate +from egomimic.rldb.embodiment.embodiment import get_embodiment_id from egomimic.rldb.zarr.utils import DataSchematic, set_global_seed from egomimic.scripts.evaluation.eval import Eval from egomimic.utils.aws.aws_data_utils import load_env @@ -114,12 +115,41 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: data_schematic.infer_norm_from_dataset( norm_dataset, dataset_name, - sample_frac=0.005, + sample_frac=0.02, benchmark_dir=os.path.join( cfg.trainer.default_root_dir, "benchmark_stats.json" ), ) + # Propagate norm stats to all zarrdatasets out of bounds check in getitem + # Have to remap keys to zarr keys for continuity + for split_name, split_datasets in [ + ("train", train_datasets), + ("valid", valid_datasets), + ]: + for dataset_name, dataset in split_datasets.items(): + embodiment_id = get_embodiment_id(dataset_name) + if ( + embodiment_id in data_schematic.norm_stats + and data_schematic.norm_stats[embodiment_id] + ): + remapped = {} + for key_name, stats in data_schematic.norm_stats[embodiment_id].items(): + zarr_key = data_schematic.keyname_to_zarr_key( + key_name, embodiment_id + ) + if zarr_key is not None: + remapped[zarr_key] = stats + log.info( + f"Passing norm_stats to {split_name} dataset <{dataset_name}> (embodiment_id={embodiment_id}): " + f"{list(remapped.keys())}" + ) + dataset.set_norm_stats(remapped) + else: + log.warning( + f"No norm_stats found for {split_name} dataset <{dataset_name}> (embodiment_id={embodiment_id}), skipping bounds filtering" + ) + # NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see. log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(