Skip to content

Commit dff4670

Browse files
authored
EHN: Parallelisation for SMOTEENN and SMOTETomek (#547)
2 parents 4559071 + aba9470 commit dff4670

File tree

5 files changed

+69
-6
lines changed

5 files changed

+69
-6
lines changed

doc/whats_new/v0.5.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ Documentation
1717
:class:`imblearn.over_sampling.SVMSMOTE` in the API documenation.
1818
:issue:`530` by :user:`Guillaume Lemaitre <glemaitre>`.
1919

20+
Enhancement
21+
...........
22+
23+
- Add Parallelisation for SMOTEENN and SMOTETomek.
24+
:issue:`547` by :user:`Michael Hsieh <Microsheep>`.
25+
2026
Maintenance
2127
...........
2228

imblearn/combine/_smote_enn.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ class SMOTEENN(BaseSampler):
3939
a :class:`imblearn.over_sampling.SMOTE` object with default parameters
4040
will be given.
4141
42+
enn : object, optional (default=\
43+
EditedNearestNeighbours(sampling_strategy='all'))
44+
The :class:`imblearn.under_sampling.EditedNearestNeighbours` object
45+
to use. If not given, a
46+
:class:`imblearn.under_sampling.EditedNearestNeighbours` object with
47+
sampling strategy='all' will be given.
48+
49+
n_jobs : int, optional (default=1)
50+
The number of threads to open if possible.
51+
Will not apply to smote and enn given by the user.
52+
4253
ratio : str, dict, or callable
4354
.. deprecated:: 0.4
4455
Use the parameter ``sampling_strategy`` instead. It will be removed
@@ -86,12 +97,14 @@ def __init__(self,
8697
random_state=None,
8798
smote=None,
8899
enn=None,
100+
n_jobs=1,
89101
ratio=None):
90102
super(SMOTEENN, self).__init__()
91103
self.sampling_strategy = sampling_strategy
92104
self.random_state = random_state
93105
self.smote = smote
94106
self.enn = enn
107+
self.n_jobs = n_jobs
95108
self.ratio = ratio
96109

97110
def _validate_estimator(self):
@@ -107,6 +120,7 @@ def _validate_estimator(self):
107120
self.smote_ = SMOTE(
108121
sampling_strategy=self.sampling_strategy,
109122
random_state=self.random_state,
123+
n_jobs=self.n_jobs,
110124
ratio=self.ratio)
111125

112126
if self.enn is not None:
@@ -117,7 +131,9 @@ def _validate_estimator(self):
117131
' Got {} instead.'.format(type(self.enn)))
118132
# Otherwise create a default EditedNearestNeighbours
119133
else:
120-
self.enn_ = EditedNearestNeighbours(sampling_strategy='all')
134+
self.enn_ = EditedNearestNeighbours(
135+
sampling_strategy='all',
136+
n_jobs=self.n_jobs)
121137

122138
def _fit_resample(self, X, y):
123139
self._validate_estimator()

imblearn/combine/_smote_tomek.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@ class SMOTETomek(BaseSampler):
4141
a :class:`imblearn.over_sampling.SMOTE` object with default parameters
4242
will be given.
4343
44-
tomek : object, optional (default=Tomek())
45-
The :class:`imblearn.under_sampling.Tomek` object to use. If not given,
46-
a :class:`imblearn.under_sampling.Tomek` object with default parameters
47-
will be given.
44+
tomek : object, optional (default=TomekLinks(sampling_strategy='all'))
45+
The :class:`imblearn.under_sampling.TomekLinks` object to use. If not
46+
given, a :class:`imblearn.under_sampling.TomekLinks` object with
47+
sampling strategy='all' will be given.
48+
49+
n_jobs : int, optional (default=1)
50+
The number of threads to open if possible.
51+
Will not apply to smote and tomek given by the user.
4852
4953
ratio : str, dict, or callable
5054
.. deprecated:: 0.4
@@ -94,12 +98,14 @@ def __init__(self,
9498
random_state=None,
9599
smote=None,
96100
tomek=None,
101+
n_jobs=1,
97102
ratio=None):
98103
super(SMOTETomek, self).__init__()
99104
self.sampling_strategy = sampling_strategy
100105
self.random_state = random_state
101106
self.smote = smote
102107
self.tomek = tomek
108+
self.n_jobs = n_jobs
103109
self.ratio = ratio
104110

105111
def _validate_estimator(self):
@@ -116,6 +122,7 @@ def _validate_estimator(self):
116122
self.smote_ = SMOTE(
117123
sampling_strategy=self.sampling_strategy,
118124
random_state=self.random_state,
125+
n_jobs=self.n_jobs,
119126
ratio=self.ratio)
120127

121128
if self.tomek is not None:
@@ -126,7 +133,9 @@ def _validate_estimator(self):
126133
'Got {} instead.'.format(type(self.tomek)))
127134
# Otherwise create a default TomekLinks
128135
else:
129-
self.tomek_ = TomekLinks(sampling_strategy='all')
136+
self.tomek_ = TomekLinks(
137+
sampling_strategy='all',
138+
n_jobs=self.n_jobs)
130139

131140
def _fit_resample(self, X, y):
132141
self._validate_estimator()

imblearn/combine/tests/test_smote_enn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ def test_validate_estimator_default():
9898
assert_array_equal(y_resampled, y_gt)
9999

100100

101+
def test_parallelisation():
102+
# Check if default job count is 1
103+
smt = SMOTEENN(random_state=RND_SEED)
104+
smt._validate_estimator()
105+
assert smt.n_jobs == 1
106+
assert smt.smote_.n_jobs == 1
107+
assert smt.enn_.n_jobs == 1
108+
109+
# Check if job count is set
110+
smt = SMOTEENN(random_state=RND_SEED, n_jobs=8)
111+
smt._validate_estimator()
112+
assert smt.n_jobs == 8
113+
assert smt.smote_.n_jobs == 8
114+
assert smt.enn_.n_jobs == 8
115+
116+
101117
@pytest.mark.parametrize(
102118
"smote_params, err_msg",
103119
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),

imblearn/combine/tests/test_smote_tomek.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ def test_validate_estimator_default():
104104
assert_array_equal(y_resampled, y_gt)
105105

106106

107+
def test_parallelisation():
108+
# Check if default job count is 1
109+
smt = SMOTETomek(random_state=RND_SEED)
110+
smt._validate_estimator()
111+
assert smt.n_jobs == 1
112+
assert smt.smote_.n_jobs == 1
113+
assert smt.tomek_.n_jobs == 1
114+
115+
# Check if job count is set
116+
smt = SMOTETomek(random_state=RND_SEED, n_jobs=8)
117+
smt._validate_estimator()
118+
assert smt.n_jobs == 8
119+
assert smt.smote_.n_jobs == 8
120+
assert smt.tomek_.n_jobs == 8
121+
122+
107123
@pytest.mark.parametrize(
108124
"smote_params, err_msg",
109125
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),

0 commit comments

Comments
 (0)