|
29 | 29 | npxapi_inline, |
30 | 30 | ) |
31 | 31 | from onnx_array_api.npx.npx_functions import absolute as absolute_inline |
| 32 | +from onnx_array_api.npx.npx_functions import all as all_inline |
32 | 33 | from onnx_array_api.npx.npx_functions import arange as arange_inline |
33 | 34 | from onnx_array_api.npx.npx_functions import arccos as arccos_inline |
34 | 35 | from onnx_array_api.npx.npx_functions import arccosh as arccosh_inline |
|
50 | 51 | from onnx_array_api.npx.npx_functions import det as det_inline |
51 | 52 | from onnx_array_api.npx.npx_functions import dot as dot_inline |
52 | 53 | from onnx_array_api.npx.npx_functions import einsum as einsum_inline |
| 54 | +from onnx_array_api.npx.npx_functions import equal as equal_inline |
53 | 55 | from onnx_array_api.npx.npx_functions import erf as erf_inline |
54 | 56 | from onnx_array_api.npx.npx_functions import exp as exp_inline |
55 | 57 | from onnx_array_api.npx.npx_functions import expand_dims as expand_dims_inline |
|
95 | 97 | from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor |
96 | 98 | from onnx_array_api.npx.npx_types import ( |
97 | 99 | Bool, |
| 100 | + DType, |
98 | 101 | Float32, |
99 | 102 | Float64, |
100 | 103 | Int64, |
@@ -127,18 +130,25 @@ def test_tensor(self): |
127 | 130 | self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) |
128 | 131 | self.assertEmpty(dt.shape) |
129 | 132 | self.assertEqual(dt.type_name(), "TensorType['float32']") |
| 133 | + |
130 | 134 | dt = TensorType["float32"] |
131 | 135 | self.assertEqual(len(dt.dtypes), 1) |
132 | 136 | self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) |
133 | 137 | self.assertEqual(dt.type_name(), "TensorType['float32']") |
| 138 | + |
134 | 139 | dt = TensorType[np.float32] |
135 | 140 | self.assertEqual(len(dt.dtypes), 1) |
136 | 141 | self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) |
137 | 142 | self.assertEqual(dt.type_name(), "TensorType['float32']") |
138 | 143 | self.assertEmpty(dt.shape) |
139 | 144 |
|
| 145 | + dt = TensorType[np.str_] |
| 146 | + self.assertEqual(len(dt.dtypes), 1) |
| 147 | + self.assertEqual(dt.dtypes[0].dtype, ElemType.str_) |
| 148 | + self.assertEqual(dt.type_name(), "TensorType[strings]") |
| 149 | + self.assertEmpty(dt.shape) |
| 150 | + |
140 | 151 | self.assertRaise(lambda: TensorType[None], TypeError) |
141 | | - self.assertRaise(lambda: TensorType[np.str_], TypeError) |
142 | 152 | self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError) |
143 | 153 |
|
144 | 154 | def test_superset(self): |
@@ -1155,6 +1165,16 @@ def test_astype(self): |
1155 | 1165 | got = ref.run(None, {"A": x}) |
1156 | 1166 | self.assertEqualArray(z, got[0]) |
1157 | 1167 |
|
| 1168 | + def test_astype_dtype(self): |
| 1169 | + f = absolute_inline(copy_inline(Input("A")).astype(DType(7))) |
| 1170 | + self.assertIsInstance(f, Var) |
| 1171 | + onx = f.to_onnx(constraints={"A": Float64[None]}) |
| 1172 | + x = np.array([[-5.4, 6.6]], dtype=np.float64) |
| 1173 | + z = np.abs(x.astype(np.int64)) |
| 1174 | + ref = ReferenceEvaluator(onx) |
| 1175 | + got = ref.run(None, {"A": x}) |
| 1176 | + self.assertEqualArray(z, got[0]) |
| 1177 | + |
1158 | 1178 | def test_astype_int(self): |
1159 | 1179 | f = absolute_inline(copy_inline(Input("A")).astype(1)) |
1160 | 1180 | self.assertIsInstance(f, Var) |
@@ -1413,6 +1433,9 @@ def test_einsum(self): |
1413 | 1433 | lambda x, y: np.einsum(equation, x, y), |
1414 | 1434 | ) |
1415 | 1435 |
|
| 1436 | + def test_equal(self): |
| 1437 | + self.common_test_inline_bin(equal_inline, np.equal) |
| 1438 | + |
1416 | 1439 | @unittest.skipIf(scipy is None, reason="scipy is not installed.") |
1417 | 1440 | def test_erf(self): |
1418 | 1441 | self.common_test_inline(erf_inline, scipy.special.erf) |
@@ -1460,7 +1483,17 @@ def test_hstack(self): |
1460 | 1483 | def test_identity(self): |
1461 | 1484 | f = identity_inline(2, dtype=np.float64) |
1462 | 1485 | onx = f.to_onnx(constraints={(0, False): Float64[None]}) |
1463 | | - z = np.identity(2) |
| 1486 | + self.assertIn('name: "dtype"', str(onx)) |
| 1487 | + z = np.identity(2).astype(np.float64) |
| 1488 | + ref = ReferenceEvaluator(onx) |
| 1489 | + got = ref.run(None, {}) |
| 1490 | + self.assertEqualArray(z, got[0]) |
| 1491 | + |
| 1492 | + def test_identity_uint8(self): |
| 1493 | + f = identity_inline(2, dtype=np.uint8) |
| 1494 | + onx = f.to_onnx(constraints={(0, False): Float64[None]}) |
| 1495 | + self.assertIn('name: "dtype"', str(onx)) |
| 1496 | + z = np.identity(2).astype(np.uint8) |
1464 | 1497 | ref = ReferenceEvaluator(onx) |
1465 | 1498 | got = ref.run(None, {}) |
1466 | 1499 | self.assertEqualArray(z, got[0]) |
@@ -2318,7 +2351,7 @@ def compute_labels(X, centers): |
2318 | 2351 | self.assertEqual(f.n_versions, 1) |
2319 | 2352 | self.assertEqual(len(f.available_versions), 1) |
2320 | 2353 | self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))]) |
2321 | | - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2)) |
| 2354 | + key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2)) |
2322 | 2355 | onx = f.get_onnx(key) |
2323 | 2356 | self.assertIsInstance(onx, ModelProto) |
2324 | 2357 | self.assertRaise(lambda: f.get_onnx(2), ValueError) |
@@ -2379,7 +2412,12 @@ def compute_labels(X, centers, use_sqrt=False): |
2379 | 2412 | self.assertEqualArray(got[1], dist) |
2380 | 2413 | self.assertEqual(f.n_versions, 1) |
2381 | 2414 | self.assertEqual(len(f.available_versions), 1) |
2382 | | - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True) |
| 2415 | + key = ( |
| 2416 | + (DType(TensorProto.DOUBLE), 2), |
| 2417 | + (DType(TensorProto.DOUBLE), 2), |
| 2418 | + "use_sqrt", |
| 2419 | + True, |
| 2420 | + ) |
2383 | 2421 | self.assertEqual(f.available_versions, [key]) |
2384 | 2422 | onx = f.get_onnx(key) |
2385 | 2423 | self.assertIsInstance(onx, ModelProto) |
@@ -2452,7 +2490,52 @@ def test_take(self): |
2452 | 2490 | got = ref.run(None, {"A": data, "B": indices}) |
2453 | 2491 | self.assertEqualArray(y, got[0]) |
2454 | 2492 |
|
| 2493 | + def test_numpy_all(self): |
| 2494 | + data = np.array([[1, 0], [1, 1]]).astype(np.bool_) |
| 2495 | + y = np.all(data, axis=1) |
| 2496 | + |
| 2497 | + f = all_inline(Input("A"), axis=1) |
| 2498 | + self.assertIsInstance(f, Var) |
| 2499 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2500 | + ref = ReferenceEvaluator(onx) |
| 2501 | + got = ref.run(None, {"A": data}) |
| 2502 | + self.assertEqualArray(y, got[0]) |
| 2503 | + |
| 2504 | + def test_numpy_all_empty(self): |
| 2505 | + data = np.zeros((0,), dtype=np.bool_) |
| 2506 | + y = np.all(data) |
| 2507 | + |
| 2508 | + f = all_inline(Input("A")) |
| 2509 | + self.assertIsInstance(f, Var) |
| 2510 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2511 | + ref = ReferenceEvaluator(onx) |
| 2512 | + got = ref.run(None, {"A": data}) |
| 2513 | + self.assertEqualArray(y, got[0]) |
| 2514 | + |
| 2515 | + @unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0") |
| 2516 | + def test_numpy_all_empty_axis_0(self): |
| 2517 | + data = np.zeros((0, 1), dtype=np.bool_) |
| 2518 | + y = np.all(data, axis=0) |
| 2519 | + |
| 2520 | + f = all_inline(Input("A"), axis=0) |
| 2521 | + self.assertIsInstance(f, Var) |
| 2522 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2523 | + ref = ReferenceEvaluator(onx) |
| 2524 | + got = ref.run(None, {"A": data}) |
| 2525 | + self.assertEqualArray(y, got[0]) |
| 2526 | + |
| 2527 | + def test_numpy_all_empty_axis_1(self): |
| 2528 | + data = np.zeros((0, 1), dtype=np.bool_) |
| 2529 | + y = np.all(data, axis=1) |
| 2530 | + |
| 2531 | + f = all_inline(Input("A"), axis=1) |
| 2532 | + self.assertIsInstance(f, Var) |
| 2533 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2534 | + ref = ReferenceEvaluator(onx) |
| 2535 | + got = ref.run(None, {"A": data}) |
| 2536 | + self.assertEqualArray(y, got[0]) |
| 2537 | + |
2455 | 2538 |
|
2456 | 2539 | if __name__ == "__main__": |
2457 | | - TestNpx().test_take() |
| 2540 | + # TestNpx().test_numpy_all_empty_axis_0() |
2458 | 2541 | unittest.main(verbosity=2) |
0 commit comments