forked from weipeilun/Nested-Learning-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_hope.py
More file actions
35 lines (26 loc) · 936 Bytes
/
train_hope.py
File metadata and controls
35 lines (26 loc) · 936 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import annotations
import hydra
import torch
from omegaconf import DictConfig
from nested_learning_pytorch.training import build_model_from_cfg, unwrap_config
@hydra.main(config_path="configs", config_name="hope_tiny", version_base=None)
def main(cfg: DictConfig) -> None:
cfg = unwrap_config(cfg)
device = cfg.train.device
model = build_model_from_cfg(cfg.model).to(device)
model.is_training = True
steps = 10
x = torch.randn(1, 10240, 64)
y = torch.randn(1, 10240, 64)
state = None
x = x.to(device)
y = y.to(device)
for step in range(steps):
print(f"step {step}")
grads_dict, state = model.forward_inner_loop(x=x, y=y, state=state)
model.outer_update(grads_dict=grads_dict)
# Explicitly clean up to prevent memory accumulation
grads_dict = None
state = None
if __name__ == "__main__":
main()