Skip to content

Commit 0faa533

Browse files
committed
refactor: MPI standard/legacy ABI support
1 parent 9be5888 commit 0faa533

File tree

9 files changed

+206
-468
lines changed

9 files changed

+206
-468
lines changed

.cibw/install-mpi.sh

+24-40
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,28 @@
11
#!/bin/bash
22
set -euo pipefail
33

4-
MPI_ABI=${1:-mpi31-mpich}
4+
MPI_ABI=${1:-mpich}
55
MACHINE=${PROCESSOR_ARCHITECTURE:-$(uname -m)}
66
MPIARCH=${2:-$MACHINE}
77
MPIARCH=${MPIARCH/native/$MACHINE}
88

99
MPI_CHANNEL=conda-forge
10-
MPI_PACKAGE=${MPI_ABI#*-}
11-
MPI_VERSION="*.*"
12-
case "$(uname)"-"$MPIARCH"-"$MPI_ABI" in
13-
# Linux x86_64/aarch64
14-
Linux-*-mpi41-mpich) MPI_VERSION=4.2;; # >= 4.2
15-
Linux-*-mpi40-mpich) MPI_VERSION=4.1;; # >= 4.0
16-
Linux-*-mpi31-mpich) MPI_VERSION=3.4;; # >= 3.2
17-
Linux-*-mpi31-openmpi) MPI_VERSION=4.1;; # >= 3.1
18-
# Darwin x86_64
19-
Darwin-x86_64-mpi41-mpich) MPI_VERSION=4.2;;
20-
Darwin-x86_64-mpi40-mpich) MPI_VERSION=4.0;;
21-
Darwin-x86_64-mpi31-mpich) MPI_VERSION=3.2;;
22-
Darwin-x86_64-mpi31-openmpi) MPI_VERSION=3.1;;
23-
# Darwin arm64
24-
Darwin-arm64-mpi41-mpich) MPI_VERSION=4.2;;
25-
Darwin-arm64-mpi40-mpich) MPI_VERSION=4.0;;
26-
Darwin-arm64-mpi31-mpich) MPI_VERSION=3.4;;
27-
Darwin-arm64-mpi31-openmpi) MPI_VERSION=4.0;;
28-
# Windows AMD64
29-
*NT*-AMD64-mpi20-msmpi)
30-
MPI_CHANNEL=conda-forge
31-
MPI_PACKAGE=msmpi
32-
MPI_VERSION=10.1.1
33-
MPI_ROOT=${MPI_ROOT:-~/mpi}
34-
;;
35-
*NT*-AMD64-mpi31-impi)
36-
MPI_CHANNEL=conda-forge
37-
MPI_PACKAGE=impi-devel
38-
MPI_VERSION=2021.14.0
39-
MPI_ROOT=${MPI_ROOT:-~/mpi}
40-
;;
10+
MPI_PACKAGE=${MPI_ABI}
11+
case "$MPI_ABI" in
12+
mpich) MPI_VERSION=4 ;;
13+
openmpi) MPI_VERSION=5 ;;
14+
msmpi) MPI_VERSION=10.1.1 ;;
15+
impi) MPI_VERSION=2021.14.0 MPI_PACKAGE=impi-devel ;;
16+
esac
17+
case "$(uname)" in
18+
Linux|Darwin)
19+
MPI_ROOT=${MPI_ROOT:-/usr/local}
20+
sudo() { [ "$(id -u)" -eq 0 ] || set -- command sudo "$@"; "$@"; }
21+
;;
22+
*NT*)
23+
MPI_ROOT=${MPI_ROOT:-~/mpi}
24+
sudo() { "$@"; }
25+
;;
4126
esac
4227

4328
echo "Install Micromamba"
@@ -56,7 +41,7 @@ micromamba create --yes --always-copy \
5641
--prefix "$envdir" \
5742
--relocate-prefix "$MPI_ROOT" \
5843
"$MPI_PACKAGE"="$MPI_VERSION"
59-
test "$(uname)-$MPIARCH-$MPI_ABI" = "Linux-x86_64-mpi41-mpich" && \
44+
test "$(micromamba list --json --prefix "$envdir" attr)" != "[]" && \
6045
micromamba remove --yes --force --prefix "$envdir" attr
6146
micromamba list --prefix "$envdir"
6247

