File tree 1 file changed +8
-3
lines changed
swift/llm/infer/infer_engine
1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -265,16 +265,21 @@ def __init__(self,
265
265
if not load_weights :
266
266
for _ in e .map (self .model_comm .process_weight , self .gpu_list , ranks ):
267
267
pass
268
- for _ in e .map (self .model_comm .create_engine , self .gpu_list , ranks , repeat (self .nccl_params )):
269
- pass
268
+ if version .parse (lmdeploy .__version__ ) < version .parse ('0.7.2' ):
269
+ for _ in e .map (self .model_comm .create_engine , self .gpu_list , ranks , repeat (self .nccl_params )):
270
+ pass
271
+ else :
272
+ for _ in e .map (self .model_comm .create_engine , self .gpu_list , ranks ):
273
+ pass
270
274
271
275
def _create_weight (self , model_comm ):
272
276
"""Allocate weight buffer, load params if from_workspace."""
273
277
274
278
# TODO: support mpi
275
279
self .node_id = 0
276
280
self .node_num = 1
277
- self .nccl_params = model_comm .create_nccl_params (self .node_id )
281
+ if version .parse (lmdeploy .__version__ ) < version .parse ('0.7.2' ):
282
+ self .nccl_params = model_comm .create_nccl_params (self .node_id )
278
283
torch .cuda .synchronize ()
279
284
280
285
# create weight
You can’t perform that action at this time.
0 commit comments