Add ModalityTransform Methods in DataConfigs
Description
Add transform methods for ModalityTransform in the DataConfigs. While applying no transforms is an option, adding these methods would be beneficial for subsequent training processes.
References
Current State (strands-robots)
Our BaseDataConfig in strands_robots/policies/groot/data_config.py only has:
modality_config() method ✅
- Missing:
transform() method ❌
NVIDIA Implementation Analysis
1. Base Transform Architecture
NVIDIA has an abstract ModalityTransform class in gr00t/data/transform/base.py:
class ModalityTransform(BaseModel, ABC):
"""Abstract class for transforming data modalities."""
apply_to: list[str] # Keys to apply transform to
training: bool = True # Training vs eval mode
@abstractmethod
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Apply transformation to data."""
def train(self): self.training = True
def eval(self): self.training = False
class InvertibleModalityTransform(ModalityTransform):
@abstractmethod
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Reverse transformation."""
class ComposedModalityTransform(ModalityTransform):
"""Compose multiple transforms into a pipeline."""
transforms: list[ModalityTransform]
2. Transform Pipeline Structure
NVIDIA's BaseDataConfig.transform() returns a ComposedModalityTransform with this pipeline:
def transform(self) -> ModalityTransform:
transforms = [
# 1. VIDEO TRANSFORMS
VideoToTensor(apply_to=self.video_keys),
VideoCrop(apply_to=self.video_keys, scale=0.95),
VideoResize(apply_to=self.video_keys, height=224, width=224),
VideoColorJitter(apply_to=self.video_keys, brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08),
VideoToNumpy(apply_to=self.video_keys),
# 2. STATE TRANSFORMS
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(apply_to=self.state_keys, normalization_modes={...}),
# 3. ACTION TRANSFORMS
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(apply_to=self.action_keys, normalization_modes={...}),
# 4. CONCAT TRANSFORM
ConcatTransform(video_concat_order=..., state_concat_order=..., action_concat_order=...),
# 5. MODEL-SPECIFIC TRANSFORM
GR00TTransform(state_horizon=..., action_horizon=..., max_state_dim=64, max_action_dim=32),
]
return ComposedModalityTransform(transforms=transforms)
3. Key Transform Classes Needed
Video Transforms (video.py)
VideoToTensor: numpy [T,H,W,C] uint8 → torch [T,C,H,W] float32
VideoCrop: Random/center crop with scale factor
VideoResize: Resize to target resolution (224x224)
VideoColorJitter: Augmentation (brightness, contrast, saturation, hue)
VideoToNumpy: torch → numpy for inference
State/Action Transforms (state_action.py)
StateActionToTensor: numpy → torch
StateActionTransform: Normalization (min_max, mean_std, q99, binary) + Rotation conversion
StateActionSinCosTransform: Sin-cos encoding for joint angles
Concat Transform (concat.py)
- Concatenates video views along new axis
- Concatenates state/action keys into single tensors
- Tracks dimensions for unapply (splitting back)
Model Transform (transforms.py)
GR00TTransform: Pads state/action to max dims, applies VLM processing
Implementation Plan
Phase 1: Base Transform Infrastructure
Create strands_robots/policies/groot/transforms/ directory:
strands_robots/policies/groot/transforms/
├── __init__.py
├── base.py # ModalityTransform, InvertibleModalityTransform, ComposedModalityTransform
├── video.py # VideoToTensor, VideoCrop, VideoResize, VideoColorJitter, VideoToNumpy
├── state_action.py # StateActionToTensor, StateActionTransform, StateActionSinCosTransform
└── concat.py # ConcatTransform
Phase 2: Update BaseDataConfig
Add abstract transform() method to BaseDataConfig:
@dataclass
class BaseDataConfig(ABC):
# ... existing fields ...
@abstractmethod
def transform(self) -> ModalityTransform:
"""Return the transform pipeline for this data config."""
pass
Phase 3: Implement Transforms for Each Config
Update each concrete config (So100DataConfig, FourierGr1ArmsOnlyDataConfig, etc.) with their specific transform pipelines.
Phase 4: Integration with Policy
Update Gr00tPolicy to optionally use transforms for training workflows.
Simplified Initial Implementation (Inference-Only)
For inference-only use cases, we can start with a minimal implementation:
class IdentityTransform(ModalityTransform):
"""No-op transform for inference."""
def apply(self, data): return data
class BaseDataConfig:
def transform(self) -> ModalityTransform:
"""Default: no transforms (identity)."""
return IdentityTransform(apply_to=[])
Full Implementation Priority
| Transform |
Priority |
Reason |
| Base classes |
High |
Foundation for all transforms |
| VideoToTensor/ToNumpy |
High |
Basic type conversion |
| StateActionToTensor |
High |
Basic type conversion |
| VideoCrop/Resize |
Medium |
Required for proper image preprocessing |
| StateActionTransform (normalization) |
Medium |
Required for training |
| VideoColorJitter |
Low |
Augmentation (training only) |
| ConcatTransform |
Medium |
Required for multi-modal data |
| GR00TTransform |
Low |
Complex, depends on Eagle VLM |
Acceptance Criteria
Code Changes Required
File: strands_robots/policies/groot/data_config.py
Before:
@dataclass
class BaseDataConfig(ABC):
video_keys: List[str]
# ... other fields ...
def modality_config(self) -> Dict[str, ModalityConfig]:
# ... existing implementation ...
After:
from .transforms import ModalityTransform, ComposedModalityTransform
@dataclass
class BaseDataConfig(ABC):
video_keys: List[str]
# ... other fields ...
def modality_config(self) -> Dict[str, ModalityConfig]:
# ... existing implementation ...
@abstractmethod
def transform(self) -> ModalityTransform:
"""Return the transform pipeline for training/inference."""
pass
Priority
Medium - Important for training workflows
Add ModalityTransform Methods in DataConfigs
Description
Add transform methods for ModalityTransform in the DataConfigs. While applying no transforms is an option, adding these methods would be beneficial for subsequent training processes.
References
Current State (strands-robots)
Our
BaseDataConfiginstrands_robots/policies/groot/data_config.pyonly has:modality_config()method ✅transform()method ❌NVIDIA Implementation Analysis
1. Base Transform Architecture
NVIDIA has an abstract
ModalityTransformclass ingr00t/data/transform/base.py:2. Transform Pipeline Structure
NVIDIA's
BaseDataConfig.transform()returns aComposedModalityTransformwith this pipeline:3. Key Transform Classes Needed
Video Transforms (
video.py)VideoToTensor: numpy [T,H,W,C] uint8 → torch [T,C,H,W] float32VideoCrop: Random/center crop with scale factorVideoResize: Resize to target resolution (224x224)VideoColorJitter: Augmentation (brightness, contrast, saturation, hue)VideoToNumpy: torch → numpy for inferenceState/Action Transforms (
state_action.py)StateActionToTensor: numpy → torchStateActionTransform: Normalization (min_max, mean_std, q99, binary) + Rotation conversionStateActionSinCosTransform: Sin-cos encoding for joint anglesConcat Transform (
concat.py)Model Transform (
transforms.py)GR00TTransform: Pads state/action to max dims, applies VLM processingImplementation Plan
Phase 1: Base Transform Infrastructure
Create
strands_robots/policies/groot/transforms/directory:Phase 2: Update BaseDataConfig
Add abstract
transform()method toBaseDataConfig:Phase 3: Implement Transforms for Each Config
Update each concrete config (So100DataConfig, FourierGr1ArmsOnlyDataConfig, etc.) with their specific transform pipelines.
Phase 4: Integration with Policy
Update
Gr00tPolicyto optionally use transforms for training workflows.Simplified Initial Implementation (Inference-Only)
For inference-only use cases, we can start with a minimal implementation:
Full Implementation Priority
Acceptance Criteria
ModalityTransformbase classes implementedtransform()method added toBaseDataConfigtransform()methodCode Changes Required
File:
strands_robots/policies/groot/data_config.pyBefore:
After:
Priority
Medium - Important for training workflows