10
10
11
11
import pytest
12
12
import torch
13
- from tensordict import NonTensorStack , TensorDict
13
+ from tensordict import LazyStackedTensorDict , NonTensorStack , TensorDict
14
14
from tensordict .nn import CompositeDistribution , TensorDictModule
15
15
from tensordict .nn .distributions import NormalParamExtractor
16
16
17
17
from torch import distributions as dist , nn
18
+
19
+ from torchrl .collectors import SyncDataCollector
18
20
from torchrl .data import Binary , Bounded , Categorical , Composite , MultiOneHot , OneHot
19
21
from torchrl .data .llm import LLMData
20
22
from torchrl .data .llm .dataset import _has_transformers
23
+ from torchrl .envs import LLMEnv
21
24
from torchrl .modules import (
22
25
from_hf_transformers ,
23
26
from_vllm ,
42
45
43
46
if os .getenv ("PYTORCH_TEST_FBCODE" ):
44
47
from pytorch .rl .test ._utils_internal import get_default_devices
45
- from pytorch .rl .test .mocking_classes import NestedCountingEnv
48
+ from pytorch .rl .test .mocking_classes import DummyStrDataLoader , NestedCountingEnv
46
49
else :
47
50
from _utils_internal import get_default_devices
48
- from mocking_classes import NestedCountingEnv
51
+ from mocking_classes import DummyStrDataLoader , NestedCountingEnv
49
52
50
53
_has_vllm = importlib .util .find_spec ("vllm" ) is not None
51
54
@@ -922,6 +925,18 @@ def test_lmhead_actorvalueoperator(device):
922
925
@pytest .mark .skipif (not _has_transformers , reason = "missing transformers dependencies" )
923
926
@pytest .mark .skipif (not _has_vllm , reason = "missing vllm dependencies" )
924
927
class TestLLMActor :
928
+ @pytest .fixture (scope = "module" )
929
+ def vllm_instance (self ):
930
+ try :
931
+ import vllm
932
+ except ImportError :
933
+ pytest .skip (reason = "missing vllm" )
934
+
935
+ llm_model = vllm .LLM ("gpt2" )
936
+ tokenizer = llm_model .get_tokenizer ()
937
+ tokenizer .pad_token = tokenizer .eos_token
938
+ return llm_model
939
+
925
940
@pytest .mark .parametrize (
926
941
"from_text, generate, return_log_probs, tokens, attention_mask" ,
927
942
[
@@ -1005,12 +1020,17 @@ def test_from_hf_transformers(
1005
1020
],
1006
1021
)
1007
1022
def test_from_vllm (
1008
- self , from_text , generate , return_log_probs , tokens , attention_mask
1023
+ self ,
1024
+ from_text ,
1025
+ generate ,
1026
+ return_log_probs ,
1027
+ tokens ,
1028
+ attention_mask ,
1029
+ vllm_instance ,
1009
1030
):
1010
1031
torch .manual_seed (0 )
1011
- from vllm import LLM
1012
1032
1013
- model = LLM ( model = "facebook/opt-125m" )
1033
+ model = vllm_instance
1014
1034
m = from_vllm (
1015
1035
model ,
1016
1036
from_text = from_text ,
@@ -1122,6 +1142,8 @@ def _run_check(
1122
1142
1123
1143
# If from text and not generating, the tokens are not returned for now
1124
1144
if not (from_text and not generate ):
1145
+ assert td .tokens_response is not None
1146
+ assert td .tokens is not None
1125
1147
assert td .tokens_response .shape [:- 1 ] == td .tokens .shape [:- 1 ]
1126
1148
# The convention is that the response only has new tokens
1127
1149
assert (
@@ -1166,28 +1188,43 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
1166
1188
)
1167
1189
1168
1190
@pytest .mark .parametrize (
1169
- "from_text, tokens, attention_mask" ,
1191
+ "pad_output, from_text, tokens, attention_mask" ,
1170
1192
[
1171
- (True , None , None ),
1193
+ (True , True , None , None ),
1194
+ (False , True , None , None ),
1172
1195
(
1196
+ True ,
1173
1197
False ,
1174
1198
torch .randint (1024 , (1 , 10 )),
1175
1199
torch .ones (1 , 10 , dtype = torch .int64 ),
1176
1200
),
1177
- (False , torch .randint (1024 , (1 , 10 )), None ),
1201
+ (True , False , torch .randint (1024 , (1 , 10 )), None ),
1178
1202
],
1179
1203
)
1180
- def test_from_vllm_logprobs (self , from_text , tokens , attention_mask ):
1204
+ def test_from_vllm_logprobs (
1205
+ self , from_text , tokens , attention_mask , pad_output , vllm_instance
1206
+ ):
1181
1207
torch .manual_seed (0 )
1182
- from vllm import LLM
1183
1208
1184
- model = LLM ( model = "facebook/opt-125m" )
1209
+ model = vllm_instance
1185
1210
m_generate = from_vllm (
1186
- model , from_text = from_text , generate = True , return_log_probs = True
1211
+ model ,
1212
+ from_text = from_text ,
1213
+ generate = True ,
1214
+ return_log_probs = True ,
1215
+ pad_output = pad_output ,
1216
+ )
1217
+ m_logprobs = from_vllm (
1218
+ model , from_text = from_text , generate = False , pad_output = pad_output
1187
1219
)
1188
- m_logprobs = from_vllm (model , from_text = from_text , generate = False )
1189
1220
self ._check_lps (
1190
- m_generate , m_logprobs , tokens , attention_mask , from_text , has_logits = False
1221
+ m_generate ,
1222
+ m_logprobs ,
1223
+ tokens ,
1224
+ attention_mask ,
1225
+ from_text ,
1226
+ has_logits = False ,
1227
+ tol = 1e-1 ,
1191
1228
)
1192
1229
1193
1230
def _check_lps (
@@ -1198,6 +1235,7 @@ def _check_lps(
1198
1235
attention_mask ,
1199
1236
from_text ,
1200
1237
has_logits ,
1238
+ tol = 1e-2 ,
1201
1239
):
1202
1240
# Checks that the log-probs gathered with generate=False equate those with generate=True
1203
1241
tdin_genetate = self ._make_data (
@@ -1218,8 +1256,114 @@ def _check_lps(
1218
1256
assert td_generate .log_probs .shape == td_generate .tokens_response .shape
1219
1257
assert td_logprobs .log_probs .shape == td_generate .tokens_response .shape
1220
1258
torch .testing .assert_close (
1221
- td_generate .log_probs , td_logprobs .log_probs , rtol = 1e-2 , atol = 1e-2
1259
+ td_generate .log_probs , td_logprobs .log_probs , rtol = tol , atol = tol
1260
+ )
1261
+
1262
+ @pytest .mark .parametrize ("pad" , [True , False ])
1263
+ @pytest .mark .parametrize ("generate" , [True , False ])
1264
+ @pytest .mark .parametrize ("use_tensorclass" , [True , False ])
1265
+ def test_vllm_batch_run (self , pad , generate , use_tensorclass , vllm_instance ):
1266
+ # Test generate - padding combinations
1267
+ policy = from_vllm (
1268
+ vllm_instance ,
1269
+ from_text = True ,
1270
+ generate = generate ,
1271
+ return_log_probs = True ,
1272
+ pad_output = pad ,
1273
+ generate_kwargs = {"max_tokens" : 10000 },
1274
+ )
1275
+ if generate :
1276
+ data = LazyStackedTensorDict (
1277
+ * TensorDict (
1278
+ text = NonTensorStack ("a string" , "another very long string" ),
1279
+ batch_size = [2 ],
1280
+ ).unbind (0 )
1281
+ )
1282
+ else :
1283
+ data = LazyStackedTensorDict (
1284
+ * TensorDict (
1285
+ text = NonTensorStack ("a string" , "another very long string" ),
1286
+ text_response = NonTensorStack (
1287
+ " is a string" , " is still a very long string"
1288
+ ),
1289
+ batch_size = [2 ],
1290
+ ).unbind (0 )
1291
+ )
1292
+ if use_tensorclass :
1293
+ data = LLMData .from_tensordict (data )
1294
+ output = policy (data )
1295
+ try :
1296
+ log_probs = output .get ("log_probs" )
1297
+ except Exception :
1298
+ log_probs = output .get ("log_probs" , as_list = True )
1299
+ if pad :
1300
+ assert isinstance (log_probs , torch .Tensor )
1301
+ else :
1302
+ assert isinstance (log_probs , list )
1303
+ text = output .get ("text" , as_list = True )
1304
+ # TODO: this is not ideal...
1305
+ if use_tensorclass :
1306
+ assert isinstance (text , list )
1307
+ else :
1308
+ assert isinstance (text , NonTensorStack )
1309
+ text_response = output .get ("text_response" , as_list = True )
1310
+ if use_tensorclass :
1311
+ assert isinstance (text_response , list )
1312
+ else :
1313
+ assert isinstance (text_response , NonTensorStack )
1314
+ try :
1315
+ tokens_response = output .get ("tokens_response" )
1316
+ except Exception :
1317
+ tokens_response = output .get ("tokens_response" , as_list = True )
1318
+ if pad :
1319
+ assert isinstance (tokens_response , torch .Tensor )
1320
+ else :
1321
+ assert isinstance (tokens_response , list )
1322
+ try :
1323
+ tokens = output .get ("tokens" )
1324
+ except Exception :
1325
+ tokens = output .get ("tokens" , as_list = True )
1326
+ if not generate :
1327
+ assert tokens is None
1328
+ elif pad :
1329
+ assert isinstance (tokens , torch .Tensor ), tokens
1330
+ else :
1331
+ assert isinstance (tokens , list )
1332
+
1333
+ def test_vllm_collection (self , vllm_instance ):
1334
+ policy = from_vllm (
1335
+ vllm_instance ,
1336
+ from_text = True ,
1337
+ generate = True ,
1338
+ return_log_probs = True ,
1339
+ pad_output = False ,
1340
+ generate_kwargs = {"max_tokens" : 10 },
1341
+ )
1342
+ self ._run_check_collector (policy )
1343
+
1344
+ def test_transformers_collection (self ):
1345
+ ...
1346
+
1347
+ @classmethod
1348
+ def env_constructor (cls ):
1349
+ dl = DummyStrDataLoader (batch_size = 32 )
1350
+ env = LLMEnv .from_dataloader (
1351
+ dl , batch_size = 16 , repeats = 4 , str2str = True , group_repeats = True
1352
+ )
1353
+ assert env .batch_size == (64 ,)
1354
+ return env
1355
+
1356
+ def _run_check_collector (self , policy ):
1357
+ collector = SyncDataCollector (
1358
+ self .env_constructor ,
1359
+ policy = policy ,
1360
+ frames_per_batch = 128 ,
1361
+ total_frames = 512 ,
1362
+ use_buffers = False ,
1222
1363
)
1364
+ for data in collector :
1365
+ assert isinstance (data , LazyStackedTensorDict )
1366
+ assert isinstance (data .reshape (- 1 ).get ("text_response" ), NonTensorStack )
1223
1367
1224
1368
1225
1369
if __name__ == "__main__" :
0 commit comments