Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 219 additions & 18 deletions egomimic/rldb/zarr/zarr_dataset_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ def __init__(
folder_path: Path,
key_map: dict | None = None,
transform_list: list | None = None,
norm_stats: dict | None = None,
):
self.folder_path = Path(folder_path)
self.key_map = key_map
self.transform_list = transform_list
self.norm_stats = norm_stats

def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]):
"""
Expand Down Expand Up @@ -121,7 +123,10 @@ def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]):
continue
try:
ds_obj = ZarrDataset(
p, key_map=self.key_map, transform_list=self.transform_list
p,
key_map=self.key_map,
transform_list=self.transform_list,
norm_stats=self.norm_stats,
)
datasets[name] = ds_obj
except Exception as e:
Expand Down Expand Up @@ -149,12 +154,18 @@ def __init__(
main_prefix: str = "processed_v3",
key_map: dict | None = None,
transform_list: list | None = None,
norm_stats: dict | None = None,
debug: bool = False,
):
self.bucket_name = bucket_name
self.main_prefix = main_prefix
self.debug = debug
super().__init__(folder_path, key_map=key_map, transform_list=transform_list)
super().__init__(
folder_path,
key_map=key_map,
transform_list=transform_list,
norm_stats=norm_stats,
)

def resolve(
self,
Expand Down Expand Up @@ -378,9 +389,10 @@ def __init__(
folder_path: Path,
key_map: dict | None = None,
transform_list: list | None = None,
norm_stats: dict | None = None,
debug=False,
):
super().__init__(folder_path, key_map, transform_list)
super().__init__(folder_path, key_map, transform_list, norm_stats=norm_stats)
self.debug = debug

@staticmethod
Expand Down Expand Up @@ -543,6 +555,22 @@ def __getitem__(self, idx):

return data

def set_norm_stats(self, norm_stats: dict) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? I thought when we instantiate ZarrDataset we pass in the norm stats?

"""
Propagate norm_stats to all child ZarrDatasets.
Note: have to add recursive functionality to this, multidatset doesn't self propagate norm yet
If we stack 2 multi or more this won't work
Args:
norm_stats: dict mapping key names to {"quantile_1": tensor, "quantile_99": tensor, ...}
(typically data_schematic.norm_stats[embodiment_id])
"""
for ds in self.datasets.values():
ds.norm_stats = norm_stats
logger.info(
f"Set norm_stats with {len(norm_stats)} keys on "
f"{len(self.datasets)} ZarrDatasets"
)

@classmethod
def _from_resolver(cls, resolver: EpisodeResolver, **kwargs):
"""
Expand Down Expand Up @@ -583,12 +611,17 @@ def __init__(
Episode_path: Path,
key_map: dict,
transform_list: list | None = None,
norm_stats: dict | None = None,
):
"""
Args:
episode_path: just a path to the designated zarr episode
key_map: dict mapping from dataset keys to zarr keys and horizon info, e.g. {"obs/image/front": {"zarr_key": "observations.images.front", "horizon": 4}, ...}
transform_list: list of Transform objects to apply to the data after loading, e.g. for action chunk transformations. Should be in order of application.
norm_stats: optional dict mapping dataset key names (same keys as key_map) to
{"quantile_1": tensor, "quantile_99": tensor} bounds. When provided, any
loaded sample whose values fall outside [quantile_1, quantile_99] for any
tracked key triggers the random index fallback.
"""
self.episode_path = Episode_path
self.metadata = None
Expand All @@ -599,6 +632,8 @@ def __init__(

self.key_map = key_map
self.transform = transform_list
self.norm_stats = norm_stats or {}
self._warned_violations: set[str] = set()
super().__init__()

def init_episode(self):
Expand Down Expand Up @@ -697,15 +732,114 @@ def _pad_sequences(self, data, horizon: int | None) -> dict:

return data

def _check_bounds(self, data: dict, idx: int) -> str | None:
"""
Check whether any tracked key's values fall outside [quantile_1, quantile_99].

Logs detailed violation info and returns a log-prefix string if a violation
is found, else None.
"""
for k, stats in self.norm_stats.items():
if k not in data:
continue
v = data[k]
if isinstance(v, torch.Tensor):
arr = v.float()
elif isinstance(v, np.ndarray):
arr = torch.from_numpy(v).float()
else:
continue
q1 = stats["quantile_1"]
q99 = stats["quantile_99"]
if isinstance(q1, np.ndarray):
q1 = torch.from_numpy(q1).float()
if isinstance(q99, np.ndarray):
q99 = torch.from_numpy(q99).float()

# sanity nan/inf check, don't think this is needed but going to keep just in case
has_nan = torch.any(torch.isnan(arr))
has_inf = torch.any(torch.isinf(arr))
if has_nan or has_inf:
nan_mask = torch.isnan(arr)
inf_mask = torch.isinf(arr)
n_nan = nan_mask.sum().item()
n_inf = inf_mask.sum().item()
bad_mask = nan_mask | inf_mask
bad_indices = bad_mask.nonzero(as_tuple=False).tolist()
bad_values = arr[bad_mask].tolist()
prefix = (
f"NaN/Inf violation ep={Path(self.episode_path).name} "
f"frame={idx} key={k}"
)
warn_key = f"nan_inf:{Path(self.episode_path).name}:{k}"
if warn_key not in self._warned_violations:
self._warned_violations.add(warn_key)
logger.warning(
f"{prefix} | n_nan={int(n_nan)} n_inf={int(n_inf)} "
f"indices={bad_indices[:10]} values={[f'{v:.4f}' for v in bad_values[:10]]}"
)
return prefix

# regular bounds violation, if above q99 or below q1 log it and trigger fallback
below = arr < q1
above = arr > q99
if torch.any(below) or torch.any(above):
n_below = below.sum().item()
n_above = above.sum().item()
below_vals = arr[below].tolist()
above_vals = arr[above].tolist()
below_bounds = q1[below].tolist()
above_bounds = q99[above].tolist()
prefix = (
f"Bounds violation ep={Path(self.episode_path).name} "
f"frame={idx} key={k}"
)
warn_key = f"bounds:{Path(self.episode_path).name}:{k}"
if warn_key not in self._warned_violations:
self._warned_violations.add(warn_key)
logger.warning(
f"{prefix} | "
f"n_below={int(n_below)} below_vals={[f'{v:.4f}' for v in below_vals[:5]]} below_q1={[f'{b:.4f}' for b in below_bounds[:5]]} "
f"n_above={int(n_above)} above_vals={[f'{v:.4f}' for v in above_vals[:5]]} above_q99={[f'{b:.4f}' for b in above_bounds[:5]]}"
)
return prefix

return None

def _get_fallback_idx(
self,
idx: int,
_fallback_origin: int | None,
_attempts: int | None,
log_prefix: str,
) -> tuple[int, int, int]:
"""
Compute next frame index for fallback when decode/transform fails.
Strategy: randomly sample a different index from the episode.
Returns (next_idx, origin, attempts). Raises RuntimeError if attempts exceed episode length.
"""
origin = _fallback_origin if _fallback_origin is not None else idx
attempts = (_attempts or 0) + 1
if attempts >= self.total_frames:
raise RuntimeError(
f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}"
)
candidates = list(range(0, idx)) + list(range(idx + 1, self.total_frames))
next_idx = random.choice(candidates)
logger.warning(
f"{log_prefix} | attempt {attempts}, trying random idx {next_idx}"
)
return (next_idx, origin, attempts)

