1+ import  ray 
2+ 
3+ from  argparse  import  ArgumentParser 
4+ from  functools  import  partial 
5+ 
6+ import  torch 
7+ from  datasets  import  load_dataset 
8+ from  tensordict  import  TensorDict 
9+ from  torch .utils .data  import  DataLoader 
10+ from  torchrl .collectors .weight_update  import  RayRemoteWeightUpdater 
11+ from  transformers  import  AutoTokenizer , AutoModel 
12+ from  vllm  import  LLM 
13+ 
14+ from  vllm .utils  import  get_ip , get_open_port 
15+ 
16+ from  torchrl .collectors .distributed  import  RayCollector 
17+ from  torchrl .envs  import  LLMEnv 
18+ from  torchrl .modules  import  from_vllm 
19+ 
20+ from  torchrl .collectors .vllm_weight_update  import  vLLMHFLocalWeightUpdater , vLLMRemoteWeightUpdaterBase , WorkerExtension 
21+ 
22+ parser  =  ArgumentParser ()
23+ parser .add_argument ("--dataset" , type = str , default = "gsm8k" )
24+ parser .add_argument ("--batch_size" , type = int , default = 4 )
25+ parser .add_argument ("--epochs" , type = int , default = 10 )
26+ parser .add_argument ("--repeats" , type = int , default = 10 )
27+ parser .add_argument ("--steps_per_batch" , type = int , default = 16 )
28+ parser .add_argument ("--optim_batch_size" , type = int , default = 4 )
29+ 
30+ 
31+ def  make_policy ():
32+     inference_model  =  LLM (
33+         "facebook/opt-125m" ,
34+         enforce_eager = True ,
35+         # change to worker_extension_cls when available in stable release 
36+         worker_cls = WorkerExtension ,
37+     )
38+     
39+     tokenizer  =  AutoTokenizer .from_pretrained ("facebook/opt-125m" )
40+     tokenizer .pad_token  =  tokenizer .eos_token 
41+     tokenizer .padding_side  =  "left" 
42+ 
43+     policy  =  from_vllm (
44+         inference_model , tokenizer = tokenizer , from_text = False , generate = True , return_log_probs = True , generate_kwargs = {"temperature" : 0.0 })
45+     return  policy 
46+ 
47+ 
48+ def  make_env (dataset , batch_size ):
49+     dataset  =  load_dataset (dataset , "main" )
50+     train_dataset  =  dataset ["train" ]
51+     tokenizer  =  AutoTokenizer .from_pretrained ("facebook/opt-125m" )
52+     tokenizer .pad_token  =  tokenizer .eos_token 
53+     tokenizer .padding_side  =  "left" 
54+ 
55+     # Env 
56+     dataloader  =  DataLoader (  # noqa: TOR401 
57+         train_dataset , batch_size = batch_size , shuffle = True , collate_fn = collate_fn 
58+     )
59+     env  =  LLMEnv .from_dataloader (
60+         dataloader = dataloader ,
61+         tokenizer = tokenizer ,
62+         str2str = True ,
63+         batch_size = (args .batch_size  *  args .repeats ,),
64+         repeats = args .repeats , )
65+     return  env 
66+ 
67+ 
68+ def  collate_fn (batch ):
69+     batch  =  torch .stack ([TensorDict .from_dict (_batch ) for  _batch  in  batch ])
70+     batch .rename_key_ ("question" , "text" )
71+     return  batch 
72+ 
73+ @ray .remote (num_cpus = 1 , num_gpus = 1 ) 
74+ class  TrainerActor :
75+     def  __init__ (self , model , env_vars ):
76+         import  os 
77+         import  torch 
78+         import  torch .distributed 
79+         from  torch .distributed ._composable .fsdp  import  fully_shard 
80+ 
81+         torch .cuda .set_device (torch .device ('cuda' , 0 ))
82+ 
83+         for  var  in  env_vars :
84+             os .environ [var ] =  str (env_vars [var ])
85+ 
86+         if  not  torch .distributed .is_initialized ():
87+             torch .distributed .init_process_group (backend = "nccl" , device_id = torch .device ('cuda:0' ))
88+             print ("initialized process group" )
89+ 
90+         world_size  =  torch .distributed .get_world_size ()
91+         rank  =  torch .distributed .get_rank ()
92+         print (world_size , rank )
93+         self .rank  =  int (os .environ ["RANK" ])
94+         self .world_size  =  int (os .environ ["WORLD_SIZE" ])
95+ 
96+ 
97+         # hold back one rank for the parameter server 
98+         self .fsdp_group  =  torch .distributed .new_group (ranks = list (range (self .world_size  -  1 )))
99+         self .device_mesh  =  torch .distributed .device_mesh .DeviceMesh .from_group (self .fsdp_group , device_type = "cuda" ) 
100+ 
101+         self .model  =  AutoModel .from_pretrained (model ).cuda ()
102+ 
103+         fully_shard (self .model , mesh = self .device_mesh )
104+     
105+     def  register_parameter_server (self , param_server ):
106+         assert  self .rank  ==  0 
107+         self .param_server  =  param_server 
108+     
109+     def  send_weights_to_param_server (self ):
110+         if  self .rank  ==  0 :
111+             ray .get (self .param_server .acquire_state_dict_lock .remote ())
112+             self .param_server .receive_from_trainer .remote ()
113+         for  k , v  in  self .model .state_dict ().items ():
114+             replicated_v  =  v .full_tensor ()
115+             if  self .rank  ==  0 :
116+                 # dst is global rank, can switch to group_dst arg if not 2.5.1 
117+                 torch .distributed .send (replicated_v , dst = 2 )
118+         if  self .rank  ==  0 :
119+             ray .get (self .param_server .release_state_dict_lock .remote ())
120+     
121+     def  zero_ (self ):
122+         sd  =  self .model .state_dict ()
123+         for  k , v  in  sd .items ():
124+             sd [k ] =  v .data .zero_ ()
125+     
126+     def  train (self ):
127+         import  time 
128+         for  _  in  range (1 ):
129+             # actually run train loop 
130+             # ... 
131+             self .zero_ ()
132+             torch .distributed .barrier (group = self .fsdp_group )
133+             self .send_weights_to_param_server ()
134+             torch .distributed .barrier (group = self .fsdp_group )
135+ 
136+ 
137+ @ray .remote (num_cpus = 1 , num_gpus = 1 ) 
138+ class  vLLMParameterServer (vLLMRemoteWeightUpdaterBase ):
139+     def  __init__ (self , model , vllm_master_address , vllm_master_port , env_vars ):
140+         super ().__init__ (model , vllm_master_address , vllm_master_port )
141+         import  os 
142+         import  torch 
143+         import  torch .distributed 
144+ 
145+         torch .cuda .set_device (torch .device ('cuda' , 0 ))
146+ 
147+         for  var  in  env_vars :
148+             os .environ [var ] =  str (env_vars [var ])
149+ 
150+         if  not  torch .distributed .is_initialized ():
151+             torch .distributed .init_process_group (backend = "nccl" , device_id = torch .device ('cuda:0' ))
152+ 
153+         self .rank  =  int (os .environ ["RANK" ])
154+         self .world_size  =  int (os .environ ["WORLD_SIZE" ])
155+         assert  self .rank  ==  self .world_size  -  1 
156+ 
157+         self .fsdp_group  =  torch .distributed .new_group (ranks = list (range (self .world_size  -  1 )))
158+     
159+     def  receive_from_trainer (self ):
160+         for  k , v  in  self .state_dict .items ():
161+             torch .distributed .recv (v , src = 0 )
162+ 
163+     def  _skip_update (self , worker_id : int ) ->  bool :
164+         pass 
165+     
166+     def  check_weights_changed (self ):
167+         """ 
168+         Check if the weights are updated to 0. 
169+         """ 
170+         weights_updated  =  True 
171+         for  name , p  in  self .state_dict .items ():
172+             weights_updated  =  weights_updated  and  torch .allclose (
173+                 p , torch .zeros_like (p ))
174+         return  weights_updated 
175+ 
176+ 
177+ 
178+ def  _create_trainer_group (
179+     worker_cls ,
180+     param_server_cls ,
181+     world_size : int ,
182+     vllm_master_address ,
183+     vllm_master_port ,
184+     model ,
185+ ):
186+     addr , port  =  get_ip (), get_open_port ()
187+     trainer_workers  =  []
188+     fsdp_world_size  =  world_size  -  1 
189+     for  i  in  range (fsdp_world_size ):
190+         env_vars  =  {
191+             "RANK" : str (i ),
192+             "WORLD_SIZE" : world_size ,
193+             "MASTER_ADDR" : str (addr ),
194+             "MASTER_PORT" : str (port ),
195+         }
196+         worker  =  worker_cls .remote (model , env_vars )
197+         trainer_workers .append (worker )
198+     
199+     env_vars  =  {
200+         "RANK" : str (world_size  -  1 ),
201+         "WORLD_SIZE" : world_size ,
202+         "MASTER_ADDR" : str (addr ),
203+         "MASTER_PORT" : str (port ),
204+     }
205+     parameter_server  =  param_server_cls .remote (model , vllm_master_address , vllm_master_port , env_vars )
206+     trainer_workers [0 ].register_parameter_server .remote (parameter_server )
207+     return  trainer_workers , parameter_server 
208+ 
209+ 
210+ if  __name__  ==  "__main__" :
211+     args  =  parser .parse_args ()
212+ 
213+     remote_configs  =  {
214+         "num_cpus" : 1 ,
215+         "num_gpus" : 1 ,
216+         "memory" : 2  *  1024 ** 3 ,
217+     }
218+ 
219+     model  =  "facebook/opt-125m" 
220+ 
221+     ray .init (num_cpus = 5 , num_gpus = 5 )
222+ 
223+     vllm_addresses  =  [get_ip ()] *  2 
224+     vllm_ports  =  [get_open_port () for  i  in  range (2 )]
225+     print (vllm_ports )
226+ 
227+     trainer_workers , parameter_server  =  _create_trainer_group (
228+                                             TrainerActor ,
229+                                             vLLMParameterServer ,
230+                                             3 ,
231+                                             vllm_addresses ,
232+                                             vllm_ports ,
233+                                             model ,
234+                                         )
235+ 
236+     handles  =  []
237+     for  trainer_worker  in  trainer_workers :
238+         handles .append (trainer_worker .train .remote ())
239+ 
240+     model_metadata  =  ray .get (parameter_server .get_model_metadata .remote ())
241+     local_weight_updaters  =  [
242+         vLLMHFLocalWeightUpdater (vllm_master_address , vllm_update_port , model_metadata ) for 
243+         vllm_master_address , vllm_update_port  in  zip (vllm_addresses , vllm_ports )
244+     ]
245+ 
246+     make_env_parsed  =  partial (make_env , batch_size = args .batch_size , dataset = args .dataset )
247+     collector  =  RayCollector (
248+         [make_env_parsed , make_env_parsed ],
249+         policy_factory = make_policy ,
250+         frames_per_batch = 40 ,
251+         total_frames = 200 ,
252+         remote_configs = remote_configs ,
253+         remote_weight_updater = parameter_server ,
254+         num_collectors = 2 ,
255+         collector_kwargs = [
256+             {
257+                 "local_weight_updater" : local_weight_updaters [0 ],
258+             },
259+             {
260+                 "local_weight_updater" : local_weight_updaters [1 ],
261+             }
262+         ],
263+         update_after_each_batch = True ,
264+     )
265+     print ("done collector init" )
266+ 
267+     tokenizer  =  AutoTokenizer .from_pretrained ("facebook/opt-125m" )
268+ 
269+     for  i , data  in  enumerate (collector ):
270+         print (tokenizer .decode (data ["tokens" ][0 ].squeeze ()))
271+         print (tokenizer .decode (data ["tokens_response" ][0 ].squeeze ()))
272+         if  i  ==  3 :
273+             break 
274+     collector .shutdown ()
0 commit comments