Skip to content

Commit fa6f9b6

Browse files
authored
Minor cleanup in torchchat/cli/builder.py (#1308)
Beautify a series of similar checks & fix a spelling error.
1 parent 4f2f4fb commit fa6f9b6

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

torchchat/cli/builder.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,16 @@ def __post_init__(self):
7979
if self.dso_path and self.pte_path:
8080
raise RuntimeError("specify either DSO path or PTE path, but not both")
8181

82-
if self.checkpoint_path and (self.dso_path or self.pte_path):
83-
print(
84-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
85-
)
86-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
87-
print(
88-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
89-
)
90-
if self.gguf_path and (self.dso_path or self.pte_path):
91-
print(
92-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
93-
)
94-
if not (self.dso_path) and not (self.pte_path):
82+
if self.dso_path or self.pte_path:
83+
ignored_params = [
84+
(self.checkpoint_path, "checkpoint path"),
85+
(self.checkpoint_dir, "checkpoint dir"),
86+
(self.gguf_path, "GGUF path"),
87+
]
88+
for param, param_msg in ignored_params:
89+
if param:
90+
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
91+
else:
9592
self.prefill_possible = True
9693

9794
@classmethod
@@ -446,7 +443,7 @@ def _maybe_init_distributed(
446443
return world_mesh, parallel_dims
447444

448445

449-
def _maybe_parellelize_model(
446+
def _maybe_parallelize_model(
450447
model: nn.Module,
451448
builder_args: BuilderArgs,
452449
world_mesh: DeviceMesh,
@@ -486,7 +483,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
486483
model = _init_model_on_meta_device(builder_args)
487484
else:
488485
model = _load_model_default(builder_args)
489-
model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
486+
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
490487

491488
model = model.to(device=builder_args.device, dtype=builder_args.precision)
492489
return model.eval()

0 commit comments

Comments
 (0)