1010
1111import pytest
1212import torch
13- from tensordict import NonTensorStack , TensorDict
13+ from tensordict import LazyStackedTensorDict , NonTensorStack , TensorDict
1414from tensordict .nn import CompositeDistribution , TensorDictModule
1515from tensordict .nn .distributions import NormalParamExtractor
1616
1717from torch import distributions as dist , nn
18+
19+ from torchrl .collectors import SyncDataCollector
1820from torchrl .data import Binary , Bounded , Categorical , Composite , MultiOneHot , OneHot
1921from torchrl .data .llm import LLMData
2022from torchrl .data .llm .dataset import _has_transformers
23+ from torchrl .envs import LLMEnv
2124from torchrl .modules import (
2225 from_hf_transformers ,
2326 from_vllm ,
4245
4346if os .getenv ("PYTORCH_TEST_FBCODE" ):
4447 from pytorch .rl .test ._utils_internal import get_default_devices
45- from pytorch .rl .test .mocking_classes import NestedCountingEnv
48+ from pytorch .rl .test .mocking_classes import DummyStrDataLoader , NestedCountingEnv
4649else :
4750 from _utils_internal import get_default_devices
48- from mocking_classes import NestedCountingEnv
51+ from mocking_classes import DummyStrDataLoader , NestedCountingEnv
4952
5053_has_vllm = importlib .util .find_spec ("vllm" ) is not None
5154
@@ -922,6 +925,18 @@ def test_lmhead_actorvalueoperator(device):
922925@pytest .mark .skipif (not _has_transformers , reason = "missing transformers dependencies" )
923926@pytest .mark .skipif (not _has_vllm , reason = "missing vllm dependencies" )
924927class TestLLMActor :
928+ @pytest .fixture (scope = "module" )
929+ def vllm_instance (self ):
930+ try :
931+ import vllm
932+ except ImportError :
933+ pytest .skip (reason = "missing vllm" )
934+
935+ llm_model = vllm .LLM ("gpt2" )
936+ tokenizer = llm_model .get_tokenizer ()
937+ tokenizer .pad_token = tokenizer .eos_token
938+ return llm_model
939+
925940 @pytest .mark .parametrize (
926941 "from_text, generate, return_log_probs, tokens, attention_mask" ,
927942 [
@@ -1005,12 +1020,17 @@ def test_from_hf_transformers(
10051020 ],
10061021 )
10071022 def test_from_vllm (
1008- self , from_text , generate , return_log_probs , tokens , attention_mask
1023+ self ,
1024+ from_text ,
1025+ generate ,
1026+ return_log_probs ,
1027+ tokens ,
1028+ attention_mask ,
1029+ vllm_instance ,
10091030 ):
10101031 torch .manual_seed (0 )
1011- from vllm import LLM
10121032
1013- model = LLM ( model = "facebook/opt-125m" )
1033+ model = vllm_instance
10141034 m = from_vllm (
10151035 model ,
10161036 from_text = from_text ,
@@ -1122,6 +1142,8 @@ def _run_check(
11221142
11231143 # If from text and not generating, the tokens are not returned for now
11241144 if not (from_text and not generate ):
1145+ assert td .tokens_response is not None
1146+ assert td .tokens is not None
11251147 assert td .tokens_response .shape [:- 1 ] == td .tokens .shape [:- 1 ]
11261148 # The convention is that the response only has new tokens
11271149 assert (
@@ -1166,28 +1188,43 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
11661188 )
11671189
11681190 @pytest .mark .parametrize (
1169- "from_text, tokens, attention_mask" ,
1191+ "pad_output, from_text, tokens, attention_mask" ,
11701192 [
1171- (True , None , None ),
1193+ (True , True , None , None ),
1194+ (False , True , None , None ),
11721195 (
1196+ True ,
11731197 False ,
11741198 torch .randint (1024 , (1 , 10 )),
11751199 torch .ones (1 , 10 , dtype = torch .int64 ),
11761200 ),
1177- (False , torch .randint (1024 , (1 , 10 )), None ),
1201+ (True , False , torch .randint (1024 , (1 , 10 )), None ),
11781202 ],
11791203 )
1180- def test_from_vllm_logprobs (self , from_text , tokens , attention_mask ):
1204+ def test_from_vllm_logprobs (
1205+ self , from_text , tokens , attention_mask , pad_output , vllm_instance
1206+ ):
11811207 torch .manual_seed (0 )
1182- from vllm import LLM
11831208
1184- model = LLM ( model = "facebook/opt-125m" )
1209+ model = vllm_instance
11851210 m_generate = from_vllm (
1186- model , from_text = from_text , generate = True , return_log_probs = True
1211+ model ,
1212+ from_text = from_text ,
1213+ generate = True ,
1214+ return_log_probs = True ,
1215+ pad_output = pad_output ,
1216+ )
1217+ m_logprobs = from_vllm (
1218+ model , from_text = from_text , generate = False , pad_output = pad_output
11871219 )
1188- m_logprobs = from_vllm (model , from_text = from_text , generate = False )
11891220 self ._check_lps (
1190- m_generate , m_logprobs , tokens , attention_mask , from_text , has_logits = False
1221+ m_generate ,
1222+ m_logprobs ,
1223+ tokens ,
1224+ attention_mask ,
1225+ from_text ,
1226+ has_logits = False ,
1227+ tol = 1e-1 ,
11911228 )
11921229
11931230 def _check_lps (
@@ -1198,6 +1235,7 @@ def _check_lps(
11981235 attention_mask ,
11991236 from_text ,
12001237 has_logits ,
1238+ tol = 1e-2 ,
12011239 ):
12021240 # Checks that the log-probs gathered with generate=False equate those with generate=True
12031241 tdin_genetate = self ._make_data (
@@ -1218,8 +1256,114 @@ def _check_lps(
12181256 assert td_generate .log_probs .shape == td_generate .tokens_response .shape
12191257 assert td_logprobs .log_probs .shape == td_generate .tokens_response .shape
12201258 torch .testing .assert_close (
1221- td_generate .log_probs , td_logprobs .log_probs , rtol = 1e-2 , atol = 1e-2
1259+ td_generate .log_probs , td_logprobs .log_probs , rtol = tol , atol = tol
1260+ )
1261+
1262+ @pytest .mark .parametrize ("pad" , [True , False ])
1263+ @pytest .mark .parametrize ("generate" , [True , False ])
1264+ @pytest .mark .parametrize ("use_tensorclass" , [True , False ])
1265+ def test_vllm_batch_run (self , pad , generate , use_tensorclass , vllm_instance ):
1266+ # Test generate - padding combinations
1267+ policy = from_vllm (
1268+ vllm_instance ,
1269+ from_text = True ,
1270+ generate = generate ,
1271+ return_log_probs = True ,
1272+ pad_output = pad ,
1273+ generate_kwargs = {"max_tokens" : 10000 },
1274+ )
1275+ if generate :
1276+ data = LazyStackedTensorDict (
1277+ * TensorDict (
1278+ text = NonTensorStack ("a string" , "another very long string" ),
1279+ batch_size = [2 ],
1280+ ).unbind (0 )
1281+ )
1282+ else :
1283+ data = LazyStackedTensorDict (
1284+ * TensorDict (
1285+ text = NonTensorStack ("a string" , "another very long string" ),
1286+ text_response = NonTensorStack (
1287+ " is a string" , " is still a very long string"
1288+ ),
1289+ batch_size = [2 ],
1290+ ).unbind (0 )
1291+ )
1292+ if use_tensorclass :
1293+ data = LLMData .from_tensordict (data )
1294+ output = policy (data )
1295+ try :
1296+ log_probs = output .get ("log_probs" )
1297+ except Exception :
1298+ log_probs = output .get ("log_probs" , as_list = True )
1299+ if pad :
1300+ assert isinstance (log_probs , torch .Tensor )
1301+ else :
1302+ assert isinstance (log_probs , list )
1303+ text = output .get ("text" , as_list = True )
1304+ # TODO: this is not ideal...
1305+ if use_tensorclass :
1306+ assert isinstance (text , list )
1307+ else :
1308+ assert isinstance (text , NonTensorStack )
1309+ text_response = output .get ("text_response" , as_list = True )
1310+ if use_tensorclass :
1311+ assert isinstance (text_response , list )
1312+ else :
1313+ assert isinstance (text_response , NonTensorStack )
1314+ try :
1315+ tokens_response = output .get ("tokens_response" )
1316+ except Exception :
1317+ tokens_response = output .get ("tokens_response" , as_list = True )
1318+ if pad :
1319+ assert isinstance (tokens_response , torch .Tensor )
1320+ else :
1321+ assert isinstance (tokens_response , list )
1322+ try :
1323+ tokens = output .get ("tokens" )
1324+ except Exception :
1325+ tokens = output .get ("tokens" , as_list = True )
1326+ if not generate :
1327+ assert tokens is None
1328+ elif pad :
1329+ assert isinstance (tokens , torch .Tensor ), tokens
1330+ else :
1331+ assert isinstance (tokens , list )
1332+
1333+ def test_vllm_collection (self , vllm_instance ):
1334+ policy = from_vllm (
1335+ vllm_instance ,
1336+ from_text = True ,
1337+ generate = True ,
1338+ return_log_probs = True ,
1339+ pad_output = False ,
1340+ generate_kwargs = {"max_tokens" : 10 },
1341+ )
1342+ self ._run_check_collector (policy )
1343+
1344+ def test_transformers_collection (self ):
1345+ ...
1346+
1347+ @classmethod
1348+ def env_constructor (cls ):
1349+ dl = DummyStrDataLoader (batch_size = 32 )
1350+ env = LLMEnv .from_dataloader (
1351+ dl , batch_size = 16 , repeats = 4 , str2str = True , group_repeats = True
1352+ )
1353+ assert env .batch_size == (64 ,)
1354+ return env
1355+
1356+ def _run_check_collector (self , policy ):
1357+ collector = SyncDataCollector (
1358+ self .env_constructor ,
1359+ policy = policy ,
1360+ frames_per_batch = 128 ,
1361+ total_frames = 512 ,
1362+ use_buffers = False ,
12221363 )
1364+ for data in collector :
1365+ assert isinstance (data , LazyStackedTensorDict )
1366+ assert isinstance (data .reshape (- 1 ).get ("text_response" ), NonTensorStack )
12231367
12241368
12251369if __name__ == "__main__" :
0 commit comments