Skip to content

Commit 4d675ec

Browse files
committed
Removed extra code for sample_shift.
1 parent 548d824 commit 4d675ec

File tree

1 file changed

+6
-19
lines changed

1 file changed

+6
-19
lines changed

Diff for: sklbench/datasets/transformer.py

+6-19
Original file line numberDiff line numberDiff line change
@@ -117,34 +117,21 @@ def split_and_transform_data(bench_case, data, data_description):
117117
if distributed_split == "sample_shift":
118118
from mpi4py import MPI
119119

120-
comm = MPI.COMM_WORLD
121-
rank = comm.Get_rank()
122-
size = comm.Get_size()
123-
124-
n_train = len(x_train)
125-
n_test = len(x_test)
126-
127-
train_start = 0
128-
train_end = n_train
129-
test_start = 0
130-
test_end = n_test
131-
120+
rank = MPI.COMM_WORLD.Get_rank()
132121
adjust_number = (math.sqrt(rank) * 0.003) + 1
133122

134123
if "y" in data:
135124
x_train, y_train = (
136-
x_train[train_start:train_end] * adjust_number,
137-
y_train[train_start:train_end],
125+
x_train * adjust_number,
126+
y_train,
138127
)
139128

140129
x_test, y_test = (
141-
x_test[test_start:test_end] * adjust_number,
142-
y_test[test_start:test_end],
130+
x_test * adjust_number,
131+
y_test,
143132
)
144133
else:
145-
x_train = x_train[train_start:train_end]
146-
147-
x_test = x_test[test_start:test_end] * adjust_number
134+
x_test = x_test * adjust_number
148135

149136
elif distributed_split == "rank_based":
150137
from mpi4py import MPI

0 commit comments

Comments
 (0)