def _get_image_fallback_idx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does get_image_fallback_idx have diff logic from get_fallback_idx

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought we didnt want to random resample frames, but just do closest clean frame

self,
idx: int,
_fallback_origin: int | None,
_direction: str | None,
log_prefix: str,
) -> tuple[int, int, str]:
"""
Compute next frame index for fallback when decode/transform fails.
Compute next frame index for fallback when image decode fails.
Strategy: try left from origin (idx-5, idx-10, ...), then right (origin+5, ...).
Returns (next_idx, origin, direction). Raises RuntimeError if entire episode is bad.
"""
Expand All @@ -727,10 +861,18 @@ def _get_fallback_idx(
raise RuntimeError(
f"Entire episode bad (no valid indices): ep={Path(self.episode_path).name}"
)
logger.warning(f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}")
logger.warning(
f"{log_prefix} | left exhausted, trying right from origin idx {next_idx}"
)
return (next_idx, origin, "right")

def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction: str | None = None) -> dict[str, torch.Tensor]:
def __getitem__(
self,
idx: int,
_fallback_origin: int | None = None,
_attempts: int | None = None,
_direction: str | None = None,
) -> dict[str, torch.Tensor]:
# Build keys_dict with ranges based on whether action chunking is enabled
data = {}
for k in self.key_map:
Expand Down Expand Up @@ -759,11 +901,18 @@ def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction:
try:
decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB")
except Exception:
next_idx, origin, direction = self._get_fallback_idx(
idx, _fallback_origin, _direction,
next_idx, origin, direction = self._get_image_fallback_idx(
idx,
_fallback_origin,
_direction,
f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}",
)
result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction)
result = self.__getitem__(
next_idx,
_fallback_origin=origin,
_attempts=_attempts,
_direction=direction,
)
return result
data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0
elif zarr_key in self._json_keys:
Expand All @@ -780,13 +929,33 @@ def __getitem__(self, idx: int, _fallback_origin: int | None = None, _direction:
try:
data = transform.transform(data)
except Exception as e:
next_idx, origin, direction = self._get_fallback_idx(
idx, _fallback_origin, _direction,
next_idx, origin, attempts = self._get_fallback_idx(
idx,
_fallback_origin,
_attempts,
f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})",
)
result = self.__getitem__(next_idx, _fallback_origin=origin, _direction=direction)
result = self.__getitem__(
next_idx,
_fallback_origin=origin,
_attempts=attempts,
_direction=_direction,
)
return result

