Skip to content

Commit ca7d301

Browse files
authored
MAINT use sklearn estimator tag in our test suite (#577)
1 parent e12aab3 commit ca7d301

File tree

11 files changed

+45
-34
lines changed

11 files changed

+45
-34
lines changed

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,9 @@ def _fit_resample(self, X, y):
119119
safe_indexing(y, sample_indices), sample_indices)
120120
return (safe_indexing(X, sample_indices),
121121
safe_indexing(y, sample_indices))
122+
123+
def _more_tags(self):
124+
# TODO: remove the str tag once the following PR is merged:
125+
# https://github.com/scikit-learn/scikit-learn/pull/14043
126+
return {'X_types': ['2darray', 'str', 'string'],
127+
'sample_indices': True}

imblearn/under_sampling/_prototype_generation/_cluster_centroids.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,6 @@ def _fit_resample(self, X, y):
170170
y_resampled = np.hstack(y_resampled)
171171

172172
return X_resampled, np.array(y_resampled, dtype=y.dtype)
173+
174+
def _more_tags(self):
175+
return {'sample_indices': False}

imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,6 @@ def _fit_resample(self, X, y):
220220
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
221221
idx_under)
222222
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
223+
224+
def _more_tags(self):
225+
return {'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def _fit_resample(self, X, y):
186186
idx_under)
187187
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
188188

189+
def _more_tags(self):
190+
return {'sample_indices': True}
191+
189192

190193
@Substitution(
191194
sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
@@ -377,6 +380,9 @@ def _fit_resample(self, X, y):
377380
return X_resampled, y_resampled, self.sample_indices_
378381
return X_resampled, y_resampled
379382

383+
def _more_tags(self):
384+
return {'sample_indices': True}
385+
380386

381387
@Substitution(
382388
sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
@@ -564,3 +570,6 @@ def _fit_resample(self, X, y):
564570
if self.return_indices:
565571
return X_resampled, y_resampled, self.sample_indices_
566572
return X_resampled, y_resampled
573+
574+
def _more_tags(self):
575+
return {'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,6 @@ def _fit_resample(self, X, y):
187187
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
188188
idx_under)
189189
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
190+
191+
def _more_tags(self):
192+
return {'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_nearmiss.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,6 @@ def _fit_resample(self, X, y):
293293
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
294294
idx_under)
295295
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
296+
297+
def _more_tags(self):
298+
return {'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,6 @@ def _fit_resample(self, X, y):
204204
self.sample_indices_)
205205
return (safe_indexing(X, self.sample_indices_),
206206
safe_indexing(y, self.sample_indices_))
207+
208+
def _more_tags(self):
209+
return {'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_one_sided_selection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,6 @@ def _fit_resample(self, X, y):
189189
if self.return_indices:
190190
return (X_cleaned, y_cleaned, self.sample_indices_)
191191
return X_cleaned, y_cleaned
192+
193+
def _more_tags(self):
194+
return {'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,9 @@ def _fit_resample(self, X, y):
135135
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
136136
idx_under)
137137
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
138+
139+
def _more_tags(self):
140+
# TODO: remove the str tag once the following PR is merged:
141+
# https://github.com/scikit-learn/scikit-learn/pull/14043
142+
return {'X_types': ['2darray', 'str', 'string'],
143+
'sample_indices': True}

imblearn/under_sampling/_prototype_selection/_tomek_links.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,6 @@ def _fit_resample(self, X, y):
166166
self.sample_indices_)
167167
return (safe_indexing(X, self.sample_indices_),
168168
safe_indexing(y, self.sample_indices_))
169+
170+
def _more_tags(self):
171+
return {'sample_indices': True}

0 commit comments

Comments
 (0)