Skip to content

Commit 43644f8

Browse files
committed
Fix loading nccl 2.28.
1 parent e739915 commit 43644f8

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

python-package/xgboost/collective.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,41 @@ def signal_error() -> None:
294294
_check_call(_LIB.XGCommunicatorSignalError())
295295

296296

297+
def _find_nccl() -> Optional[str]:
298+
from nvidia.nccl import lib
299+
300+
# There are two versions of nvidia-nccl, one is from PyPI, another one from
301+
# nvidia-pyindex. We support only the first one as the second one is too old (2.9.8
302+
# as of writing).
303+
#
304+
# nccl 2.28 doesn't have the __file__ attribute, we use the namespace path instead.
305+
if lib.__file__ is not None:
306+
dirname: Optional[str] = os.path.dirname(lib.__file__)
307+
elif hasattr(lib, "__path__") and len(lib.__path__) > 0:
308+
dirname = lib.__path__[0]
309+
else:
310+
dirname = None
311+
if not dirname:
312+
return None
313+
314+
# Find the first shared object in the lib directory.
315+
files = os.listdir(dirname)
316+
if not files:
317+
return None
318+
assert len(files) >= 1
319+
320+
libname: Optional[str] = None
321+
for name in files:
322+
if name.startswith("libnccl.so"):
323+
libname = name
324+
break
325+
326+
if libname is not None:
327+
path = os.path.join(dirname, libname)
328+
return path
329+
return None
330+
331+
297332
class CommunicatorContext:
298333
"""A context controlling collective communicator initialization and finalization."""
299334

@@ -309,18 +344,8 @@ def __init__(self, **args: _ArgVals) -> None:
309344

310345
try:
311346
# PyPI package of NCCL.
312-
from nvidia.nccl import lib
313-
314-
# There are two versions of nvidia-nccl, one is from PyPI, another one from
315-
# nvidia-pyindex. We support only the first one as the second one is too old
316-
# (2.9.8 as of writing).
317-
if lib.__file__ is not None:
318-
dirname: Optional[str] = os.path.dirname(lib.__file__)
319-
else:
320-
dirname = None
321-
322-
if dirname:
323-
path = os.path.join(dirname, "libnccl.so.2")
347+
path = _find_nccl()
348+
if path:
324349
self.args[key] = path
325350
except ImportError:
326351
pass

0 commit comments

Comments
 (0)