File tree Expand file tree Collapse file tree 11 files changed +27
-15
lines changed
Expand file tree Collapse file tree 11 files changed +27
-15
lines changed Original file line number Diff line number Diff line change @@ -56,6 +56,9 @@ Bug fixes
5656 generating new samples. :issue: `354 ` by :user: `Guillaume Lemaitre
5757 <glemaitre> `.
5858
59+ - Force to clone scikit-learn estimator passed as attributes to samplers.
60+ :issue: `446 ` by :user: `Guillaume Lemaitre <glemaitre> `.
61+
5962Maintenance
6063...........
6164
Original file line number Diff line number Diff line change 99import logging
1010import warnings
1111
12+ from sklearn .base import clone
1213from sklearn .utils import check_X_y
1314
1415from ..base import SamplerMixin
@@ -103,7 +104,7 @@ def _validate_estimator(self):
103104 "Private function to validate SMOTE and ENN objects"
104105 if self .smote is not None :
105106 if isinstance (self .smote , SMOTE ):
106- self .smote_ = self .smote
107+ self .smote_ = clone ( self .smote )
107108 else :
108109 raise ValueError ('smote needs to be a SMOTE object.'
109110 'Got {} instead.' .format (type (self .smote )))
@@ -116,7 +117,7 @@ def _validate_estimator(self):
116117
117118 if self .enn is not None :
118119 if isinstance (self .enn , EditedNearestNeighbours ):
119- self .enn_ = self .enn
120+ self .enn_ = clone ( self .enn )
120121 else :
121122 raise ValueError ('enn needs to be an EditedNearestNeighbours.'
122123 ' Got {} instead.' .format (type (self .enn )))
Original file line number Diff line number Diff line change 1010import logging
1111import warnings
1212
13+ from sklearn .base import clone
1314from sklearn .utils import check_X_y
1415
1516from ..base import SamplerMixin
@@ -111,7 +112,7 @@ def _validate_estimator(self):
111112
112113 if self .smote is not None :
113114 if isinstance (self .smote , SMOTE ):
114- self .smote_ = self .smote
115+ self .smote_ = clone ( self .smote )
115116 else :
116117 raise ValueError ('smote needs to be a SMOTE object.'
117118 'Got {} instead.' .format (type (self .smote )))
@@ -124,7 +125,7 @@ def _validate_estimator(self):
124125
125126 if self .tomek is not None :
126127 if isinstance (self .tomek , TomekLinks ):
127- self .tomek_ = self .tomek
128+ self .tomek_ = clone ( self .tomek )
128129 else :
129130 raise ValueError ('tomek needs to be a TomekLinks object.'
130131 'Got {} instead.' .format (type (self .tomek )))
Original file line number Diff line number Diff line change 88
99import numpy as np
1010
11- from sklearn .base import ClassifierMixin
11+ from sklearn .base import ClassifierMixin , clone
1212from sklearn .neighbors import KNeighborsClassifier
1313from sklearn .utils import check_random_state , safe_indexing
1414from sklearn .model_selection import cross_val_predict
@@ -142,7 +142,7 @@ def _validate_estimator(self):
142142 if (self .estimator is not None and
143143 isinstance (self .estimator , ClassifierMixin ) and
144144 hasattr (self .estimator , 'predict' )):
145- self .estimator_ = self .estimator
145+ self .estimator_ = clone ( self .estimator )
146146 elif self .estimator is None :
147147 self .estimator_ = KNeighborsClassifier ()
148148 else :
Original file line number Diff line number Diff line change 1414
1515from scipy import sparse
1616
17+ from sklearn .base import clone
1718from sklearn .svm import SVC
1819from sklearn .utils import check_random_state , safe_indexing
1920
@@ -448,7 +449,7 @@ def _validate_estimator(self):
448449 if self .svm_estimator is None :
449450 self .svm_estimator_ = SVC (random_state = self .random_state )
450451 elif isinstance (self .svm_estimator , SVC ):
451- self .svm_estimator_ = self .svm_estimator
452+ self .svm_estimator_ = clone ( self .svm_estimator )
452453 else :
453454 raise_isinstance_error ('svm_estimator' , [SVC ],
454455 self .svm_estimator )
@@ -698,7 +699,7 @@ def _validate_estimator(self):
698699 self .svm_estimator == 'deprecated' ):
699700 self .svm_estimator_ = SVC (random_state = self .random_state )
700701 elif isinstance (self .svm_estimator , SVC ):
701- self .svm_estimator_ = self .svm_estimator
702+ self .svm_estimator_ = clone ( self .svm_estimator )
702703 else :
703704 raise_isinstance_error ('svm_estimator' , [SVC ],
704705 self .svm_estimator )
Original file line number Diff line number Diff line change 1111import numpy as np
1212from scipy import sparse
1313
14+ from sklearn .base import clone
1415from sklearn .cluster import KMeans
1516from sklearn .neighbors import NearestNeighbors
1617from sklearn .utils import safe_indexing
@@ -113,7 +114,7 @@ def _validate_estimator(self):
113114 self .estimator_ = KMeans (
114115 random_state = self .random_state , n_jobs = self .n_jobs )
115116 elif isinstance (self .estimator , KMeans ):
116- self .estimator_ = self .estimator
117+ self .estimator_ = clone ( self .estimator )
117118 else :
118119 raise ValueError ('`estimator` has to be a KMeans clustering.'
119120 ' Got {} instead.' .format (type (self .estimator )))
Original file line number Diff line number Diff line change 1313
1414from scipy .sparse import issparse
1515
16+ from sklearn .base import clone
1617from sklearn .neighbors import KNeighborsClassifier
1718from sklearn .utils import check_random_state , safe_indexing
1819
@@ -121,7 +122,7 @@ def _validate_estimator(self):
121122 self .estimator_ = KNeighborsClassifier (
122123 n_neighbors = self .n_neighbors , n_jobs = self .n_jobs )
123124 elif isinstance (self .n_neighbors , KNeighborsClassifier ):
124- self .estimator_ = self .n_neighbors
125+ self .estimator_ = clone ( self .n_neighbors )
125126 else :
126127 raise ValueError ('`n_neighbors` has to be a int or an object'
127128 ' inhereited from KNeighborsClassifier.'
Original file line number Diff line number Diff line change 1212
1313import numpy as np
1414
15- from sklearn .base import ClassifierMixin
15+ from sklearn .base import ClassifierMixin , clone
1616from sklearn .ensemble import RandomForestClassifier
1717from sklearn .model_selection import StratifiedKFold
1818from sklearn .utils import safe_indexing
@@ -117,7 +117,7 @@ def _validate_estimator(self):
117117 if (self .estimator is not None and
118118 isinstance (self .estimator , ClassifierMixin ) and
119119 hasattr (self .estimator , 'predict_proba' )):
120- self .estimator_ = self .estimator
120+ self .estimator_ = clone ( self .estimator )
121121 elif self .estimator is None :
122122 self .estimator_ = RandomForestClassifier (
123123 random_state = self .random_state , n_jobs = self .n_jobs )
Original file line number Diff line number Diff line change 99from collections import Counter
1010
1111import numpy as np
12+
13+ from sklearn .base import clone
1214from sklearn .neighbors import KNeighborsClassifier
1315from sklearn .utils import check_random_state , safe_indexing
1416
@@ -114,7 +116,7 @@ def _validate_estimator(self):
114116 self .estimator_ = KNeighborsClassifier (
115117 n_neighbors = self .n_neighbors , n_jobs = self .n_jobs )
116118 elif isinstance (self .n_neighbors , KNeighborsClassifier ):
117- self .estimator_ = self .n_neighbors
119+ self .estimator_ = clone ( self .n_neighbors )
118120 else :
119121 raise ValueError ('`n_neighbors` has to be a int or an object'
120122 ' inhereited from KNeighborsClassifier.'
Original file line number Diff line number Diff line change @@ -35,7 +35,8 @@ def test_check_neighbors_object():
3535 assert issubclass (type (estimator ), KNeighborsMixin )
3636 assert estimator .n_neighbors == 2
3737 estimator = NearestNeighbors (n_neighbors )
38- assert estimator is check_neighbors_object (name , estimator )
38+ estimator_cloned = check_neighbors_object (name , estimator )
39+ assert estimator .n_neighbors == estimator_cloned .n_neighbors
3940 n_neighbors = 'rnd'
4041 with pytest .raises (ValueError , match = "has to be one of" ):
4142 check_neighbors_object (name , n_neighbors )
You can’t perform that action at this time.
0 commit comments