Skip to content

Commit c3fb9db

Browse files
Fix lmdeploy 0.7.3 (#3584)
1 parent 8c0d00b commit c3fb9db

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

swift/llm/infer/infer_engine/utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,21 @@ def __init__(self,
265265
if not load_weights:
266266
for _ in e.map(self.model_comm.process_weight, self.gpu_list, ranks):
267267
pass
268-
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks, repeat(self.nccl_params)):
269-
pass
268+
if version.parse(lmdeploy.__version__) < version.parse('0.7.2'):
269+
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks, repeat(self.nccl_params)):
270+
pass
271+
else:
272+
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks):
273+
pass
270274

271275
def _create_weight(self, model_comm):
272276
"""Allocate weight buffer, load params if from_workspace."""
273277

274278
# TODO: support mpi
275279
self.node_id = 0
276280
self.node_num = 1
277-
self.nccl_params = model_comm.create_nccl_params(self.node_id)
281+
if version.parse(lmdeploy.__version__) < version.parse('0.7.2'):
282+
self.nccl_params = model_comm.create_nccl_params(self.node_id)
278283
torch.cuda.synchronize()
279284

280285
# create weight

0 commit comments

Comments
 (0)