@@ -73,16 +58,15 @@ if [ "$MPI_PACKAGE" == openmpi ]; then
7358
files=("$envdir"/share/openmpi/mpi{cc,c++,fort}-wrapper-data.txt)
7459
sed -i.orig -E 's/(compiler)=(.*)-(.*)/\1=\3/' "${files[@]}"
7560
sed -i.orig "s%-Wl,-rpath,$MPI_ROOT/lib%%g" "${files[@]}"
61+
sed -i.orig "s%-Wl,-allow-shlib-undefined%%g" "${files[@]}"
62+
sed -i.orig "s%-Wl,-rpath -Wl,\${libdir}%%g" "${files[@]}"
7663
fi
7764

78-
echo
79-
# shellcheck disable=SC2015
80-
SUDO=$(test "$(id -u)" -ne 0 && command -v sudo || true)
8165
echo "Copying MPI to $MPI_ROOT"
82-
$SUDO mkdir -p "$MPI_ROOT"
83-
$SUDO cp -RP "$envdir"/. "$MPI_ROOT"
66+
sudo mkdir -p "$MPI_ROOT"
67+
sudo cp -RP "$envdir"/. "$MPI_ROOT"
8468
echo "Rebuild dynamic linker cache"
85-
$SUDO "$(command -v ldconfig || echo true)"
69+
sudo "$(command -v ldconfig || echo true)"
8670

8771
echo "Display MPI information"
8872
if [ "$MPI_PACKAGE" == mpich ]; then mpichversion; fi
@@ -107,7 +91,7 @@ if [ "$(uname)" == Darwin ] && [ "$MPIARCH" != "$MACHINE" ]; then
10791
libs=$(find "$envdir2/lib" -type f -name 'lib*.dylib')
10892
for lib in $libs; do
10993
lib=$(basename "$lib")
110-
$SUDO lipo -create \
94+
sudo lipo -create \
11195
"$envdir1/lib/$lib" \
11296
"$envdir2/lib/$lib" \
11397
-output "$MPI_ROOT/lib/$lib"

.cibw/mpi4py_mpiabi.py

+43-64
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import importlib.machinery
55
import importlib.util
66
import os
7-
import re
87
import sys
98
import warnings
109

@@ -196,88 +195,68 @@ def _get_mpiabi_from_libmpi(libmpi=None):
196195
# pylint: disable=import-outside-toplevel
197196
import ctypes as ct
198197
lib = _dlopen_libmpi(libmpi)
199-
lib.MPI_Get_version.restype = ct.c_int
200-
lib.MPI_Get_version.argtypes = [ct.POINTER(ct.c_int)] * 2
201-
vmajor, vminor = ct.c_int(0), ct.c_int(0)
202-
ierr = lib.MPI_Get_version(ct.byref(vmajor), ct.byref(vminor))
203-
if ierr: # pragma: no cover
204-
message = f"MPI_Get_version [ierr={ierr}]"
205-
raise RuntimeError(message)
206-
vmajor, vminor = vmajor.value, vminor.value
198+
abi_get_version = getattr(lib, "MPI_Abi_get_version", None)
199+
if abi_get_version:
200+
abi_get_version.restype = ct.c_int
201+
abi_get_version.argtypes = [ct.POINTER(ct.c_int)] * 2
202+
abi_major, abi_minor = ct.c_int(0), ct.c_int(0)
203+
ierr = abi_get_version(ct.byref(abi_major), ct.byref(abi_minor))
204+
if ierr: # pragma: no cover
205+
message = f"MPI_Abi_get_version [ierr={ierr}]"
206+
raise RuntimeError(message)
207+
if abi_major.value > 0:
208+
return "mpiabi"
207209
if os.name == "posix":
208210
openmpi = hasattr(lib, "ompi_mpi_comm_self")
209-
family = "openmpi" if openmpi else "mpich"
211+
mpiabi = "openmpi" if openmpi else "mpich"
210212
else:
211213
msmpi = hasattr(lib, "MSMPI_Get_version")
212-
family = "msmpi" if msmpi else "impi"
213-
return (vmajor, vminor), family
214-
215-
216-
_pattern = re.compile(
217-
r"""
218-
\.? (
219-
(mpi)? (\W|_)* (
220-
(?P<vmajor>\d+) \.?
221-
(?P<vminor>\d)
222-
) )? (\W|_)*
223-
(?P<family>\w+)?
224-
""",
225-
re.VERBOSE | re.IGNORECASE,
226-
)
227-
_pattern_strict = re.compile(
228-
r"mpi(?P<vmajor>\d+)(?P<vminor>\d)(-(?P<family>\w+))?",
229-
re.VERBOSE | re.IGNORECASE,
230-
)
231-
232-
233-
def _get_mpiabi_from_string(string, strict=False):
234-
pattern = _pattern_strict if strict else _pattern
235-
match = pattern.match(string)
236-
if match is None:
237-
message = f"invalid MPI ABI string {string!r}"
238-
raise RuntimeError(message)
239-
vmajor = match.group("vmajor") or "4"
240-
vminor = match.group("vminor") or "0"
241-
family = match.group("family") or ""
242-
return (int(vmajor), int(vminor)), family.lower() or None
214+
mpiabi = "msmpi" if msmpi else "impi"
215+
return mpiabi
216+
217+
218+
def _get_mpiabi_from_string(string):
219+
table = {ord(c): "" for c in " -_"}
220+
mpiabi = string.translate(table).lower()
221+
if os.name == "posix":
222+
if mpiabi == "impi":
223+
mpiabi = "mpich"
224+
else:
225+
if mpiabi == "mpich":
226+
mpiabi = "impi"
227+
return mpiabi
243228

