Skip to content

Commit 508118e

Browse files
committed
Implement Number protocol for PyComplex
1 parent cc90bc0 commit 508118e

File tree

1 file changed

+96
-3
lines changed

1 file changed

+96
-3
lines changed

vm/src/builtins/complex.rs

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use super::{float, PyStr, PyType, PyTypeRef};
22
use crate::{
33
class::PyClassImpl,
4-
convert::ToPyObject,
4+
convert::{ToPyObject, ToPyResult},
55
function::{
66
OptionalArg, OptionalOption,
77
PyArithmeticValue::{self, *},
88
PyComparisonValue,
99
},
1010
identifier,
11-
types::{Comparable, Constructor, Hashable, PyComparisonOp},
11+
protocol::{PyNumber, PyNumberMethods},
12+
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp},
1213
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
1314
};
1415
use num_complex::Complex64;
@@ -203,7 +204,7 @@ impl PyComplex {
203204
}
204205
}
205206

206-
#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor))]
207+
#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor, AsNumber))]
207208
impl PyComplex {
208209
#[pymethod(magic)]
209210
fn complex(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyRef<PyComplex> {
@@ -419,6 +420,98 @@ impl Hashable for PyComplex {
419420
}
420421
}
421422

423+
impl AsNumber for PyComplex {
424+
const AS_NUMBER: PyNumberMethods = PyNumberMethods {
425+
add: Some(|number, other, vm| Self::number_complex_op(number, other, |a, b| a + b, vm)),
426+
subtract: Some(|number, other, vm| {
427+
Self::number_complex_op(number, other, |a, b| a - b, vm)
428+
}),
429+
multiply: Some(|number, other, vm| {
430+
Self::number_complex_op(number, other, |a, b| a * b, vm)
431+
}),
432+
remainder: None,
433+
divmod: None,
434+
power: Some(|number, other, vm| Self::number_general_op(number, other, inner_pow, vm)),
435+
negative: Some(|number, vm| {
436+
let value = Self::number_downcast(number).value;
437+
(-value).to_pyresult(vm)
438+
}),
439+
positive: Some(|number, vm| Self::number_complex(number, vm).to_pyresult(vm)),
440+
absolute: Some(|number, vm| {
441+
let value = Self::number_downcast(number).value;
442+
value.norm().to_pyresult(vm)
443+
}),
444+
boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())),
445+
invert: None,
446+
lshift: None,
447+
rshift: None,
448+
and: None,
449+
xor: None,
450+
or: None,
451+
int: None,
452+
float: None,
453+
inplace_add: None,
454+
inplace_subtract: None,
455+
inplace_multiply: None,
456+
inplace_remainder: None,
457+
inplace_divmod: None,
458+
inplace_power: None,
459+
inplace_lshift: None,
460+
inplace_rshift: None,
461+
inplace_and: None,
462+
inplace_xor: None,
463+
inplace_or: None,
464+
floor_divide: None,
465+
true_divide: Some(|number, other, vm| {
466+
Self::number_general_op(number, other, inner_div, vm)
467+
}),
468+
inplace_floor_divide: None,
469+
inplace_true_divide: None,
470+
index: None,
471+
matrix_multiply: None,
472+
inplace_matrix_multiply: None,
473+
};
474+
}
475+
476+
impl PyComplex {
477+
fn number_general_op<F, R>(
478+
number: &PyNumber,
479+
other: &PyObject,
480+
op: F,
481+
vm: &VirtualMachine,
482+
) -> PyResult
483+
where
484+
F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R,
485+
R: ToPyResult,
486+
{
487+
if let (Some(a), Some(b)) = (number.obj.payload::<Self>(), other.payload::<Self>()) {
488+
op(a.value, b.value, vm).to_pyresult(vm)
489+
} else {
490+
Ok(vm.ctx.not_implemented())
491+
}
492+
}
493+
494+
fn number_complex_op<F>(
495+
number: &PyNumber,
496+
other: &PyObject,
497+
op: F,
498+
vm: &VirtualMachine,
499+
) -> PyResult
500+
where
501+
F: FnOnce(Complex64, Complex64) -> Complex64,
502+
{
503+
Self::number_general_op(number, other, |a, b, _vm| op(a, b), vm)
504+
}
505+
506+
fn number_complex(number: &PyNumber, vm: &VirtualMachine) -> PyRef<PyComplex> {
507+
if let Some(zelf) = number.obj.downcast_ref_if_exact::<Self>(vm) {
508+
zelf.to_owned()
509+
} else {
510+
vm.ctx.new_complex(Self::number_downcast(number).value)
511+
}
512+
}
513+
}
514+
422515
#[derive(FromArgs)]
423516
pub struct ComplexArgs {
424517
#[pyarg(any, optional)]

0 commit comments

Comments
 (0)