From 8fd54c2db8d553d8cad0c2627618f737ed602e51 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 2 Oct 2023 14:30:39 -0700 Subject: [PATCH] Test fixes + moving dtypes out of the test --- tests/unit/inference/test_inference.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 894f040be207..f891e69cd13a 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -99,7 +99,16 @@ def model_w_task(request): return request.param -@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"]) +_dtypes = [] +_dtype_ids = [] +for dt, id in [(torch.float32, "fp32"), (torch.float16, "fp16")]: + if dt in get_accelerator().supported_dtypes(): + _dtypes.append(dt) + _dtype_ids.append(id) +assert len(_dtypes) > 0, "Accelerator does not support any tested data types" + + +@pytest.fixture(params=_dtypes, ids=_dtype_ids) def dtype(request): return request.param @@ -280,6 +289,12 @@ def test( if invalid_test_msg: pytest.skip(invalid_test_msg) + if dtype not in get_accelerator().supported_dtypes(): + pytest.skip(f"Acceleraor {get_accelerator().device_name()} does not support {dtype}.") + + if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("This op had not been implemented on this system.", allow_module_level=True) + model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0"))