Skip to content

Commit a937963

Browse files
committed
fixed mpi4py
1 parent edceece commit a937963

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

Diff for: envs/conda-env-rapids.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ dependencies:
1919
- psutil
2020
- requests
2121
- py-cpuinfo
22-
- mpi4py
22+

Diff for: envs/conda-env-sklearn.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ dependencies:
2121
- psutil
2222
- requests
2323
- py-cpuinfo
24-
- mpi4py
24+

Diff for: sklbench/datasets/transformer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import numpy as np
2121
import pandas as pd
22-
from mpi4py import MPI
2322
from scipy.sparse import csr_matrix
2423
from sklearn.model_selection import train_test_split
2524

@@ -117,6 +116,8 @@ def split_and_transform_data(bench_case, data, data_description):
117116
)
118117

119118
if distributed_split == "sample_shift":
119+
from mpi4py import MPI
120+
120121
comm = MPI.COMM_WORLD
121122
rank = comm.Get_rank()
122123
size = comm.Get_size()
@@ -147,6 +148,8 @@ def split_and_transform_data(bench_case, data, data_description):
147148
x_test = x_test[test_start:test_end] * adjust_number
148149

149150
elif distributed_split == "rank_based" or knn_split_train:
151+
from mpi4py import MPI
152+
150153
comm = MPI.COMM_WORLD
151154
rank = comm.Get_rank()
152155
size = comm.Get_size()

0 commit comments

Comments
 (0)