-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·122 lines (97 loc) · 3.67 KB
/
main.py
File metadata and controls
executable file
·122 lines (97 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Main entry for EVA.
The public phase-1 release is inference-first, but we keep the same
Hydra/experiment/algorithm/dataset structure that will later host
supervised training and RL post-training.
"""
from pathlib import Path
import hydra
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from utils.print_utils import cyan
from utils.distributed_utils import is_rank_zero
from utils.ckpt_utils import is_run_id
def run_local(cfg: DictConfig):
# delay some imports in case they are not needed in non-local envs for submission
from experiments import build_experiment
# Get yaml names
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
with open_dict(cfg):
if cfg_choice["experiment"] is not None:
cfg.experiment._name = cfg_choice["experiment"]
if cfg_choice["dataset"] is not None:
cfg.dataset._name = cfg_choice["dataset"]
if cfg_choice["algorithm"] is not None:
cfg.algorithm._name = cfg_choice["algorithm"]
# Set up the output directory.
output_dir = Path(hydra_cfg.runtime.output_dir)
if is_rank_zero:
print(cyan(f"Outputs will be saved to:"), output_dir)
(output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
(output_dir.parents[1] / "latest-run").symlink_to(
output_dir, target_is_directory=True
)
# Resolve ckpt path
resume = cfg.get("resume", None)
load = cfg.get("load", None)
checkpoint_path = None
load_id = None
if load and not is_run_id(load):
checkpoint_path = load
if resume:
load_id = resume
elif load and is_run_id(load):
load_id = load
else:
load_id = None
if load_id:
run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
checkpoint_path = Path("outputs/downloaded") / run_path / "model.ckpt"
# launch experiment
experiment = build_experiment(cfg, output_dir, checkpoint_path)
# for those who are searching, this is where we call tasks like 'training, validation, main'
for task in cfg.experiment.tasks:
experiment.exec_task(task)
@hydra.main(
version_base=None,
config_path="configurations",
config_name="config",
)
def run(cfg: DictConfig):
if "name" not in cfg:
raise ValueError(
"must specify a name for the run with command line argument '+name=[name]'"
)
if cfg.wandb.mode == "online" and not cfg.wandb.get("entity", None):
raise ValueError(
"wandb.entity is required when wandb.mode=online"
)
if cfg.wandb.project is None:
cfg.wandb.project = str(Path(__file__).parent.name)
# If resuming or loading a wandb ckpt and not on a compute node, download the checkpoint.
resume = cfg.get("resume", None)
load = cfg.get("load", None)
if resume and load:
raise ValueError(
"When resuming a wandb run with `resume=[wandb id]`, checkpoint will be loaded from the cloud"
"and `load` should not be specified."
)
if resume:
load_id = resume
elif load and is_run_id(load):
load_id = load
else:
load_id = None
if load_id:
raise NotImplementedError(
"wandb checkpoint download is not included in the public phase-1 EVA release. "
"Please pass a local checkpoint path with load=/path/to/model.ckpt."
)
if cfg.get("cluster") is not None:
raise NotImplementedError(
"Cluster submission is not included in the public phase-1 EVA release."
)
run_local(cfg)
if __name__ == "__main__":
run()