2424
2525Inputs = namedtuple ("case" , ["x" , "y" ])
2626
27- _cpu_cases = [Inputs (x = torch .randn (100 ,), y = torch .randn (100 ,)),
28- Inputs (x = torch .randn (100 , requires_grad = True ), y = torch .randn (100 ,requires_grad = True )),
29- # test that list/numpy arrays still works
30- Inputs (x = [1 ,2 ,3 ,4 ], y = [1 ,2 ,3 ,4 ]),
31- Inputs (x = np .random .randn (100 ,), y = np .random .randn (100 ,)),
32- # test that we can mix
33- Inputs (x = torch .randn (100 ,), y = torch .randn (100 , requires_grad = True )),
34- Inputs (x = np .random .randn (100 ,), y = torch .randn (100 , requires_grad = True )),
35- Inputs (x = torch .randn (5 ,), y = [1 ,2 ,3 ,4 ,5 ]),
36- ]
27+ _cpu_cases = [
28+ Inputs (x = torch .randn (100 ), y = torch .randn (100 )),
29+ Inputs (x = torch .randn (100 , requires_grad = True ), y = torch .randn (100 , requires_grad = True )),
30+ # test that list/numpy arrays still works
31+ Inputs (x = [1 , 2 , 3 , 4 ], y = [1 , 2 , 3 , 4 ]),
32+ Inputs (x = np .random .randn (100 ), y = np .random .randn (100 )),
33+ # test that we can mix
34+ Inputs (x = torch .randn (100 ), y = torch .randn (100 , requires_grad = True )),
35+ Inputs (x = np .random .randn (100 ), y = torch .randn (100 , requires_grad = True )),
36+ Inputs (x = torch .randn (5 ), y = [1 , 2 , 3 , 4 , 5 ]),
37+ ]
3738
38- _gpu_cases = [Inputs (x = torch .randn (100 , device = 'cuda' ), y = torch .randn (100 , device = 'cuda' )),
39- Inputs (x = torch .randn (100 ,requires_grad = True , device = 'cuda' ), y = torch .randn (100 ,requires_grad = True , device = 'cuda' )),
40- ]
39+ _gpu_cases = [
40+ Inputs (x = torch .randn (100 , device = "cuda" ), y = torch .randn (100 , device = "cuda" )),
41+ Inputs (
42+ x = torch .randn (100 , requires_grad = True , device = "cuda" ), y = torch .randn (100 , requires_grad = True , device = "cuda" )
43+ ),
44+ ]
4145
4246
43-
44- _members_to_check = [name for name , member in getmembers (plt )
45- if isfunction (member ) and not name .startswith ('_' )]
47+ _members_to_check = [name for name , member in getmembers (plt ) if isfunction (member ) and not name .startswith ("_" )]
4648
4749
4850def string_compare (text1 , text2 ):
4951 if text1 is None and text2 is None :
5052 return True
5153 remove = string .punctuation + string .whitespace
52- return text1 .translate (str .maketrans (dict .fromkeys (remove ))) == text2 .translate (str .maketrans (dict .fromkeys (remove )))
54+ return text1 .translate (str .maketrans (dict .fromkeys (remove ))) == text2 .translate (
55+ str .maketrans (dict .fromkeys (remove ))
56+ )
5357
5458
5559@pytest .mark .parametrize ("member" , _members_to_check )
@@ -59,15 +63,14 @@ def test_members(member):
5963 assert member in dir (tp )
6064
6165
62- @pytest .mark .parametrize (' test_case' , _cpu_cases )
66+ @pytest .mark .parametrize (" test_case" , _cpu_cases )
6367def test_cpu (test_case ):
6468 """ test that it works on cpu """
65- assert tp .plot (test_case .x , test_case .y , '.' )
69+ assert tp .plot (test_case .x , test_case .y , "." )
6670
6771
68- @pytest .mark .skipif (not torch .cuda .is_available (), reason = ' test requires cuda' )
69- @pytest .mark .parametrize (' test_case' , _gpu_cases )
72+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = " test requires cuda" )
73+ @pytest .mark .parametrize (" test_case" , _gpu_cases )
7074def test_gpu (test_case ):
7175 """ test that it works on gpu """
72- assert tp .plot (test_case .x , test_case .y , '.' )
73-
76+ assert tp .plot (test_case .x , test_case .y , "." )
0 commit comments