-
Notifications
You must be signed in to change notification settings - Fork 1
norm stats bound check #285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: aniketh/bad-idx-fallback
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,10 +89,12 @@ def __init__( | |
| folder_path: Path, | ||
| key_map: dict | None = None, | ||
| transform_list: list | None = None, | ||
| norm_stats: dict | None = None, | ||
| ): | ||
| self.folder_path = Path(folder_path) | ||
| self.key_map = key_map | ||
| self.transform_list = transform_list | ||
| self.norm_stats = norm_stats | ||
|
|
||
| def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): | ||
| """ | ||
|
|
@@ -121,7 +123,10 @@ def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): | |
| continue | ||
| try: | ||
| ds_obj = ZarrDataset( | ||
| p, key_map=self.key_map, transform_list=self.transform_list | ||
| p, | ||
| key_map=self.key_map, | ||
| transform_list=self.transform_list, | ||
| norm_stats=self.norm_stats, | ||
| ) | ||
| datasets[name] = ds_obj | ||
| except Exception as e: | ||
|
|
@@ -149,12 +154,18 @@ def __init__( | |
| main_prefix: str = "processed_v3", | ||
| key_map: dict | None = None, | ||
| transform_list: list | None = None, | ||
| norm_stats: dict | None = None, | ||
| debug: bool = False, | ||
| ): | ||
| self.bucket_name = bucket_name | ||
| self.main_prefix = main_prefix | ||
| self.debug = debug | ||
| super().__init__(folder_path, key_map=key_map, transform_list=transform_list) | ||
| super().__init__( | ||
| folder_path, | ||
| key_map=key_map, | ||
| transform_list=transform_list, | ||
| norm_stats=norm_stats, | ||
| ) | ||
|
|
||
| def resolve( | ||
| self, | ||
|
|
@@ -378,9 +389,10 @@ def __init__( | |
| folder_path: Path, | ||
| key_map: dict | None = None, | ||
| transform_list: list | None = None, | ||
| norm_stats: dict | None = None, | ||
| debug=False, | ||
| ): | ||
| super().__init__(folder_path, key_map, transform_list) | ||
| super().__init__(folder_path, key_map, transform_list, norm_stats=norm_stats) | ||
| self.debug = debug | ||
|
|
||
| @staticmethod | ||
|
|
@@ -543,6 +555,22 @@ def __getitem__(self, idx): | |
|
|
||
| return data | ||
|
|
||
| def set_norm_stats(self, norm_stats: dict) -> None: | ||
| """ | ||
| Propagate norm_stats to all child ZarrDatasets. | ||
| Note: have to add recursive functionality to this, multidatset doesn't self propagate norm yet | ||
| If we stack 2 multi or more this won't work | ||
| Args: | ||
| norm_stats: dict mapping key names to {"quantile_1": tensor, "quantile_99": tensor, ...} | ||
| (typically data_schematic.norm_stats[embodiment_id]) | ||
| """ | ||
| for ds in self.datasets.values(): | ||
| ds.norm_stats = norm_stats | ||
| logger.info( | ||
| f"Set norm_stats with {len(norm_stats)} keys on " | ||
| f"{len(self.datasets)} ZarrDatasets" | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): | ||
| """ | ||
|
|
@@ -583,12 +611,17 @@ def __init__( | |
| Episode_path: Path, | ||
| key_map: dict, | ||
| transform_list: list | None = None, | ||
| norm_stats: dict | None = None, | ||
| ): | ||
| """ | ||
| Args: | ||
| episode_path: just a path to the designated zarr episode | ||
| key_map: dict mapping from dataset keys to zarr keys and horizon info, e.g. {"obs/image/front": {"zarr_key": "observations.images.front", "horizon": 4}, ...} | ||
| transform_list: list of Transform objects to apply to the data after loading, e.g. for action chunk transformations. Should be in order of application. | ||
| norm_stats: optional dict mapping dataset key names (same keys as key_map) to | ||
| {"quantile_1": tensor, "quantile_99": tensor} bounds. When provided, any | ||
| loaded sample whose values fall outside [quantile_1, quantile_99] for any | ||
| tracked key triggers the random index fallback. | ||
| """ | ||
| self.episode_path = Episode_path | ||
| self.metadata = None | ||
|
|
@@ -599,6 +632,8 @@ def __init__( | |
|
|
||
| self.key_map = key_map | ||
| self.transform = transform_list | ||
| self.norm_stats = norm_stats or {} | ||
| self._warned_violations: set[str] = set() | ||
| super().__init__() | ||
|
|
||
| def init_episode(self): | ||
|
|
@@ -697,15 +732,114 @@ def _pad_sequences(self, data, horizon: int | None) -> dict: | |
|
|
||
| return data | ||
|
|
||
| def _check_bounds(self, data: dict, idx: int) -> str | None: | ||
| """ | ||
| Check whether any tracked key's values fall outside [quantile_1, quantile_99]. | ||
|
|
||
| Logs detailed violation info and returns a log-prefix string if a violation | ||
| is found, else None. | ||
| """ | ||
| for k, stats in self.norm_stats.items(): | ||
| if k not in data: | ||
| continue | ||
| v = data[k] | ||
| if isinstance(v, torch.Tensor): | ||
| arr = v.float() | ||
| elif isinstance(v, np.ndarray): | ||
| arr = torch.from_numpy(v).float() | ||
| else: | ||
| continue | ||
| q1 = stats["quantile_1"] | ||
| q99 = stats["quantile_99"] | ||
| if isinstance(q1, np.ndarray): | ||
| q1 = torch.from_numpy(q1).float() | ||
| if isinstance(q99, np.ndarray): | ||
| q99 = torch.from_numpy(q99).float() | ||
|
|
||
| # sanity nan/inf check, don't think this is needed but going to keep just in case | ||
| has_nan = torch.any(torch.isnan(arr)) | ||
| has_inf = torch.any(torch.isinf(arr)) | ||
| if has_nan or has_inf: | ||
| nan_mask = torch.isnan(arr) | ||
| inf_mask = torch.isinf(arr) | ||
| n_nan = nan_mask.sum().item() | ||
| n_inf = inf_mask.sum().item() | ||
| bad_mask = nan_mask | inf_mask | ||
| bad_indices = bad_mask.nonzero(as_tuple=False).tolist() | ||
| bad_values = arr[bad_mask].tolist() | ||
| prefix = ( | ||
| f"NaN/Inf violation ep={Path(self.episode_path).name} " | ||
| f"frame={idx} key={k}" | ||
| ) | ||
| warn_key = f"nan_inf:{Path(self.episode_path).name}:{k}" | ||
| if warn_key not in self._warned_violations: | ||
| self._warned_violations.add(warn_key) | ||
| logger.warning( | ||
| f"{prefix} | n_nan={int(n_nan)} n_inf={int(n_inf)} " | ||
| f"indices={bad_indices[:10]} values={[f'{v:.4f}' for v in bad_values[:10]]}" | ||
| ) | ||
| return prefix | ||
|
|
||
| # regular bounds violation, if above q99 or below q1 log it and trigger fallback | ||
| below = arr < q1 | ||
| above = arr > q99 | ||
| if torch.any(below) or torch.any(above): | ||
| n_below = below.sum().item() | ||
| n_above = above.sum().item() | ||
| below_vals = arr[below].tolist() | ||
| above_vals = arr[above].tolist() | ||
| below_bounds = q1[below].tolist() | ||
| above_bounds = q99[above].tolist() | ||
| prefix = ( | ||
| f"Bounds violation ep={Path(self.episode_path).name} " | ||
| f"frame={idx} key={k}" | ||
| ) | ||
| warn_key = f"bounds:{Path(self.episode_path).name}:{k}" | ||
| if warn_key not in self._warned_violations: | ||
| self._warned_violations.add(warn_key) | ||
| logger.warning( | ||
| f"{prefix} | " | ||
| f"n_below={int(n_below)} below_vals={[f'{v:.4f}' for v in below_vals[:5]]} below_q1={[f'{b:.4f}' for b in below_bounds[:5]]} " | ||
| f"n_above={int(n_above)} above_vals={[f'{v:.4f}' for v in above_vals[:5]]} above_q99={[f'{b:.4f}' for b in above_bounds[:5]]}" | ||
| ) | ||
| return prefix | ||
|
|
||
| return None | ||
|
|
||
| def _get_fallback_idx( | ||
| self, | ||
| idx: int, | ||
| _fallback_origin: int | None, | ||
| _attempts: int | None, | ||
| log_prefix: str, | ||
| ) -> tuple[int, int, int]: | ||
| """ | ||
| Compute next frame index for fallback when decode/transform fails. | ||
| Strategy: randomly sample a different index from the episode. | ||
| Returns (next_idx, origin, attempts). Raises RuntimeError if attempts exceed episode length. | ||
| """ | ||
| origin = _fallback_origin if _fallback_origin is not None else idx | ||
| attempts = (_attempts or 0) + 1 | ||
| if attempts >= self.total_frames: | ||
| raise RuntimeError( | ||
| f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}" | ||
| ) | ||
| candidates = list(range(0, idx)) + list(range(idx + 1, self.total_frames)) | ||
| next_idx = random.choice(candidates) | ||
| logger.warning( | ||
| f"{log_prefix} | attempt {attempts}, trying random idx {next_idx}" | ||
| ) | ||
| return (next_idx, origin, attempts) | ||
|
|
||
| def _get_image_fallback_idx( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does get_image_fallback_idx have diff logic from get_fallback_idx
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i thought we didnt want to random resample frames, but just do closest clean frame |
||
| self, | ||
| idx: int, | ||
| _fallback_origin: int | None, | ||
| _direction: str | None, | ||
| log_prefix: str, | ||
| ) -> tuple[int, int, str]: | ||
| """ | ||
| Compute next frame index for fallback when decode/transform fails. | ||
| Compute next frame index for fallback when image decode fails. | ||
| Strategy: try left from origin (idx-5, idx-10, ...), then right (origin+5, ...). | ||
| Returns (next_idx, origin, direction). Raises RuntimeError if entire episode is bad. | ||
| """ | ||
|
|
@@ -727,10 +861,18 @@ def _get_fallback_idx( | |
| raise RuntimeError( | ||
| f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}" | ||
| ) | ||
| logger.warning(f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}") | ||
| logger.warning( | ||
| f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}" | ||
| ) | ||
| return (next_idx, origin, "right") | ||
|
|
||
| def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: str | None = None) -> dict[str, torch.Tensor]: | ||
| def __getitem__( | ||
| self, | ||
| idx: int, | ||
| _fallback_origin: int | None = None, | ||
| _attempts: int | None = None, | ||
| _direction: str | None = None, | ||
| ) -> dict[str, torch.Tensor]: | ||
| # Build keys_dict with ranges based on whether action chunking is enabled | ||
| data = {} | ||
| for k in self.key_map: | ||
|
|
@@ -759,11 +901,18 @@ def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: | |
| try: | ||
| decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") | ||
| except Exception: | ||
| next_idx, origin, direction = self._get_fallback_idx( | ||
| idx, _fallback_origin, _direction, | ||
| next_idx, origin, direction = self._get_image_fallback_idx( | ||
| idx, | ||
| _fallback_origin, | ||
| _direction, | ||
| f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}", | ||
| ) | ||
| result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction) | ||
| result = self.__getitem__( | ||
| next_idx, | ||
| _fallback_origin=origin, | ||
| _attempts=_attempts, | ||
| _direction=direction, | ||
| ) | ||
| return result | ||
| data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0 | ||
| elif zarr_key in self._json_keys: | ||
|
|
@@ -780,13 +929,33 @@ def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: | |
| try: | ||
| data = transform.transform(data) | ||
| except Exception as e: | ||
| next_idx, origin, direction = self._get_fallback_idx( | ||
| idx, _fallback_origin, _direction, | ||
| next_idx, origin, attempts = self._get_fallback_idx( | ||
| idx, | ||
| _fallback_origin, | ||
| _attempts, | ||
| f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})", | ||
| ) | ||
| result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction) | ||
| result = self.__getitem__( | ||
| next_idx, | ||
| _fallback_origin=origin, | ||
| _attempts=attempts, | ||
| _direction=_direction, | ||
| ) | ||
| return result | ||
|
|
||
| if self.norm_stats: | ||
| violation = self._check_bounds(data, idx) | ||
| if violation is not None: | ||
| next_idx, origin, attempts = self._get_fallback_idx( | ||
| idx, _fallback_origin, _attempts, violation | ||
| ) | ||
| return self.__getitem__( | ||
| next_idx, | ||
| _fallback_origin=origin, | ||
| _attempts=attempts, | ||
| _direction=_direction, | ||
| ) | ||
|
|
||
| for k, v in data.items(): | ||
| if isinstance(v, np.ndarray): | ||
| data[k] = torch.from_numpy(v).to(torch.float32) | ||
|
|
@@ -798,6 +967,7 @@ def get_item_keys( | |
| idx: int, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Re: line +965] This function isn't being used so all these changes can be deleted, right? See this comment inline on Graphite. |
||
| keys, | ||
| _fallback_origin: int | None = None, | ||
| _attempts: int | None = None, | ||
| _direction: str | None = None, | ||
| ) -> dict[str, torch.Tensor]: | ||
| requested = self._normalize_keys_arg(keys) | ||
|
|
@@ -823,12 +993,21 @@ def get_item_keys( | |
| val = raw[zarr_key] | ||
|
|
||
| if zarr_key in self._image_keys: | ||
|
|
||
| def _jpeg_fallback() -> dict[str, torch.Tensor]: | ||
| next_idx, origin, direction = self._get_fallback_idx( | ||
| idx, _fallback_origin, _direction, | ||
| next_idx, origin, direction = self._get_image_fallback_idx( | ||
| idx, | ||
| _fallback_origin, | ||
| _direction, | ||
| f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}", | ||
| ) | ||
| result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction) | ||
| result = self.get_item_keys( | ||
| next_idx, | ||
| keys, | ||
| _fallback_origin=origin, | ||
| _attempts=_attempts, | ||
| _direction=direction, | ||
| ) | ||
| return result | ||
|
|
||
| if ( | ||
|
|
@@ -858,13 +1037,35 @@ def _jpeg_fallback() -> dict[str, torch.Tensor]: | |
| try: | ||
| out = transform.transform(out) | ||
| except Exception as e: | ||
| next_idx, origin, direction = self._get_fallback_idx( | ||
| idx, _fallback_origin, _direction, | ||
| next_idx, origin, attempts = self._get_fallback_idx( | ||
| idx, | ||
| _fallback_origin, | ||
| _attempts, | ||
| f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})", | ||
| ) | ||
| result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction) | ||
| result = self.get_item_keys( | ||
| next_idx, | ||
| keys, | ||
| _fallback_origin=origin, | ||
| _attempts=attempts, | ||
| _direction=_direction, | ||
| ) | ||
| return result | ||
|
|
||
| if self.norm_stats: | ||
| violation = self._check_bounds(out, idx) | ||
| if violation is not None: | ||
| next_idx, origin, attempts = self._get_fallback_idx( | ||
| idx, _fallback_origin, _attempts, violation | ||
| ) | ||
| return self.get_item_keys( | ||
| next_idx, | ||
| keys, | ||
| _fallback_origin=origin, | ||
| _attempts=attempts, | ||
| _direction=_direction, | ||
| ) | ||
|
|
||
| for k, v in out.items(): | ||
| if isinstance(v, np.ndarray): | ||
| out[k] = torch.from_numpy(v).to(torch.float32) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? I thought when we instantiate ZarrDataset we pass in the norm stats?