|
1 | 1 | from typing import Any, Callable, List, Optional, Tuple |
2 | 2 | import numpy as np |
3 | | -from onnx import ModelProto |
| 3 | +from onnx import ModelProto, TensorProto |
4 | 4 | from onnx.reference import ReferenceEvaluator |
5 | 5 | from .._helpers import np_dtype_to_tensor_dtype |
6 | 6 | from .npx_numpy_tensors_ops import ConstantOfShape |
@@ -183,6 +183,60 @@ def __array_namespace__(self, api_version: Optional[str] = None): |
183 | 183 | f"Unable to return an implementation for api_version={api_version!r}." |
184 | 184 | ) |
185 | 185 |
|
| 186 | + def __bool__(self): |
| 187 | + "Implicit conversion to bool." |
| 188 | + if self.dtype != DType(TensorProto.BOOL): |
| 189 | + raise TypeError( |
| 190 | + f"Conversion to bool only works for bool scalar, not for {self!r}." |
| 191 | + ) |
| 192 | + if self.shape == (0,): |
| 193 | + return False |
| 194 | + if len(self.shape) != 0: |
| 195 | + raise ValueError( |
| 196 | + f"Conversion to bool only works for scalar, not for {self!r}." |
| 197 | + ) |
| 198 | + return bool(self._tensor) |
| 199 | + |
| 200 | + def __int__(self): |
| 201 | + "Implicit conversion to bool." |
| 202 | + if len(self.shape) != 0: |
| 203 | + raise ValueError( |
| 204 | + f"Conversion to bool only works for scalar, not for {self!r}." |
| 205 | + ) |
| 206 | + if self.dtype not in { |
| 207 | + DType(TensorProto.INT64), |
| 208 | + DType(TensorProto.INT32), |
| 209 | + DType(TensorProto.INT16), |
| 210 | + DType(TensorProto.INT8), |
| 211 | + DType(TensorProto.UINT64), |
| 212 | + DType(TensorProto.UINT32), |
| 213 | + DType(TensorProto.UINT16), |
| 214 | + DType(TensorProto.UINT8), |
| 215 | + }: |
| 216 | + raise TypeError( |
| 217 | + f"Conversion to int only works for int scalar, " |
| 218 | + f"not for dtype={self.dtype}." |
| 219 | + ) |
| 220 | + return int(self._tensor) |
| 221 | + |
| 222 | + def __float__(self): |
| 223 | + "Implicit conversion to bool." |
| 224 | + if len(self.shape) != 0: |
| 225 | + raise ValueError( |
| 226 | + f"Conversion to bool only works for scalar, not for {self!r}." |
| 227 | + ) |
| 228 | + if self.dtype not in { |
| 229 | + DType(TensorProto.FLOAT), |
| 230 | + DType(TensorProto.DOUBLE), |
| 231 | + DType(TensorProto.FLOAT16), |
| 232 | + DType(TensorProto.BFLOAT16), |
| 233 | + }: |
| 234 | + raise TypeError( |
| 235 | + f"Conversion to int only works for float scalar, " |
| 236 | + f"not for dtype={self.dtype}." |
| 237 | + ) |
| 238 | + return float(self._tensor) |
| 239 | + |
186 | 240 |
|
187 | 241 | class JitNumpyTensor(NumpyTensor, JitTensor): |
188 | 242 | """ |
|
0 commit comments