88from collections import Counter
99
1010import numpy as np
11+ from sklearn .utils .testing import assert_allclose
1112from sklearn .utils .testing import assert_array_equal
1213
1314from imblearn .over_sampling import RandomOverSampler
@@ -40,7 +41,7 @@ def test_ros_fit_sample():
4041 [0.92923648 , 0.76103773 ], [0.47104475 , 0.44386323 ],
4142 [0.92923648 , 0.76103773 ], [0.47104475 , 0.44386323 ]])
4243 y_gt = np .array ([1 , 0 , 1 , 0 , 1 , 1 , 1 , 1 , 0 , 1 , 0 , 0 , 0 , 0 ])
43- assert_array_equal (X_resampled , X_gt )
44+ assert_allclose (X_resampled , X_gt )
4445 assert_array_equal (y_resampled , y_gt )
4546
4647
@@ -56,10 +57,27 @@ def test_ros_fit_sample_half():
5657 [0.09125309 , - 0.85409574 ], [0.12372842 , 0.6536186 ],
5758 [0.13347175 , 0.12167502 ], [0.094035 , - 2.55298982 ]])
5859 y_gt = np .array ([1 , 0 , 1 , 0 , 1 , 1 , 1 , 1 , 0 , 1 ])
59- assert_array_equal (X_resampled , X_gt )
60+ assert_allclose (X_resampled , X_gt )
6061 assert_array_equal (y_resampled , y_gt )
6162
6263
64+ def test_random_over_sampling_return_indices ():
65+ ros = RandomOverSampler (return_indices = True , random_state = RND_SEED )
66+ X_resampled , y_resampled , sample_indices = ros .fit_sample (X , Y )
67+ X_gt = np .array ([[0.04352327 , - 0.20515826 ], [0.92923648 , 0.76103773 ], [
68+ 0.20792588 , 1.49407907
69+ ], [0.47104475 , 0.44386323 ], [0.22950086 , 0.33367433 ], [
70+ 0.15490546 , 0.3130677
71+ ], [0.09125309 , - 0.85409574 ], [0.12372842 , 0.6536186 ],
72+ [0.13347175 , 0.12167502 ], [0.094035 , - 2.55298982 ],
73+ [0.92923648 , 0.76103773 ], [0.47104475 , 0.44386323 ],
74+ [0.92923648 , 0.76103773 ], [0.47104475 , 0.44386323 ]])
75+ y_gt = np .array ([1 , 0 , 1 , 0 , 1 , 1 , 1 , 1 , 0 , 1 , 0 , 0 , 0 , 0 ])
76+ assert_allclose (X_resampled , X_gt )
77+ assert_array_equal (y_resampled , y_gt )
78+ assert_array_equal (np .sort (np .unique (sample_indices )), np .arange (len (X )))
79+
80+
6381def test_multiclass_fit_sample ():
6482 y = Y .copy ()
6583 y [5 ] = 2
0 commit comments