1+ import torch
2+ import threading
3+
4+ from torchrl .collectors .weight_update import RemoteWeightUpdaterBase
5+ from torchrl .collectors .weight_update import LocalWeightUpdaterBase
6+
7+
8+ VLLM_ERR = None
9+ try :
10+ import vllm
11+ from vllm .worker .worker import Worker
12+
13+ _has_vllm = True
14+ except ImportError as err :
15+ _has_vllm = False
16+ VLLM_ERR = err
17+
18+ # These utilities are copied from vLLM's example code.
19+ def stateless_init_process_group (
20+ master_address : str ,
21+ master_port : int ,
22+ rank : int ,
23+ world_size : int ,
24+ device : torch .device ,
25+ ):
26+ """
27+ vLLM provides `StatelessProcessGroup` to create a process group
28+ without considering the global process group in torch.distributed.
29+ It is recommended to create `StatelessProcessGroup`, and then initialize
30+ the data-plane communication (NCCL) between external (train processes)
31+ and vLLM workers.
32+ """
33+ from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
34+ from vllm .distributed .utils import StatelessProcessGroup
35+
36+ pg = StatelessProcessGroup .create (
37+ host = master_address , port = master_port , rank = rank , world_size = world_size
38+ )
39+ pynccl = PyNcclCommunicator (pg , device = device )
40+ return pynccl
41+
42+
43+ if _has_vllm :
44+ # I should use worker_extension_cls arg and not inherit from worker,
45+ # but that is only available on main and not vLLM 0.7.3
46+ class WorkerExtension (Worker ):
47+ """
48+ The class for vLLM's worker to inherit from.
49+ By defining an extension class, the code can work no matter what is
50+ the underlying worker class. This way, the code can be compatible
51+ with both vLLM V0 and V1.
52+ NOTE: we define this class in a separate module, and the main module
53+ should pass the full qualified name as `worker_extension_cls` argument.
54+ """
55+
56+ def init_weight_update_group (self , master_address , master_port ,
57+ rank_offset , world_size ):
58+ from vllm .distributed .parallel_state import get_world_group
59+ # rank = get_world_group().rank + rank_offset
60+ rank = rank_offset
61+ self .model_update_group = stateless_init_process_group (
62+ master_address ,
63+ master_port ,
64+ rank ,
65+ world_size ,
66+ self .device ,
67+ )
68+ self .version = torch .tensor ([0 ], device = "cuda" )
69+
70+ def update_weight (self , name , dtype , shape ):
71+ weight = torch .empty (shape , dtype = dtype , device = "cuda" )
72+ self .model_update_group .broadcast (weight ,
73+ src = 0 ,
74+ stream = torch .cuda .current_stream ())
75+
76+ self .model_runner .model .load_weights (weights = [(name , weight )])
77+
78+ del weight
79+
80+ def update_policy_version (self ):
81+ self .model_update_group .broadcast (self .version ,
82+ src = 0 ,
83+ stream = torch .cuda .current_stream ())
84+ torch .cuda .synchronize ()
85+ # print(f"{self=} {self.model_runner.model=}")
86+ self .policy_version = self .version
87+
88+ def check_weights_changed (self ):
89+ """
90+ Check if the weights are updated to 0.
91+ """
92+ weights_updated = True
93+ for name , p in self .model_runner .model .named_parameters ():
94+ weights_updated = weights_updated and torch .allclose (
95+ p , torch .zeros_like (p ))
96+ return weights_updated
97+ else :
98+ class WorkerExtension :
99+ pass
100+
101+
102+ class vLLMHFLocalWeightUpdater (LocalWeightUpdaterBase ):
103+ def __init__ (self , master_address , master_port , model_metadata ):
104+ print (f"{ master_address = } , { master_port = } " )
105+ self .master_address = master_address
106+ self .master_port = master_port
107+ self .model_metadata = model_metadata
108+ self .initialized_group = None
109+
110+ def _get_server_weights (self ):
111+ return None
112+
113+ def _get_local_weights (self ):
114+ # We don't implement this because we let vLLM's update_weights API handle everything for now
115+ return None
116+
117+ def _maybe_map_weights (self , server_weights , local_weights ):
118+ # vLLM update_weights function handles the mapping from huggingface
119+ # so we don't implement this for now
120+ return None
121+
122+ def _update_local_weights (self , local_weights , mapped_weights ):
123+ llm = self .collector .policy ["generate" ].module
124+ if self .initialized_group is None :
125+ weight_sync_world_size = llm .llm_engine .parallel_config .tensor_parallel_size + 1
126+ llm .collective_rpc (
127+ "init_weight_update_group" ,
128+ args = (self .master_address , self .master_port , 1 , weight_sync_world_size )
129+ )
130+ self .initialized_group = True
131+
132+ for k , (dtype , shape ) in self .model_metadata .items ():
133+ llm .collective_rpc (
134+ "update_weight" ,
135+ args = (k , dtype , shape )
136+ )
137+
138+ llm .collective_rpc ("update_policy_version" )
139+ print ("done local update_weight" )
140+
141+ class ReadWriteLock :
142+ """ A lock object that allows many simultaneous "read locks", but
143+ only one "write lock." """
144+
145+ def __init__ (self ):
146+ self ._read_ready = threading .Condition (threading .Lock ())
147+ self ._readers = 0
148+
149+ def acquire_read (self ):
150+ """ Acquire a read lock. Blocks only if a thread has
151+ acquired the write lock. """
152+ self ._read_ready .acquire ()
153+ try :
154+ self ._readers += 1
155+ finally :
156+ self ._read_ready .release ()
157+
158+ def release_read (self ):
159+ """ Release a read lock. """
160+ self ._read_ready .acquire ()
161+ try :
162+ self ._readers -= 1
163+ if not self ._readers :
164+ self ._read_ready .notifyAll ()
165+ finally :
166+ self ._read_ready .release ()
167+
168+ def acquire_write (self ):
169+ """ Acquire a write lock. Blocks until there are no
170+ acquired read or write locks. """
171+ self ._read_ready .acquire ()
172+ while self ._readers > 0 :
173+ self ._read_ready .wait ()
174+
175+ def release_write (self ):
176+ """ Release a write lock. """
177+ self ._read_ready .release ()
178+
179+ class vLLMRemoteWeightUpdaterBase (RemoteWeightUpdaterBase ):
180+ def __init__ (self , vllm_master_addresses , vllm_master_ports ):
181+ super ().__init__ ()
182+ from transformers import AutoModel
183+ self .vllm_master_addresses = vllm_master_addresses
184+ self .vllm_master_ports = vllm_master_ports
185+ # state_dict = dict()
186+ # for k, (dtype, shape) in model_metadata.items():
187+ # self.state_dict[k] = torch.zeros(shape, dtype=dtype, device="cuda")
188+ # self.state_dict = state_dict()
189+ # self.state_dict_lock = ReadWriteLock()
190+ self .vllm_comm_groups = dict ()
191+ self .vllm_weight_versions = dict ()
192+ # self.version = -1
193+
194+ def register_model_metadata (self , model_metadata ):
195+ self .model_metadata = model_metadata
196+ self .state_dict = dict ()
197+ for k , (dtype , shape ) in model_metadata .items ():
198+ self .state_dict [k ] = torch .zeros (shape , dtype = dtype , device = "cuda" )
199+ self .state_dict_lock = ReadWriteLock ()
200+ self .version = 0
201+ self .version_tensor = torch .tensor ([0 ], device = "cuda" )
202+
203+ def acquire_state_dict_lock (self ):
204+ self .state_dict_lock .acquire_write ()
205+
206+ def release_state_dict_lock (self ):
207+ self .version += 1
208+ self .version_tensor += 1
209+ torch .cuda .synchronize ()
210+ self .state_dict_lock .release_write ()
211+
212+ def all_worker_ids (self ):
213+ return [i for i in range (len (self .collector ._remote_collectors ))]
214+
215+ def _get_server_weights (self ):
216+ return self .state_dict
217+
218+ def _maybe_map_weights (self , server_weights ):
219+ return server_weights
220+
221+ def _skip_update (self , worker_id ):
222+ if self .version == 0 :
223+ return True
224+ if worker_id not in self .vllm_weight_versions :
225+ return False
226+ if self .vllm_weight_versions [worker_id ] == self .version :
227+ print (f"skipping update for { worker_id = } , { self .version = } , { self .vllm_weight_versions [worker_id ]= } " )
228+ return True
229+ return False
230+
231+ def _init_model_update_group (self , worker_id ):
232+ # here again, I want to grab the tp size from the vLLM worker... :(
233+ # llm.llm_engine.parallel_config.tensor_parallel_size
234+ vllm_tp_size = 1
235+ weight_sync_world_size = vllm_tp_size + 1
236+ print ("before stateless_init_process_group" )
237+ model_update_group = stateless_init_process_group (
238+ self .vllm_master_addresses [worker_id ],
239+ self .vllm_master_ports [worker_id ],
240+ 0 ,
241+ weight_sync_world_size ,
242+ torch .device ("cuda:0" ),
243+ )
244+ print ("after stateless_init_process_group" )
245+ self .vllm_comm_groups [worker_id ] = model_update_group
246+
247+ def _sync_weights_with_worker (
248+ self , worker_id : int , server_weights
249+ ):
250+ print (f"in _sync_weights_with_worker { worker_id } " , flush = True )
251+ self .collector ._remote_collectors [worker_id ].update_policy_weights_ .remote ()
252+ if worker_id not in self .vllm_comm_groups :
253+ print ("init model update group" )
254+ self ._init_model_update_group (worker_id )
255+ print ("done init model update group" )
256+ self .state_dict_lock .acquire_read ()
257+ for i , k in enumerate (server_weights .keys ()):
258+ # if i == 0:
259+ # print(f"{server_weights[k][0]=}")
260+ self .vllm_comm_groups [worker_id ].broadcast (server_weights [k ], src = 0 , stream = torch .cuda .current_stream ())
261+ self .vllm_comm_groups [worker_id ].broadcast (self .version_tensor , src = 0 , stream = torch .cuda .current_stream ())
262+ torch .cuda .synchronize ()
263+ print (f"_sync_weights_with_worker done broadcast { worker_id } { self .version = } " )
264+ self .vllm_weight_versions [worker_id ] = self .version
265+ self .state_dict_lock .release_read ()
0 commit comments