@@ -294,6 +294,41 @@ def signal_error() -> None:
294294 _check_call (_LIB .XGCommunicatorSignalError ())
295295
296296
297+ def _find_nccl () -> Optional [str ]:
298+ from nvidia .nccl import lib
299+
300+ # There are two versions of nvidia-nccl, one is from PyPI, another one from
301+ # nvidia-pyindex. We support only the first one as the second one is too old (2.9.8
302+ # as of writing).
303+ #
304+ # nccl 2.28 doesn't have the __file__ attribute, we use the namespace path instead.
305+ if lib .__file__ is not None :
306+ dirname : Optional [str ] = os .path .dirname (lib .__file__ )
307+ elif hasattr (lib , "__path__" ) and len (lib .__path__ ) > 0 :
308+ dirname = lib .__path__ [0 ]
309+ else :
310+ dirname = None
311+ if not dirname :
312+ return None
313+
314+ # Find the first shared object in the lib directory.
315+ files = os .listdir (dirname )
316+ if not files :
317+ return None
318+ assert len (files ) >= 1
319+
320+ libname : Optional [str ] = None
321+ for name in files :
322+ if name .startswith ("libnccl.so" ):
323+ libname = name
324+ break
325+
326+ if libname is not None :
327+ path = os .path .join (dirname , libname )
328+ return path
329+ return None
330+
331+
297332class CommunicatorContext :
298333 """A context controlling collective communicator initialization and finalization."""
299334
@@ -309,18 +344,8 @@ def __init__(self, **args: _ArgVals) -> None:
309344
310345 try :
311346 # PyPI package of NCCL.
312- from nvidia .nccl import lib
313-
314- # There are two versions of nvidia-nccl, one is from PyPI, another one from
315- # nvidia-pyindex. We support only the first one as the second one is too old
316- # (2.9.8 as of writing).
317- if lib .__file__ is not None :
318- dirname : Optional [str ] = os .path .dirname (lib .__file__ )
319- else :
320- dirname = None
321-
322- if dirname :
323- path = os .path .join (dirname , "libnccl.so.2" )
347+ path = _find_nccl ()
348+ if path :
324349 self .args [key ] = path
325350 except ImportError :
326351 pass
0 commit comments