@@ -147,7 +147,6 @@ def gen_model_and_input(
147147 long_indices : bool = True ,
148148 global_constant_batch : bool = False ,
149149 num_inputs : int = 1 ,
150- input_type : str = "kjt" , # "kjt" or "td"
151150) -> Tuple [nn .Module , List [Tuple [ModelInput , List [ModelInput ]]]]:
152151 torch .manual_seed (0 )
153152 if dedup_feature_names :
@@ -178,9 +177,9 @@ def gen_model_and_input(
178177 feature_processor_modules = feature_processor_modules ,
179178 )
180179 inputs = []
181- if input_type == "kjt" and generate == ModelInput . generate_variable_batch_input :
182- for _ in range ( num_inputs ):
183- inputs . append (
180+ for _ in range ( num_inputs ) :
181+ inputs . append (
182+ (
184183 cast (VariableBatchModelInputCallable , generate )(
185184 average_batch_size = batch_size ,
186185 world_size = world_size ,
@@ -189,26 +188,8 @@ def gen_model_and_input(
189188 weighted_tables = weighted_tables or [],
190189 global_constant_batch = global_constant_batch ,
191190 )
192- )
193- elif generate == ModelInput .generate :
194- for _ in range (num_inputs ):
195- inputs .append (
196- ModelInput .generate (
197- world_size = world_size ,
198- tables = tables ,
199- dedup_tables = dedup_tables ,
200- weighted_tables = weighted_tables or [],
201- num_float_features = num_float_features ,
202- variable_batch_size = variable_batch_size ,
203- batch_size = batch_size ,
204- long_indices = long_indices ,
205- input_type = input_type ,
206- )
207- )
208- else :
209- for _ in range (num_inputs ):
210- inputs .append (
211- cast (ModelInputCallable , generate )(
191+ if generate == ModelInput .generate_variable_batch_input
192+ else cast (ModelInputCallable , generate )(
212193 world_size = world_size ,
213194 tables = tables ,
214195 dedup_tables = dedup_tables ,
@@ -219,6 +200,7 @@ def gen_model_and_input(
219200 long_indices = long_indices ,
220201 )
221202 )
203+ )
222204 return (model , inputs )
223205
224206
@@ -315,7 +297,6 @@ def sharding_single_rank_test(
315297 global_constant_batch : bool = False ,
316298 world_size_2D : Optional [int ] = None ,
317299 node_group_size : Optional [int ] = None ,
318- input_type : str = "kjt" , # "kjt" or "td"
319300) -> None :
320301 with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
321302 # Generate model & inputs.
@@ -338,7 +319,6 @@ def sharding_single_rank_test(
338319 batch_size = batch_size ,
339320 feature_processor_modules = feature_processor_modules ,
340321 global_constant_batch = global_constant_batch ,
341- input_type = input_type ,
342322 )
343323 global_model = global_model .to (ctx .device )
344324 global_input = inputs [0 ][0 ].to (ctx .device )
0 commit comments