if self.norm_stats:
violation = self._check_bounds(data, idx)
if violation is not None:
next_idx, origin, attempts = self._get_fallback_idx(
idx, _fallback_origin, _attempts, violation
)
return self.__getitem__(
next_idx,
_fallback_origin=origin,
_attempts=attempts,
_direction=_direction,
)

for k, v in data.items():
if isinstance(v, np.ndarray):
data[k] = torch.from_numpy(v).to(torch.float32)
Expand All @@ -798,6 +967,7 @@ def get_item_keys(
idx: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Re: line +965]

This function isn't being used so all these changes can be deleted, right?

See this comment inline on Graphite.

keys,
_fallback_origin: int | None = None,
_attempts: int | None = None,
_direction: str | None = None,
) -> dict[str, torch.Tensor]:
requested = self._normalize_keys_arg(keys)
Expand All @@ -823,12 +993,21 @@ def get_item_keys(
val = raw[zarr_key]

if zarr_key in self._image_keys:

def _jpeg_fallback() -> dict[str, torch.Tensor]:
next_idx, origin, direction = self._get_fallback_idx(
idx, _fallback_origin, _direction,
next_idx, origin, direction = self._get_image_fallback_idx(
idx,
_fallback_origin,
_direction,
f"JPEG decode failed ep={Path(self.episode_path).name} frame={idx} key={k}",
)
result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction)
result = self.get_item_keys(
next_idx,
keys,
_fallback_origin=origin,
_attempts=_attempts,
_direction=direction,
)
return result

if (
Expand Down Expand Up @@ -858,13 +1037,35 @@ def _jpeg_fallback() -> dict[str, torch.Tensor]:
try:
out = transform.transform(out)
except Exception as e:
next_idx, origin, direction = self._get_fallback_idx(
idx, _fallback_origin, _direction,
next_idx, origin, attempts = self._get_fallback_idx(
idx,
_fallback_origin,
_attempts,
f"Transform failed ep={Path(self.episode_path).name} frame={idx} ({type(e).__name__}: {e})",
)
result = self.get_item_keys(next_idx, keys, _fallback_origin=origin, _direction=direction)
result = self.get_item_keys(
next_idx,
keys,
_fallback_origin=origin,
_attempts=attempts,
_direction=_direction,
)
return result

if self.norm_stats:
violation = self._check_bounds(out, idx)
if violation is not None:
next_idx, origin, attempts = self._get_fallback_idx(
idx, _fallback_origin, _attempts, violation
)
return self.get_item_keys(
next_idx,
keys,
_fallback_origin=origin,
_attempts=attempts,
_direction=_direction,
)

for k, v in out.items():
if isinstance(v, np.ndarray):
out[k] = torch.from_numpy(v).to(torch.float32)
Expand Down
Loading