@@ -294,6 +294,40 @@ def signal_error() -> None:
294294 _check_call (_LIB .XGCommunicatorSignalError ())
295295
296296
297+ def _find_nccl () -> Optional [str ]:
298+ from nvidia .nccl import lib
299+
300+ # There are two versions of nvidia-nccl, one is from PyPI, another one from
301+ # nvidia-pyindex. We support only the first one as the second one is too old (2.9.8
302+ # as of writing).
303+ #
304+ # nccl 2.28 doesn't have the __file__ attribute, we use the namespace path instead.
305+ if lib .__file__ is not None :
306+ dirname : Optional [str ] = os .path .dirname (lib .__file__ )
307+ elif hasattr (lib , "__path__" ) and len (lib .__path__ ) > 0 :
308+ dirname = lib .__path__ [0 ]
309+ else :
310+ dirname = None
311+ if not dirname :
312+ return None
313+
314+ # Find the first shared object in the lib directory.
315+ files = os .listdir (dirname )
316+ if not files :
317+ return None
318+
319+ libname : Optional [str ] = None
320+ for name in files :
321+ if name .startswith ("libnccl.so" ):
322+ libname = name
323+ break
324+
325+ if libname is not None :
326+ path = os .path .join (dirname , libname )
327+ return path
328+ return None
329+
330+
297331class CommunicatorContext :
298332 """A context controlling collective communicator initialization and finalization."""
299333
@@ -309,18 +343,8 @@ def __init__(self, **args: _ArgVals) -> None:
309343
310344 try :
311345 # PyPI package of NCCL.
312- from nvidia .nccl import lib
313-
314- # There are two versions of nvidia-nccl, one is from PyPI, another one from
315- # nvidia-pyindex. We support only the first one as the second one is too old
316- # (2.9.8 as of writing).
317- if lib .__file__ is not None :
318- dirname : Optional [str ] = os .path .dirname (lib .__file__ )
319- else :
320- dirname = None
321-
322- if dirname :
323- path = os .path .join (dirname , "libnccl.so.2" )
346+ path = _find_nccl ()
347+ if path :
324348 self .args [key ] = path
325349 except ImportError :
326350 pass
0 commit comments