Skip to content

Commit a0ac84d

Browse files
MattEdingglemaitre
authored andcommitted
ENH Vectorized ADASYN (#649)
* vectorized adasyn; fixed adasyn module docstring; todo: update unit tests due to random state changes * fix indentation error * fixed row selection indices; fixed n_samples to work with non-ints * fixed row & col shape occassional mismatch due to rounding in algorithm * update unit tests to reflect random state changes
1 parent d8472f4 commit a0ac84d

File tree

3 files changed

+41
-77
lines changed

3 files changed

+41
-77
lines changed

imblearn/over_sampling/_adasyn.py

+31-66
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Class to perform random over-sampling."""
1+
"""Class to perform over-sampling using ADASYN."""
22

33
# Authors: Guillaume Lemaitre <[email protected]>
44
# Christos Aridas
@@ -104,8 +104,8 @@ def _fit_resample(self, X, y):
104104
self._validate_estimator()
105105
random_state = check_random_state(self.random_state)
106106

107-
X_resampled = X.copy()
108-
y_resampled = y.copy()
107+
X_resampled = [X.copy()]
108+
y_resampled = [y.copy()]
109109

110110
for class_sample, n_samples in self.sampling_strategy_.items():
111111
if n_samples == 0:
@@ -114,13 +114,12 @@ def _fit_resample(self, X, y):
114114
X_class = _safe_indexing(X, target_class_indices)
115115

116116
self.nn_.fit(X)
117-
_, nn_index = self.nn_.kneighbors(X_class)
117+
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
118118
# The ratio is computed using a one-vs-rest manner. Using majority
119119
# in multi-class would lead to slightly different results at the
120120
# cost of introducing a new parameter.
121-
ratio_nn = np.sum(y[nn_index[:, 1:]] != class_sample, axis=1) / (
122-
self.nn_.n_neighbors - 1
123-
)
121+
n_neighbors = self.nn_.n_neighbors - 1
122+
ratio_nn = np.sum(y[nns] != class_sample, axis=1) / n_neighbors
124123
if not np.sum(ratio_nn):
125124
raise RuntimeError(
126125
"Not any neigbours belong to the majority"
@@ -131,7 +130,9 @@ def _fit_resample(self, X, y):
131130
)
132131
ratio_nn /= np.sum(ratio_nn)
133132
n_samples_generate = np.rint(ratio_nn * n_samples).astype(int)
134-
if not np.sum(n_samples_generate):
133+
# rounding may cause new amount for n_samples
134+
n_samples = np.sum(n_samples_generate)
135+
if not n_samples:
135136
raise ValueError(
136137
"No samples will be generated with the"
137138
" provided ratio settings."
@@ -140,66 +141,30 @@ def _fit_resample(self, X, y):
140141
# the nearest neighbors need to be fitted only on the current class
141142
# to find the class NN to generate new samples
142143
self.nn_.fit(X_class)
143-
_, nn_index = self.nn_.kneighbors(X_class)
144+
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
144145

145-
if sparse.issparse(X):
146-
row_indices, col_indices, samples = [], [], []
147-
n_samples_generated = 0
148-
for x_i, x_i_nn, num_sample_i in zip(
149-
X_class, nn_index, n_samples_generate
150-
):
151-
if num_sample_i == 0:
152-
continue
153-
nn_zs = random_state.randint(
154-
1, high=self.nn_.n_neighbors, size=num_sample_i
155-
)
156-
steps = random_state.uniform(size=len(nn_zs))
157-
if x_i.nnz:
158-
for step, nn_z in zip(steps, nn_zs):
159-
sample = x_i + step * (
160-
X_class[x_i_nn[nn_z], :] - x_i
161-
)
162-
row_indices += [n_samples_generated] * len(
163-
sample.indices
164-
)
165-
col_indices += sample.indices.tolist()
166-
samples += sample.data.tolist()
167-
n_samples_generated += 1
168-
X_new = sparse.csr_matrix(
169-
(samples, (row_indices, col_indices)),
170-
[np.sum(n_samples_generate), X.shape[1]],
171-
dtype=X.dtype,
172-
)
173-
y_new = np.array(
174-
[class_sample] * np.sum(n_samples_generate), dtype=y.dtype
175-
)
176-
else:
177-
x_class_gen = []
178-
for x_i, x_i_nn, num_sample_i in zip(
179-
X_class, nn_index, n_samples_generate
180-
):
181-
if num_sample_i == 0:
182-
continue
183-
nn_zs = random_state.randint(
184-
1, high=self.nn_.n_neighbors, size=num_sample_i
185-
)
186-
steps = random_state.uniform(size=len(nn_zs))
187-
x_class_gen.append(
188-
[
189-
x_i + step * (X_class[x_i_nn[nn_z], :] - x_i)
190-
for step, nn_z in zip(steps, nn_zs)
191-
]
192-
)
193-
194-
X_new = np.concatenate(x_class_gen).astype(X.dtype)
195-
y_new = np.array(
196-
[class_sample] * np.sum(n_samples_generate), dtype=y.dtype
197-
)
146+
enumerated_class_indices = np.arange(len(target_class_indices))
147+
rows = np.repeat(enumerated_class_indices, n_samples_generate)
148+
cols = random_state.choice(n_neighbors, size=n_samples)
149+
diffs = X_class[nns[rows, cols]] - X_class[rows]
150+
steps = random_state.uniform(size=(n_samples, 1))
198151

199-
if sparse.issparse(X_new):
200-
X_resampled = sparse.vstack([X_resampled, X_new])
152+
if sparse.issparse(X):
153+
sparse_func = type(X).__name__
154+
steps = getattr(sparse, sparse_func)(steps)
155+
X_new = X_class[rows] + steps.multiply(diffs)
201156
else:
202-
X_resampled = np.vstack((X_resampled, X_new))
203-
y_resampled = np.hstack((y_resampled, y_new))
157+
X_new = X_class[rows] + steps * diffs
158+
159+
X_new = X_new.astype(X.dtype)
160+
y_new = np.full(n_samples, fill_value=class_sample, dtype=y.dtype)
161+
X_resampled.append(X_new)
162+
y_resampled.append(y_new)
163+
164+
if sparse.issparse(X):
165+
X_resampled = sparse.vstack(X_resampled, format=X.format)
166+
else:
167+
X_resampled = np.vstack(X_resampled)
168+
y_resampled = np.hstack(y_resampled)
204169

205170
return X_resampled, y_resampled

imblearn/over_sampling/_smote.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _make_samples(
9898
"""
9999
random_state = check_random_state(self.random_state)
100100
samples_indices = random_state.randint(
101-
low=0, high=len(nn_num.flatten()), size=n_samples
101+
low=0, high=nn_num.size, size=n_samples
102102
)
103103

104104
# np.newaxis for backwards compatability with random_state
@@ -731,13 +731,12 @@ def _fit_resample(self, X, y):
731731
X_resampled.append(X_new)
732732
y_resampled.append(y_new)
733733

734-
if sparse.issparse(X_new):
734+
if sparse.issparse(X):
735735
X_resampled = sparse.vstack(X_resampled, format=X.format)
736736
else:
737737
X_resampled = np.vstack(X_resampled)
738738
y_resampled = np.hstack(y_resampled)
739739

740-
741740
return X_resampled, y_resampled
742741

743742

imblearn/over_sampling/tests/test_adasyn.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def test_ada_fit_resample():
7272
[-0.41635887, -0.38299653],
7373
[0.08711622, 0.93259929],
7474
[1.70580611, -0.11219234],
75-
[0.94899098, -0.30508981],
76-
[0.28204936, -0.13953426],
77-
[1.58028868, -0.04089947],
78-
[0.66117333, -0.28009063],
75+
[0.88161986, -0.2829741],
76+
[0.35681689, -0.18814597],
77+
[1.4148276, 0.05308106],
78+
[0.3136591, -0.31327875],
7979
]
8080
)
8181
y_gt = np.array(
@@ -136,10 +136,10 @@ def test_ada_fit_resample_nn_obj():
136136
[-0.41635887, -0.38299653],
137137
[0.08711622, 0.93259929],
138138
[1.70580611, -0.11219234],
139-
[0.94899098, -0.30508981],
140-
[0.28204936, -0.13953426],
141-
[1.58028868, -0.04089947],
142-
[0.66117333, -0.28009063],
139+
[0.88161986, -0.2829741],
140+
[0.35681689, -0.18814597],
141+
[1.4148276, 0.05308106],
142+
[0.3136591, -0.31327875],
143143
]
144144
)
145145
y_gt = np.array(

0 commit comments

Comments
 (0)