|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | import torch.nn as nn
|
17 |
| -from tensordict import TensorDict |
18 | 17 | from torchrec.distributed.embedding_tower_sharding import (
|
19 | 18 | EmbeddingTowerCollectionSharder,
|
20 | 19 | EmbeddingTowerSharder,
|
|
47 | 46 | @dataclass
|
48 | 47 | class ModelInput(Pipelineable):
|
49 | 48 | float_features: torch.Tensor
|
50 |
| - idlist_features: Union[KeyedJaggedTensor, TensorDict] |
51 |
| - idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] |
| 49 | + idlist_features: KeyedJaggedTensor |
| 50 | + idscore_features: Optional[KeyedJaggedTensor] |
52 | 51 | label: torch.Tensor
|
53 | 52 |
|
54 | 53 | @staticmethod
|
@@ -77,13 +76,11 @@ def generate(
|
77 | 76 | randomize_indices: bool = True,
|
78 | 77 | device: Optional[torch.device] = None,
|
79 | 78 | max_feature_lengths: Optional[List[int]] = None,
|
80 |
| - input_type: str = "kjt", |
81 | 79 | ) -> Tuple["ModelInput", List["ModelInput"]]:
|
82 | 80 | """
|
83 | 81 | Returns a global (single-rank training) batch
|
84 | 82 | and a list of local (multi-rank training) batches of world_size.
|
85 | 83 | """
|
86 |
| - |
87 | 84 | batch_size_by_rank = [batch_size] * world_size
|
88 | 85 | if variable_batch_size:
|
89 | 86 | batch_size_by_rank = [
|
@@ -202,26 +199,11 @@ def _validate_pooling_factor(
|
202 | 199 | )
|
203 | 200 | global_idlist_lengths.append(lengths)
|
204 | 201 | global_idlist_indices.append(indices)
|
205 |
| - |
206 |
| - if input_type == "kjt": |
207 |
| - global_idlist_input = KeyedJaggedTensor( |
208 |
| - keys=idlist_features, |
209 |
| - values=torch.cat(global_idlist_indices), |
210 |
| - lengths=torch.cat(global_idlist_lengths), |
211 |
| - ) |
212 |
| - elif input_type == "td": |
213 |
| - dict_of_nt = { |
214 |
| - k: torch.nested.nested_tensor_from_jagged( |
215 |
| - values=values, |
216 |
| - lengths=lengths, |
217 |
| - ) |
218 |
| - for k, values, lengths in zip( |
219 |
| - idlist_features, global_idlist_indices, global_idlist_lengths |
220 |
| - ) |
221 |
| - } |
222 |
| - global_idlist_input = TensorDict(source=dict_of_nt) |
223 |
| - else: |
224 |
| - raise ValueError(f"For IdList features, unknown input type {input_type}") |
| 202 | + global_idlist_kjt = KeyedJaggedTensor( |
| 203 | + keys=idlist_features, |
| 204 | + values=torch.cat(global_idlist_indices), |
| 205 | + lengths=torch.cat(global_idlist_lengths), |
| 206 | + ) |
225 | 207 |
|
226 | 208 | for idx in range(len(idscore_ind_ranges)):
|
227 | 209 | ind_range = idscore_ind_ranges[idx]
|
@@ -263,25 +245,16 @@ def _validate_pooling_factor(
|
263 | 245 | global_idscore_lengths.append(lengths)
|
264 | 246 | global_idscore_indices.append(indices)
|
265 | 247 | global_idscore_weights.append(weights)
|
266 |
| - |
267 |
| - if input_type == "kjt": |
268 |
| - global_idscore_input = ( |
269 |
| - KeyedJaggedTensor( |
270 |
| - keys=idscore_features, |
271 |
| - values=torch.cat(global_idscore_indices), |
272 |
| - lengths=torch.cat(global_idscore_lengths), |
273 |
| - weights=torch.cat(global_idscore_weights), |
274 |
| - ) |
275 |
| - if global_idscore_indices |
276 |
| - else None |
| 248 | + global_idscore_kjt = ( |
| 249 | + KeyedJaggedTensor( |
| 250 | + keys=idscore_features, |
| 251 | + values=torch.cat(global_idscore_indices), |
| 252 | + lengths=torch.cat(global_idscore_lengths), |
| 253 | + weights=torch.cat(global_idscore_weights), |
277 | 254 | )
|
278 |
| - elif input_type == "td": |
279 |
| - assert ( |
280 |
| - len(idscore_features) == 0 |
281 |
| - ), "TensorDict does not support weighted features" |
282 |
| - global_idscore_input = None |
283 |
| - else: |
284 |
| - raise ValueError(f"For weighted features, unknown input type {input_type}") |
| 255 | + if global_idscore_indices |
| 256 | + else None |
| 257 | + ) |
285 | 258 |
|
286 | 259 | if randomize_indices:
|
287 | 260 | global_float = torch.rand(
|
@@ -330,57 +303,36 @@ def _validate_pooling_factor(
|
330 | 303 | weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
|
331 | 304 | )
|
332 | 305 |
|
333 |
| - if input_type == "kjt": |
334 |
| - local_idlist_input = KeyedJaggedTensor( |
335 |
| - keys=idlist_features, |
336 |
| - values=torch.cat(local_idlist_indices), |
337 |
| - lengths=torch.cat(local_idlist_lengths), |
338 |
| - ) |
339 |
| - |
340 |
| - local_idscore_input = ( |
341 |
| - KeyedJaggedTensor( |
342 |
| - keys=idscore_features, |
343 |
| - values=torch.cat(local_idscore_indices), |
344 |
| - lengths=torch.cat(local_idscore_lengths), |
345 |
| - weights=torch.cat(local_idscore_weights), |
346 |
| - ) |
347 |
| - if local_idscore_indices |
348 |
| - else None |
349 |
| - ) |
350 |
| - elif input_type == "td": |
351 |
| - dict_of_nt = { |
352 |
| - k: torch.nested.nested_tensor_from_jagged( |
353 |
| - values=values, |
354 |
| - lengths=lengths, |
355 |
| - ) |
356 |
| - for k, values, lengths in zip( |
357 |
| - idlist_features, local_idlist_indices, local_idlist_lengths |
358 |
| - ) |
359 |
| - } |
360 |
| - local_idlist_input = TensorDict(source=dict_of_nt) |
361 |
| - assert ( |
362 |
| - len(idscore_features) == 0 |
363 |
| - ), "TensorDict does not support weighted features" |
364 |
| - local_idscore_input = None |
| 306 | + local_idlist_kjt = KeyedJaggedTensor( |
| 307 | + keys=idlist_features, |
| 308 | + values=torch.cat(local_idlist_indices), |
| 309 | + lengths=torch.cat(local_idlist_lengths), |
| 310 | + ) |
365 | 311 |
|
366 |
| - else: |
367 |
| - raise ValueError( |
368 |
| - f"For weighted features, unknown input type {input_type}" |
| 312 | + local_idscore_kjt = ( |
| 313 | + KeyedJaggedTensor( |
| 314 | + keys=idscore_features, |
| 315 | + values=torch.cat(local_idscore_indices), |
| 316 | + lengths=torch.cat(local_idscore_lengths), |
| 317 | + weights=torch.cat(local_idscore_weights), |
369 | 318 | )
|
| 319 | + if local_idscore_indices |
| 320 | + else None |
| 321 | + ) |
370 | 322 |
|
371 | 323 | local_input = ModelInput(
|
372 | 324 | float_features=global_float[r * batch_size : (r + 1) * batch_size],
|
373 |
| - idlist_features=local_idlist_input, |
374 |
| - idscore_features=local_idscore_input, |
| 325 | + idlist_features=local_idlist_kjt, |
| 326 | + idscore_features=local_idscore_kjt, |
375 | 327 | label=global_label[r * batch_size : (r + 1) * batch_size],
|
376 | 328 | )
|
377 | 329 | local_inputs.append(local_input)
|
378 | 330 |
|
379 | 331 | return (
|
380 | 332 | ModelInput(
|
381 | 333 | float_features=global_float,
|
382 |
| - idlist_features=global_idlist_input, |
383 |
| - idscore_features=global_idscore_input, |
| 334 | + idlist_features=global_idlist_kjt, |
| 335 | + idscore_features=global_idscore_kjt, |
384 | 336 | label=global_label,
|
385 | 337 | ),
|
386 | 338 | local_inputs,
|
@@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
|
671 | 623 |
|
672 | 624 | def record_stream(self, stream: torch.Stream) -> None:
|
673 | 625 | self.float_features.record_stream(stream)
|
674 |
| - if isinstance(self.idlist_features, KeyedJaggedTensor): |
675 |
| - self.idlist_features.record_stream(stream) |
676 |
| - if isinstance(self.idscore_features, KeyedJaggedTensor): |
| 626 | + self.idlist_features.record_stream(stream) |
| 627 | + if self.idscore_features is not None: |
677 | 628 | self.idscore_features.record_stream(stream)
|
678 | 629 | self.label.record_stream(stream)
|
679 | 630 |
|
@@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput:
|
1880 | 1831 | )
|
1881 | 1832 |
|
1882 | 1833 | # stride will be same but features will be joined
|
1883 |
| - assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) |
1884 |
| - assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) |
1885 | 1834 | modified_input.idlist_features = KeyedJaggedTensor.concat(
|
1886 | 1835 | [modified_input.idlist_features, self._extra_input.idlist_features]
|
1887 | 1836 | )
|
|
0 commit comments