|
24 | 24 |
|
25 | 25 | Inputs = namedtuple("case", ["x", "y"]) |
26 | 26 |
|
27 | | -_cpu_cases = [ |
| 27 | +_cases = [ |
28 | 28 | Inputs(x=torch.randn(100), y=torch.randn(100)), |
29 | 29 | Inputs(x=torch.randn(100, requires_grad=True), y=torch.randn(100, requires_grad=True)), |
30 | 30 | # test that list/numpy arrays still works |
|
36 | 36 | Inputs(x=torch.randn(5), y=[1, 2, 3, 4, 5]), |
37 | 37 | ] |
38 | 38 |
|
39 | | -_gpu_cases = [ |
40 | | - Inputs(x=torch.randn(100, device="cuda"), y=torch.randn(100, device="cuda")), |
41 | | - Inputs( |
42 | | - x=torch.randn(100, requires_grad=True, device="cuda"), y=torch.randn(100, requires_grad=True, device="cuda") |
43 | | - ), |
44 | | -] |
45 | | - |
46 | 39 |
|
47 | 40 | _members_to_check = [name for name, member in getmembers(plt) if isfunction(member) and not name.startswith("_")] |
48 | 41 |
|
49 | 42 |
|
50 | | -def string_compare(text1, text2): |
51 | | - if text1 is None and text2 is None: |
52 | | - return True |
53 | | - remove = string.punctuation + string.whitespace |
54 | | - return text1.translate(str.maketrans(dict.fromkeys(remove))) == text2.translate( |
55 | | - str.maketrans(dict.fromkeys(remove)) |
56 | | - ) |
57 | | - |
58 | | - |
59 | 43 | @pytest.mark.parametrize("member", _members_to_check) |
60 | 44 | def test_members(member): |
61 | 45 | """ test that all members have been copied """ |
62 | 46 | assert member in dir(plt) |
63 | 47 | assert member in dir(tp) |
64 | 48 |
|
65 | 49 |
|
66 | | -@pytest.mark.parametrize("test_case", _cpu_cases) |
| 50 | +@pytest.mark.parametrize("test_case", _cases) |
67 | 51 | def test_cpu(test_case): |
68 | 52 | """ test that it works on cpu """ |
69 | 53 | assert tp.plot(test_case.x, test_case.y, ".") |
70 | 54 |
|
71 | 55 |
|
72 | 56 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") |
73 | | -@pytest.mark.parametrize("test_case", _gpu_cases) |
| 57 | +@pytest.mark.parametrize("test_case", _cases) |
74 | 58 | def test_gpu(test_case): |
75 | 59 | """ test that it works on gpu """ |
76 | | - assert tp.plot(test_case.x, test_case.y, ".") |
| 60 | + assert tp.plot( |
| 61 | + test_case.x.cuda() if isinstance(test_case.x, torch.Tensor) else test_case.x, |
| 62 | + test_case.y.cuda() if isinstance(test_case.y, torch.Tensor) else test_case.y, |
| 63 | + "." |
| 64 | + ) |
0 commit comments