244229

245230
def _get_mpiabi():
246-
version = getattr(_get_mpiabi, "version", None)
247-
family = getattr(_get_mpiabi, "family", None)
248-
if version is None:
249-
string = os.environ.get("MPI4PY_MPIABI") or None
231+
mpiabi = getattr(_get_mpiabi, "mpiabi", None)
232+
if mpiabi is None:
233+
mpiabi = os.environ.get("MPI4PY_MPIABI") or None
250234
libmpi = os.environ.get("MPI4PY_LIBMPI") or None
251-
if string is not None:
252-
version, family = _get_mpiabi_from_string(string)
235+
if mpiabi is not None:
236+
mpiabi = _get_mpiabi_from_string(mpiabi)
253237
else:
254-
version, family = _get_mpiabi_from_libmpi(libmpi)
255-
_get_mpiabi.version = version # pyright: ignore
256-
_get_mpiabi.family = family # pyright: ignore
257-
return version, family
238+
mpiabi = _get_mpiabi_from_libmpi(libmpi)
239+
_get_mpiabi.mpiabi = mpiabi # pyright: ignore
240+
return mpiabi
258241

259242

260-
_registry = {} # type: dict[str, dict[str, list[tuple[int, int]]]]
243+
_registry = {} # type: dict[str, list[str]]
261244

262245

263246
def _register(module, mpiabi):
264-
version, family = _get_mpiabi_from_string(mpiabi, strict=True)
265-
versions = _registry.setdefault(module, {}).setdefault(family, [])
266-
versions.append(version)
267-
versions.sort()
247+
mpiabi = _get_mpiabi_from_string(mpiabi)
248+
registered = _registry.setdefault(module, [])
249+
if mpiabi not in registered:
250+
registered.append(mpiabi)
268251

269252

270253
def _get_mpiabi_suffix(module):
271254
if module not in _registry:
272255
return None
273-
version, family = _get_mpiabi()
274-
versions = _registry[module].get(family)
275-
if versions:
276-
vmin, vmax = versions[0], versions[-1]
277-
version = max(vmin, min(version, vmax))
278-
vmajor, vminor = version
279-
family_tag = f"-{family}" if family else ""
280-
return f".mpi{vmajor}{vminor}{family_tag}"
256+
mpiabi = _get_mpiabi()
257+
if mpiabi not in _registry[module]:
258+
return None
259+
return f".{mpiabi}" if mpiabi else ""
281260

282261

283262
class _Finder:

.cibw/run-tests-conda.sh

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
set -euo pipefail
3+
4+
mpich=("4.3" "4.1" "3.4")
5+
openmpi=("5.0" "4.1")
6+
impi=("2021.14.1" "2021.10.0")
7+
msmpi=("10.1.1")
8+
9+
mpi="$1"
10+
mpipackage="$mpi"
11+
mpiversion="${mpi}[@]"
12+
test "$mpi" = impi && mpipackage=impi_rt
13+
14+
CONDA=$(command -v micromamba || command -v mamba || command -v conda)
15+
scriptdir=$(dirname "${BASH_SOURCE[0]}")
16+
for version in "${!mpiversion}"; do
17+
echo "::group::$mpipackage=$version"
18+
"$CONDA" install -qy "$mpipackage=$version"
19+
"$CONDA" list
20+
"$scriptdir"/run-tests-mpi.sh
21+
echo "::endgroup::"
22+
done

