1+ import ray
2+
3+ from argparse import ArgumentParser
4+ from functools import partial
5+
6+ import torch
7+ from datasets import load_dataset
8+ from tensordict import TensorDict
9+ from torch .utils .data import DataLoader
10+ from torchrl .collectors .weight_update import RayRemoteWeightUpdater
11+ from transformers import AutoTokenizer , AutoModel
12+ from vllm import LLM
13+
14+ from vllm .utils import get_ip , get_open_port
15+
16+ from torchrl .collectors .distributed import RayCollector
17+ from torchrl .envs import LLMEnv
18+ from torchrl .modules import from_vllm
19+
20+ from torchrl .collectors .vllm_weight_update import vLLMHFLocalWeightUpdater , vLLMRemoteWeightUpdaterBase , WorkerExtension
21+
22+ parser = ArgumentParser ()
23+ parser .add_argument ("--dataset" , type = str , default = "gsm8k" )
24+ parser .add_argument ("--batch_size" , type = int , default = 4 )
25+ parser .add_argument ("--epochs" , type = int , default = 10 )
26+ parser .add_argument ("--repeats" , type = int , default = 10 )
27+ parser .add_argument ("--steps_per_batch" , type = int , default = 16 )
28+ parser .add_argument ("--optim_batch_size" , type = int , default = 4 )
29+
30+
31+ def make_policy ():
32+ inference_model = LLM (
33+ "facebook/opt-125m" ,
34+ enforce_eager = True ,
35+ # change to worker_extension_cls when available in stable release
36+ worker_cls = WorkerExtension ,
37+ )
38+
39+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
40+ tokenizer .pad_token = tokenizer .eos_token
41+ tokenizer .padding_side = "left"
42+
43+ policy = from_vllm (
44+ inference_model , tokenizer = tokenizer , from_text = False , generate = True , return_log_probs = True , generate_kwargs = {"temperature" : 0.0 })
45+ return policy
46+
47+
48+ def make_env (dataset , batch_size ):
49+ dataset = load_dataset (dataset , "main" )
50+ train_dataset = dataset ["train" ]
51+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
52+ tokenizer .pad_token = tokenizer .eos_token
53+ tokenizer .padding_side = "left"
54+
55+ # Env
56+ dataloader = DataLoader ( # noqa: TOR401
57+ train_dataset , batch_size = batch_size , shuffle = True , collate_fn = collate_fn
58+ )
59+ env = LLMEnv .from_dataloader (
60+ dataloader = dataloader ,
61+ tokenizer = tokenizer ,
62+ str2str = True ,
63+ batch_size = (args .batch_size * args .repeats ,),
64+ repeats = args .repeats , )
65+ return env
66+
67+
68+ def collate_fn (batch ):
69+ batch = torch .stack ([TensorDict .from_dict (_batch ) for _batch in batch ])
70+ batch .rename_key_ ("question" , "text" )
71+ return batch
72+
73+ @ray .remote (num_cpus = 1 , num_gpus = 1 )
74+ class TrainerActor :
75+ def __init__ (self , model , env_vars ):
76+ import os
77+ import torch
78+ import torch .distributed
79+ from torch .distributed ._composable .fsdp import fully_shard
80+
81+ torch .cuda .set_device (torch .device ('cuda' , 0 ))
82+
83+ for var in env_vars :
84+ os .environ [var ] = str (env_vars [var ])
85+
86+ if not torch .distributed .is_initialized ():
87+ torch .distributed .init_process_group (backend = "nccl" , device_id = torch .device ('cuda:0' ))
88+ print ("initialized process group" )
89+
90+ world_size = torch .distributed .get_world_size ()
91+ rank = torch .distributed .get_rank ()
92+ print (world_size , rank )
93+ self .rank = int (os .environ ["RANK" ])
94+ self .world_size = int (os .environ ["WORLD_SIZE" ])
95+
96+
97+ # hold back one rank for the parameter server
98+ self .fsdp_group = torch .distributed .new_group (ranks = list (range (self .world_size - 1 )))
99+ self .device_mesh = torch .distributed .device_mesh .DeviceMesh .from_group (self .fsdp_group , device_type = "cuda" )
100+
101+ self .model = AutoModel .from_pretrained (model ).cuda ()
102+
103+ fully_shard (self .model , mesh = self .device_mesh )
104+
105+ def register_parameter_server (self , param_server ):
106+ assert self .rank == 0
107+ self .param_server = param_server
108+
109+ def send_weights_to_param_server (self ):
110+ if self .rank == 0 :
111+ ray .get (self .param_server .acquire_state_dict_lock .remote ())
112+ self .param_server .receive_from_trainer .remote ()
113+ for k , v in self .model .state_dict ().items ():
114+ replicated_v = v .full_tensor ()
115+ if self .rank == 0 :
116+ # dst is global rank, can switch to group_dst arg if not 2.5.1
117+ torch .distributed .send (replicated_v , dst = 2 )
118+ if self .rank == 0 :
119+ ray .get (self .param_server .release_state_dict_lock .remote ())
120+
121+ def zero_ (self ):
122+ sd = self .model .state_dict ()
123+ for k , v in sd .items ():
124+ sd [k ] = v .data .zero_ ()
125+
126+ def train (self ):
127+ import time
128+ for _ in range (1 ):
129+ # actually run train loop
130+ # ...
131+ self .zero_ ()
132+ torch .distributed .barrier (group = self .fsdp_group )
133+ self .send_weights_to_param_server ()
134+ torch .distributed .barrier (group = self .fsdp_group )
135+
136+
137+ @ray .remote (num_cpus = 1 , num_gpus = 1 )
138+ class vLLMParameterServer (vLLMRemoteWeightUpdaterBase ):
139+ def __init__ (self , model , vllm_master_address , vllm_master_port , env_vars ):
140+ super ().__init__ (model , vllm_master_address , vllm_master_port )
141+ import os
142+ import torch
143+ import torch .distributed
144+
145+ torch .cuda .set_device (torch .device ('cuda' , 0 ))
146+
147+ for var in env_vars :
148+ os .environ [var ] = str (env_vars [var ])
149+
150+ if not torch .distributed .is_initialized ():
151+ torch .distributed .init_process_group (backend = "nccl" , device_id = torch .device ('cuda:0' ))
152+
153+ self .rank = int (os .environ ["RANK" ])
154+ self .world_size = int (os .environ ["WORLD_SIZE" ])
155+ assert self .rank == self .world_size - 1
156+
157+ self .fsdp_group = torch .distributed .new_group (ranks = list (range (self .world_size - 1 )))
158+
159+ def receive_from_trainer (self ):
160+ for k , v in self .state_dict .items ():
161+ torch .distributed .recv (v , src = 0 )
162+
163+ def _skip_update (self , worker_id : int ) -> bool :
164+ pass
165+
166+ def check_weights_changed (self ):
167+ """
168+ Check if the weights are updated to 0.
169+ """
170+ weights_updated = True
171+ for name , p in self .state_dict .items ():
172+ weights_updated = weights_updated and torch .allclose (
173+ p , torch .zeros_like (p ))
174+ return weights_updated
175+
176+
177+
178+ def _create_trainer_group (
179+ worker_cls ,
180+ param_server_cls ,
181+ world_size : int ,
182+ vllm_master_address ,
183+ vllm_master_port ,
184+ model ,
185+ ):
186+ addr , port = get_ip (), get_open_port ()
187+ trainer_workers = []
188+ fsdp_world_size = world_size - 1
189+ for i in range (fsdp_world_size ):
190+ env_vars = {
191+ "RANK" : str (i ),
192+ "WORLD_SIZE" : world_size ,
193+ "MASTER_ADDR" : str (addr ),
194+ "MASTER_PORT" : str (port ),
195+ }
196+ worker = worker_cls .remote (model , env_vars )
197+ trainer_workers .append (worker )
198+
199+ env_vars = {
200+ "RANK" : str (world_size - 1 ),
201+ "WORLD_SIZE" : world_size ,
202+ "MASTER_ADDR" : str (addr ),
203+ "MASTER_PORT" : str (port ),
204+ }
205+ parameter_server = param_server_cls .remote (model , vllm_master_address , vllm_master_port , env_vars )
206+ trainer_workers [0 ].register_parameter_server .remote (parameter_server )
207+ return trainer_workers , parameter_server
208+
209+
210+ if __name__ == "__main__" :
211+ args = parser .parse_args ()
212+
213+ remote_configs = {
214+ "num_cpus" : 1 ,
215+ "num_gpus" : 1 ,
216+ "memory" : 2 * 1024 ** 3 ,
217+ }
218+
219+ model = "facebook/opt-125m"
220+
221+ ray .init (num_cpus = 4 , num_gpus = 4 )
222+
223+ vllm_master_address , vllm_update_port = get_ip (), get_open_port ()
224+
225+ trainer_workers , parameter_server = _create_trainer_group (
226+ TrainerActor ,
227+ vLLMParameterServer ,
228+ 3 ,
229+ vllm_master_address ,
230+ vllm_update_port ,
231+ model ,
232+ )
233+
234+ handles = []
235+ for trainer_worker in trainer_workers :
236+ handles .append (trainer_worker .train .remote ())
237+
238+ model_metadata = ray .get (parameter_server .get_model_metadata .remote ())
239+ local_weight_updater = vLLMHFLocalWeightUpdater (vllm_master_address , vllm_update_port , model_metadata )
240+
241+ make_env_parsed = partial (make_env , batch_size = args .batch_size , dataset = args .dataset )
242+ collector = RayCollector (
243+ [make_env_parsed ],
244+ policy_factory = make_policy ,
245+ frames_per_batch = 40 ,
246+ total_frames = 200 ,
247+ remote_configs = remote_configs ,
248+ remote_weight_updater = parameter_server ,
249+ collector_kwargs = {
250+ "local_weight_updater" : local_weight_updater ,
251+ },
252+ update_after_each_batch = True ,
253+ )
254+ print ("done collector init" )
255+
256+ tokenizer = AutoTokenizer .from_pretrained ("facebook/opt-125m" )
257+
258+ for i , data in enumerate (collector ):
259+ print (tokenizer .decode (data ["tokens" ][0 ].squeeze ()))
260+ print (tokenizer .decode (data ["tokens_response" ][0 ].squeeze ()))
261+ if i == 1 :
262+ break
263+ collector .shutdown ()
0 commit comments