Skip to content

Commit fc79c7a

Browse files
sarckkfacebook-github-bot
authored andcommitted
Rename preproc to postproc for pipelining (#2661)
Summary: Pull Request resolved: #2661 The data transformation that happens during model fwd invocation should be post-processing, not pre-processing. renaming accordingly. Reviewed By: dstaay-fb Differential Revision: D67756024 fbshipit-source-id: 79231c608de078381f308a88c71eba515b04753a
1 parent 455de88 commit fc79c7a

File tree

6 files changed

+323
-318
lines changed

6 files changed

+323
-318
lines changed

torchrec/distributed/test_utils/test_model.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ def __init__(
11921192
max_feature_lengths: Optional[Dict[str, int]] = None,
11931193
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
11941194
over_arch_clazz: Type[nn.Module] = TestOverArch,
1195-
preproc_module: Optional[nn.Module] = None,
1195+
postproc_module: Optional[nn.Module] = None,
11961196
) -> None:
11971197
super().__init__(
11981198
tables=cast(List[BaseEmbeddingConfig], tables),
@@ -1229,7 +1229,7 @@ def __init__(
12291229
"dummy_ones",
12301230
torch.ones(1, device=dense_device),
12311231
)
1232-
self.preproc_module = preproc_module
1232+
self.postproc_module = postproc_module
12331233

12341234
def sparse_forward(self, input: ModelInput) -> KeyedTensor:
12351235
return self.sparse(
@@ -1256,8 +1256,8 @@ def forward(
12561256
self,
12571257
input: ModelInput,
12581258
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
1259-
if self.preproc_module:
1260-
input = self.preproc_module(input)
1259+
if self.postproc_module:
1260+
input = self.postproc_module(input)
12611261
return self.dense_forward(input, self.sparse_forward(input))
12621262

12631263

@@ -1749,18 +1749,18 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
17491749

17501750
class TestModelWithPreproc(nn.Module):
17511751
"""
1752-
Basic module with up to 3 preproc modules:
1753-
- preproc on idlist_features for non-weighted EBC
1754-
- preproc on idscore_features for weighted EBC
1755-
- optional preproc on model input shared by both EBCs
1752+
Basic module with up to 3 postproc modules:
1753+
- postproc on idlist_features for non-weighted EBC
1754+
- postproc on idscore_features for weighted EBC
1755+
- optional postproc on model input shared by both EBCs
17561756
17571757
Args:
17581758
tables,
17591759
weighted_tables,
17601760
device,
1761-
preproc_module,
1761+
postproc_module,
17621762
num_float_features,
1763-
run_preproc_inline,
1763+
run_postproc_inline,
17641764
17651765
Example:
17661766
>>> TestModelWithPreproc(tables, weighted_tables, device)
@@ -1774,9 +1774,9 @@ def __init__(
17741774
tables: List[EmbeddingBagConfig],
17751775
weighted_tables: List[EmbeddingBagConfig],
17761776
device: torch.device,
1777-
preproc_module: Optional[nn.Module] = None,
1777+
postproc_module: Optional[nn.Module] = None,
17781778
num_float_features: int = 10,
1779-
run_preproc_inline: bool = False,
1779+
run_postproc_inline: bool = False,
17801780
) -> None:
17811781
super().__init__()
17821782
self.dense = TestDenseArch(num_float_features, device)
@@ -1790,17 +1790,17 @@ def __init__(
17901790
is_weighted=True,
17911791
device=device,
17921792
)
1793-
self.preproc_nonweighted = TestPreprocNonWeighted()
1794-
self.preproc_weighted = TestPreprocWeighted()
1795-
self._preproc_module = preproc_module
1796-
self._run_preproc_inline = run_preproc_inline
1793+
self.postproc_nonweighted = TestPreprocNonWeighted()
1794+
self.postproc_weighted = TestPreprocWeighted()
1795+
self._postproc_module = postproc_module
1796+
self._run_postproc_inline = run_postproc_inline
17971797

17981798
def forward(
17991799
self,
18001800
input: ModelInput,
18011801
) -> Tuple[torch.Tensor, torch.Tensor]:
18021802
"""
1803-
Runs preprco for EBC and weighted EBC, optionally runs preproc for input
1803+
Runs preprco for EBC and weighted EBC, optionally runs postproc for input
18041804
18051805
Args:
18061806
input
@@ -1809,20 +1809,20 @@ def forward(
18091809
"""
18101810
modified_input = input
18111811

1812-
if self._preproc_module is not None:
1813-
modified_input = self._preproc_module(modified_input)
1814-
elif self._run_preproc_inline:
1812+
if self._postproc_module is not None:
1813+
modified_input = self._postproc_module(modified_input)
1814+
elif self._run_postproc_inline:
18151815
idlist_features = modified_input.idlist_features
18161816
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
18171817
idlist_features.keys(), # pyre-ignore [6]
18181818
idlist_features.values(), # pyre-ignore [6]
18191819
idlist_features.lengths(), # pyre-ignore [16]
18201820
)
18211821

1822-
modified_idlist_features = self.preproc_nonweighted(
1822+
modified_idlist_features = self.postproc_nonweighted(
18231823
modified_input.idlist_features
18241824
)
1825-
modified_idscore_features = self.preproc_weighted(
1825+
modified_idscore_features = self.postproc_weighted(
18261826
modified_input.idscore_features
18271827
)
18281828
ebc_out = self.ebc(modified_idlist_features[0])
@@ -1834,15 +1834,15 @@ def forward(
18341834

18351835
class TestNegSamplingModule(torch.nn.Module):
18361836
"""
1837-
Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
1837+
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
18381838
18391839
Args:
18401840
extra_input
18411841
has_params
18421842
18431843
Example:
1844-
>>> preproc = TestNegSamplingModule(extra_input)
1845-
>>> out = preproc(in)
1844+
>>> postproc = TestNegSamplingModule(extra_input)
1845+
>>> out = postproc(in)
18461846
18471847
Returns:
18481848
ModelInput
@@ -1906,8 +1906,8 @@ class TestPositionWeightedPreprocModule(torch.nn.Module):
19061906
19071907
Args: None
19081908
Example:
1909-
>>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1910-
>>> out = preproc(in)
1909+
>>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1910+
>>> out = postproc(in)
19111911
Returns:
19121912
ModelInput
19131913
"""

0 commit comments

Comments
 (0)