.cibw/run-tests-mpi.sh

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
set -euo pipefail
3+
4+
: "${MPIEXEC=mpiexec}"
5+
: "${PYTHON=python}"
6+
7+
{ set -x; } 2>/dev/null
8+
"$PYTHON" -m mpi4py --prefix
9+
"$PYTHON" -m mpi4py --version
10+
"$PYTHON" -m mpi4py --mpi-library
11+
"$PYTHON" -m mpi4py --mpi-std-version
12+
"$PYTHON" -m mpi4py --mpi-lib-version | { head -n 1; } 2>/dev/null
13+
"$MPIEXEC" -n 2 "$PYTHON" -m mpi4py.bench ringtest
14+
"$MPIEXEC" -n 2 "$PYTHON" -m mpi4py.bench helloworld
15+
{ set +x; } 2>/dev/null

.cibw/run-tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ set -x
2020
python --version
2121
python -m mpi4py --prefix
2222
python -m mpi4py --version
23+
python -m mpi4py --mpi-library
2324
python -m mpi4py --mpi-std-version
2425
python -m mpi4py --mpi-lib-version

.cibw/setup-build.py renamed to .cibw/setup-matrix.py

+38-22
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,12 @@ def pp3(y_min=8, y_max=10):
3737
}
3838

3939
MPI_ABI_POSIX = [
40-
f"mpi{std}-{mpi}"
41-
for std in (31, 40, 41)
42-
for mpi in ("mpich", "openmpi")
43-
if (std, mpi) not in (
44-
(40, "openmpi"),
45-
(41, "openmpi"),
46-
)
40+
"mpich",
41+
"openmpi",
4742
]
4843
MPI_ABI_WINNT = [
49-
f"mpi{std}-{mpi}"
50-
for std, mpi in (
51-
(20, "msmpi"),
52-
(31, "impi"),
53-
)
44+
"impi",
45+
"msmpi",
5446
]
5547
MPI_ABI = {
5648
"Linux": MPI_ABI_POSIX[:],
@@ -65,12 +57,12 @@ def pp3(y_min=8, y_max=10):
6557
None: "ubuntu-latest"
6658
},
6759
"macOS": {
68-
"arm64": "macos-14",
60+
"arm64": "macos-15",
6961
"x86_64": "macos-13",
7062
None: "macos-latest"
7163
},
7264
"Windows": {
73-
"AMD64": "windows-2019",
65+
"AMD64": "windows-2022",
7466
None: "windows-latest"
7567
},
7668
}
@@ -107,28 +99,52 @@ def pp3(y_min=8, y_max=10):
10799

108100
matrix_build = [
109101
{
110-
"os": os, "arch": arch,
111-
"py": py, "mpi-abi": mpi_abi,
102+
"os": os,
103+
"arch": arch,
104+
"py": py,
105+
"mpi-abi": mpi_abi,
112106
"runner": GHA_RUNNER[os][arch],
113107
}
114108
for os in os_arch_py
115109
for arch in os_arch_py[os]
116110
for py in os_arch_py[os][arch]
117111
for mpi_abi in MPI_ABI[os]
118112
]
113+
119114
matrix_merge = [
120115
{
121-
"os": os, "arch": arch,
116+
"os": os,
117+
"arch": arch,
122118
"runner": GHA_RUNNER[os][None],
123119
}
124120
for os in os_arch_py
125121
for arch in os_arch_py[os]
126122
]
127-
os_arch_list = [
128-
"{os}-{arch}".format(**row)
129-
for row in matrix_merge
130-
]
123+
124+
matrix_test = []
125+
for build in matrix_build:
126+
os = build["os"]
127+
arch = build["arch"]
128+
pytag = build["py"]
129+
mpi_abi = build["mpi-abi"]
130+
runner = GHA_RUNNER[os][arch]
131+
if pytag.startswith("pp"):
132+
continue
133+
pyver = pytag[2:3] + "." + pytag[3:]
134+
mpilist = [mpi_abi]
135+
if (os, arch, mpi_abi) == ("Linux", "x86_64", "mpich"):
136+
mpilist.insert(0, "impi")
137+
matrix_test += [
138+
{
139+
"mpi": mpi,
140+
"py": pyver,
141+
"os": os,
142+
"arch": arch,
143+
"runner": runner,
144+
}
145+
for mpi in mpilist
146+
]
131147

132148
print(f"matrix-build={json.dumps(matrix_build)}")
133149
print(f"matrix-merge={json.dumps(matrix_merge)}")
134-
print(f"os-arch-list={json.dumps(os_arch_list)}")
150+
print(f"matrix-test={json.dumps(matrix_test)}")

0 commit comments

Comments
 (0)