|
1 |
| -# ✍🏻 Create a New Model (WIP, Adjusting according to the new architecture) |
| 1 | +# ✍🏻 Create a New Model |
| 2 | + |
| 3 | + |
| 4 | +This guide shows you, step by step, how to plug a **new end‑to‑end policy model** into the InternManip framework. Follow the checklist below and you will be able to train your custom model with the stock training script (`scripts/train/train.py`)—no core code edits required. |
| 5 | + |
2 | 6 |
|
3 |
| -This section guides you through the process of adding a new end2end model to the InternManip framework. |
4 | 7 |
|
5 | 8 | ## File Structure and Why
|
6 | 9 |
|
7 |
| -Currently, the leading manipulation models try to leverage the existing pretrained large models for better generalization. They (for instance, **GR00T-N1** and **Pi-0**) often consist of a pretrained VLM backbone and a small downstream action expert that maps extracted hidden context to action space. In this way, we organize the model files into three main components: |
8 |
| -- **Backbone**: The pretrained VLM backbone, which is responsible for understanding the visual and textual inputs. |
9 |
| -- **Action Head**: The downstream action expert that takes the context from the backbone and maps it to the action space. |
10 |
| -- **Policy Model**: The base model that integrates the backbone and action head into a single end-to-end model. |
| 10 | +Currently, leading manipulation models strive to leverage existing pretrained large models for better generalization. For example, **GR00T-N1** and **Pi-0** typically consist of a pretrained VLM backbone and a compact downstream action expert that maps extracted context to the action space. Reflecting this design, InternManip organizes model files into three main components: |
11 | 11 |
|
12 |
| -Specifically, the model definitions are located in the `internmanip/model` directory, there are three subfolders under this directory: |
13 |
| -```plaintext |
| 12 | +- **Backbone**: The pretrained VLM backbone responsible for understanding visual or textual inputs. |
| 13 | +- **Action Head**: The downstream expert that consumes backbone features and predicts actions. |
| 14 | +- **Policy Model**: The wrapper that integrates the backbone and action head into a single end-to-end policy. |
| 15 | + |
| 16 | +Model definitions reside in the `internmanip/model` directory, which contains three sub-folders: |
| 17 | + |
| 18 | +```text |
14 | 19 | internmanip
|
15 | 20 | ├── model
|
16 |
| -│ ├── action_head |
17 |
| -│ ├── backbone |
18 |
| -│ ├── basemodel |
19 |
| -│ │ ├── base.py |
20 |
| -│ │ ├── ... |
21 |
| -│ ├── ... |
22 |
| -├── ... |
| 21 | +│ ├── action_head # task‑specific experts |
| 22 | +│ ├── backbone # pretrained encoders (ViT, CLIP, …) |
| 23 | +│ └── basemodel # full end‑to‑end policies |
| 24 | +│ └── base.py # <‑‑ universal interface |
| 25 | +... |
| 26 | +└── configs |
| 27 | + └── model # config classes (inherits PretrainedConfig) |
| 28 | +scripts |
| 29 | + └── train # trainers, entry points |
23 | 30 | ```
|
24 | 31 |
|
25 |
| -To create a new model, you need to implement a new model class derived from the `BasePolicyModel` class in `internmanip/model/basemodel/base.py`. It looks like this: |
26 |
| -```python |
27 |
| -from transformers import PreTrainedModel |
| 32 | +## 1. Outline |
| 33 | +To integrate a new model into the framework, you need to create the following files: |
28 | 34 |
|
29 |
| -from internmanip.configs.model.model_cfg import ModelCfg |
| 35 | +1. A **Config** that stores architecture related hyper‑parameters. |
| 36 | +2. A **Model** class that inherits `BasePolicyModel` and implements the model structure. |
| 37 | +3. A **data\_collator** that shapes raw samples into model‑ready tensors. |
30 | 38 |
|
31 |
| -class BasePolicyModel(PreTrainedModel): |
32 |
| - policy_models = {} |
| 39 | +Finally, you need to **register** the model with the framework and you can start training your model. We will guide you through the process step by step. |
33 | 40 |
|
34 |
| - def __init__(self, config: ModelCfg): |
35 |
| - super().__init__(config) |
36 |
| - self.config = config |
37 | 41 |
|
38 |
| - def forward(self, *args, **kwargs): |
39 |
| - raise NotImplementedError("Forward method not implemented.") |
| 42 | +## 2. Create the Model Configuration File |
| 43 | + |
| 44 | +The config file is used to store the architecture related hyper-parameters. Here is some basic information you need to know: |
| 45 | +You shall add the model configuration file in `internmanip/configs/model/{model_name}_cfg.py`, which should inherit `transformers.PretrainedConfig`. |
40 | 46 |
|
41 |
| - def inference(self, *args, **kwargs): |
42 |
| - raise NotImplementedError("inference method not implemented.") |
| 47 | +The following is **an example** of a model configuration file: |
43 | 48 |
|
| 49 | +```python |
| 50 | +from transformers import PretrainedConfig |
| 51 | + |
| 52 | +class CustomPolicyConfig(PretrainedConfig): |
| 53 | + """Configuration for CustomPolicy.""" |
| 54 | + model_type = "custom_model" |
| 55 | + |
| 56 | + def __init__(self, |
| 57 | + vit_name="google/vit-base-patch16-224-in21k", |
| 58 | + freeze_vit=True, |
| 59 | + hidden_dim=256, |
| 60 | + output_dim=8, |
| 61 | + dropout=0.0, |
| 62 | + n_obs_steps=1, |
| 63 | + horizon=10, |
| 64 | + **kwargs): |
| 65 | + super().__init__(**kwargs) |
| 66 | + self.vit_name = vit_name |
| 67 | + self.freeze_vit = freeze_vit |
| 68 | + self.hidden_dim = hidden_dim |
| 69 | + self.output_dim = output_dim |
| 70 | + self.dropout = dropout |
| 71 | + self.n_obs_steps = n_obs_steps |
| 72 | + self.horizon = horizon |
| 73 | + |
| 74 | + def transform(self) -> Tuple[List[Transform], List[int], List[int]]: |
| 75 | + transforms = None |
| 76 | + return transforms, list(range(self.n_obs_steps)), list(range(self.horizon)) |
44 | 77 | ```
|
45 |
| -where you need to implement the `__init__`, `forward`, and `inference` methods. The `forward` method is used for training, while the `inference` method is used for inference. |
46 | 78 |
|
47 |
| -## Implementation Steps |
48 |
| -As a quick start, we will use a very simple model with a ViT visual encoder and two layers of MLP as an example. |
| 79 | +As shown in the example above, the config class defines key architectural hyperparameters—such as the backbone model name, whether to freeze the backbone, the hidden/output dimensions of the action head, and more. You are free to extend this config with any additional parameters required by your custom model. |
49 | 80 |
|
50 |
| -1. Create a new file for your model in the `internmanip/model/basemodel` directory, for example `custom_model.py`. |
51 |
| -2. Import the necessary modules and classes, implement `__init__`, `forward`, and `inference` methods, and register your model class with the `BasePolicyModel` class: |
52 |
| -```python |
53 |
| -from pydantic import BaseModel |
54 |
| -from typing import Dict, Any, Optional |
55 |
| -import torch |
56 |
| -import torch.nn as nn |
57 |
| -import torch.nn.functional as F |
58 |
| -from transformers import ViTModel, ViTConfig # pip install transformers |
| 81 | +Additionally, you can implement a **model-specific `transform` method** within the config class. This method allows you to apply custom data transformations that are *not* included in the dataset-specific transform list defined in `internmanip/configs/dataset/data_config.py`. |
59 | 82 |
|
| 83 | +During training, the script `scripts/train/train.py` will automatically call this method and apply your custom transform alongside the default ones. Your `transform` method should follow the same input/output format as dataset-specific transform. For implementation guidance, refer to examples in the `internmanip/dataset/transform` directory. |
60 | 84 |
|
61 |
| -from internmanip.model.basemodel.base import BasePolicyModel |
62 | 85 |
|
| 86 | +## 3. Implement the Model |
| 87 | + |
| 88 | +In this class to implement the model, you need to inherit `BasePolicyModel` and register it with `@BasePolicyModel.register("custom_model")`. |
63 | 89 |
|
64 |
| -class CustomPolicyConfig(BaseModel): |
65 |
| - """Configuration for Custom Policy Model.""" |
66 |
| - vit_name: str = "google/vit-base-patch16-224-in21k" # or any HF ViT |
67 |
| - freeze_vit: bool = True |
68 |
| - input_dim: int |
69 |
| - hidden_dim: int = 256 |
70 |
| - output_dim: int |
71 |
| - dropout: float = 0.0 |
| 90 | +The model configuration file will be passed to the `__init__` method of the model class to initialize the model. With in the `__init__` method, you should define the model structure and initialize the model. |
72 | 91 |
|
| 92 | +You should also implement the `forward` method to define the model forward pass. The `forward` method should return a dictionary of tensors, which will be used to compute the loss. The `inference` method is used to generate the action from the model. |
| 93 | + |
| 94 | +```python |
| 95 | +from internmanip.model.basemodel.base import BasePolicyModel |
| 96 | +from transformers import ViTModel, ViTConfig |
| 97 | +import torch.nn as nn, torch.nn.functional as F, torch |
| 98 | +from typing import Dict |
| 99 | +from internmanip.configs.model.custom_policy_cfg import CustomPolicyConfig |
73 | 100 |
|
74 | 101 | @BasePolicyModel.register("custom_model")
|
75 |
| -class CustomModel(BasePolicyModel): |
76 |
| - """Two-layer MLP policy.""" |
| 102 | +class CustomPolicyModel(BasePolicyModel): |
| 103 | + """ViT backbone + 2‑layer MLP head.""" |
77 | 104 |
|
78 | 105 | def __init__(self, config: CustomPolicyConfig):
|
79 |
| - super().__init__() |
| 106 | + super().__init__(config) |
80 | 107 | self.config = config
|
| 108 | + name = "custom_model" |
81 | 109 |
|
82 |
| - # 1. ViT visual encoder |
| 110 | + # 1 Backbone |
83 | 111 | vit_conf = ViTConfig.from_pretrained(config.vit_name)
|
84 | 112 | self.vit = ViTModel.from_pretrained(config.vit_name, config=vit_conf)
|
85 | 113 | if config.freeze_vit:
|
86 | 114 | for p in self.vit.parameters():
|
87 | 115 | p.requires_grad = False
|
88 | 116 |
|
89 |
| - # 2. Two-layer MLP head |
90 |
| - vit_out_dim = vit_conf.hidden_size # 768 for base |
| 117 | + # 2 Action Head |
91 | 118 | self.mlp = nn.Sequential(
|
92 |
| - nn.Linear(vit_out_dim, config.hidden_dim), |
| 119 | + nn.Linear(vit_conf.hidden_size, config.hidden_dim), |
93 | 120 | nn.ReLU(),
|
94 | 121 | nn.Dropout(config.dropout),
|
95 | 122 | nn.Linear(config.hidden_dim, config.output_dim),
|
96 | 123 | )
|
97 | 124 |
|
| 125 | + # —— Training / Inference —— |
| 126 | + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: |
| 127 | + imgs, tgt = batch["images"], batch.get("actions") |
| 128 | + feats = self.vit(imgs).last_hidden_state[:, 0] # CLS token |
| 129 | + pred = self.mlp(feats) |
| 130 | + out = {"prediction": pred} |
| 131 | + if train and tgt is not None: |
| 132 | + out["loss"] = F.mse_loss(pred, tgt.view_as(pred)) |
| 133 | + return out |
| 134 | + |
| 135 | + def inference(self, batch: dict[str, Tensor], **kwargs) -> Tensor: |
| 136 | + actions = self.forward(batch, noise=None, time=None)["prediction"] |
| 137 | + return actions |
| 138 | +``` |
98 | 139 |
|
99 |
| - def forward(self, batch: Dict[str, torch.Tensor], train: bool = True, **kwargs) -> Dict[str, torch.Tensor]: |
100 |
| - """ |
101 |
| - Unified forward pass for both training and inference. |
102 |
| - When train=True we also return the loss. |
103 |
| - """ |
104 |
| - images = batch["images"] # (B, 3, 224, 224) |
105 |
| - vit_out = self.vit(images).last_hidden_state[:, 0] # (B, 768) - CLS token output |
106 |
| - pred = self.mlp(vit_out) |
| 140 | +In the example above, the model is composed of a ViT backbone and a simple 2-layer MLP action head. The `forward` method handles loss computation during training, while the `inference` method generates actions during evaluation. |
107 | 141 |
|
108 |
| - outputs = {"prediction": pred} |
| 142 | +When designing your own model, you can follow this backbone–head pattern or adopt a completely different architecture. If needed, you can define custom `backbone` and `action_head` modules—typically by subclassing `nn.Module`. Just ensure that your model's `inference` output has the shape `(n_actions, action_dim)`. |
109 | 143 |
|
110 |
| - if train: |
111 |
| - # Assume the batch contains a key named "actions" that holds the GT |
112 |
| - if pred.shape != targets.shape: |
113 |
| - targets = targets.view_as(pred) |
114 |
| - loss = F.mse_loss(pred, targets) |
115 |
| - outputs["loss"] = loss |
116 | 144 |
|
117 |
| - return outputs |
| 145 | +## 4. Write a Data Collator |
118 | 146 |
|
119 |
| - def inference(self, batch: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: |
120 |
| - """Inference-specific forward pass (no loss).""" |
121 |
| - return self.forward(batch, train=False, **kwargs) |
| 147 | +You need to define a data_collator function that converts a list of raw samples from default data loader into a single batch dictionary that is compatible with the model's `forward` method. |
122 | 148 |
|
| 149 | +```python |
| 150 | +import torch |
| 151 | +import torch.nn as nn |
| 152 | +import torch.nn.functional as F |
| 153 | + |
| 154 | +@DataCollatorRegistry.register("custom_model") |
| 155 | +def custom_data_collator(samples): |
| 156 | + imgs = torch.stack([s["image"] for s in samples]) |
| 157 | + acts = torch.stack([s["action"] for s in samples]) |
| 158 | + return {"images": imgs, "actions": acts} |
| 159 | +``` |
| 160 | + |
| 161 | +> **Why?** The built‑in `BaseTrainer` accepts any callable named `data_collator` so long as it returns a dictionary of tensors compatible with your model’s `forward` signature. |
| 162 | +
|
| 163 | + |
| 164 | +## 5. Register Everything |
| 165 | + |
| 166 | +Add the following **one-time** registration lines (typically at the end of your model file) to enable seamless dynamic loading with `AutoConfig` and `AutoModel`: |
| 167 | + |
| 168 | +```python |
| 169 | +from transformers import AutoConfig, AutoModel |
| 170 | + |
| 171 | +AutoConfig.register("custom_model", CustomPolicyConfig) |
| 172 | +AutoModel.register(CustomPolicyConfig, CustomPolicyModel) |
123 | 173 | ```
|
124 |
| -3. Now you can train your just customized model on `genmanip-demo` dataset with the following command: |
125 |
| -```bash |
126 |
| -torchrun --nnodes 1 --nproc_per_node 1 \ # number of processes per node, e.g., 1 |
127 |
| - scripts/train/train.py \ |
128 |
| - --model_name custom_model \ # model name |
129 |
| - --dataset-path genmanip-demo \ # registered dataset name or custom path |
130 |
| - --data-config genmanip-v1 # registered data config |
| 174 | + |
| 175 | +Make sure the string `"custom_model"` passed to `AutoConfig.register` matches the model name used in both your `CustomPolicyModel` definition and the data collator registration. |
| 176 | + |
| 177 | +Don't forget to register the module in your __init__.py, so that your custom model gets imported and initialized properly during runtime. For example: |
| 178 | + |
| 179 | +```python |
| 180 | +# In internmanip/model/basemodel/__init__.py |
| 181 | +from internmanip.model.basemodel.base import BasePolicyModel |
| 182 | + |
| 183 | +__all__ = ["BasePolicyModel"] |
| 184 | +# Import all model modules to ensure registration logic is executed |
| 185 | +from internmanip.model.basemodel.custom import custom_model # <- Your custom model module |
131 | 186 | ```
|
132 | 187 |
|
133 |
| -For more advanced tutorials, please refer to the [Model](../tutorials/model.md) section. |
| 188 | +Once registered, InternManip’s trainer can instantiate your model and you can start training. |
| 189 | + |
| 190 | +📚 For more details related to training and evaluation, please refer to [train_eval.md](./train_eval.md) and [training.md](../tutorials/training.md). |
0 commit comments