Skip to content

Commit 3817246

Browse files
authored
Merge pull request RustPython#4738 from qingshi163/binopfunc
Refactor Number Protocol
2 parents aad9015 + c5ce44e commit 3817246

File tree

20 files changed

+496
-645
lines changed

20 files changed

+496
-645
lines changed

derive-impl/src/pyclass.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,9 @@ where
702702
let slot_name = slot_ident.to_string();
703703
let tokens = {
704704
const NON_ATOMIC_SLOTS: &[&str] = &["as_buffer"];
705-
const POINTER_SLOTS: &[&str] = &["as_number", "as_sequence", "as_mapping"];
705+
const POINTER_SLOTS: &[&str] = &["as_sequence", "as_mapping"];
706+
const STATIC_GEN_SLOTS: &[&str] = &["as_number"];
707+
706708
if NON_ATOMIC_SLOTS.contains(&slot_name.as_str()) {
707709
quote_spanned! { span =>
708710
slots.#slot_ident = Some(Self::#ident as _);
@@ -711,6 +713,10 @@ where
711713
quote_spanned! { span =>
712714
slots.#slot_ident.store(Some(PointerSlot::from(Self::#ident())));
713715
}
716+
} else if STATIC_GEN_SLOTS.contains(&slot_name.as_str()) {
717+
quote_spanned! { span =>
718+
slots.#slot_ident = Self::#ident().into();
719+
}
714720
} else {
715721
quote_spanned! { span =>
716722
slots.#slot_ident.store(Some(Self::#ident as _));

stdlib/src/array.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
88
let array = module
99
.get_attr("array", vm)
1010
.expect("Expect array has array type.");
11-
array.init_builtin_number_slots(&vm.ctx);
1211

1312
let collections_abc = vm
1413
.import("collections.abc", None, 0)
@@ -722,7 +721,7 @@ mod array {
722721

723722
#[pyclass(
724723
flags(BASETYPE),
725-
with(Comparable, AsBuffer, AsMapping, Iterable, Constructor)
724+
with(Comparable, AsBuffer, AsMapping, AsSequence, Iterable, Constructor)
726725
)]
727726
impl PyArray {
728727
fn read(&self) -> PyRwLockReadGuard<'_, ArrayContentType> {

vm/src/builtins/bool.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,9 @@ impl PyBool {
164164
impl AsNumber for PyBool {
165165
fn as_number() -> &'static PyNumberMethods {
166166
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
167-
and: Some(|number, other, vm| {
168-
PyBool::and(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
169-
}),
170-
xor: Some(|number, other, vm| {
171-
PyBool::xor(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
172-
}),
173-
or: Some(|number, other, vm| {
174-
PyBool::or(number.obj.to_owned(), other.to_owned(), vm).to_pyresult(vm)
175-
}),
167+
and: Some(|a, b, vm| PyBool::and(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
168+
xor: Some(|a, b, vm| PyBool::xor(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
169+
or: Some(|a, b, vm| PyBool::or(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
176170
..PyInt::AS_NUMBER
177171
};
178172
&AS_NUMBER

vm/src/builtins/bytearray.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,9 @@ impl AsSequence for PyByteArray {
843843
impl AsNumber for PyByteArray {
844844
fn as_number() -> &'static PyNumberMethods {
845845
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
846-
remainder: Some(|number, other, vm| {
847-
if let Some(number) = number.obj.downcast_ref::<PyByteArray>() {
848-
number.mod_(other.to_owned(), vm).to_pyresult(vm)
846+
remainder: Some(|a, b, vm| {
847+
if let Some(a) = a.downcast_ref::<PyByteArray>() {
848+
a.mod_(b.to_owned(), vm).to_pyresult(vm)
849849
} else {
850850
Ok(vm.ctx.not_implemented())
851851
}

vm/src/builtins/bytes.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,9 @@ impl AsSequence for PyBytes {
615615
impl AsNumber for PyBytes {
616616
fn as_number() -> &'static PyNumberMethods {
617617
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
618-
remainder: Some(|number, other, vm| {
619-
if let Some(number) = number.obj.downcast_ref::<PyBytes>() {
620-
number.mod_(other.to_owned(), vm).to_pyresult(vm)
618+
remainder: Some(|a, b, vm| {
619+
if let Some(a) = a.downcast_ref::<PyBytes>() {
620+
a.mod_(b.to_owned(), vm).to_pyresult(vm)
621621
} else {
622622
Ok(vm.ctx.not_implemented())
623623
}

vm/src/builtins/complex.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
PyComparisonValue,
99
},
1010
identifier,
11-
protocol::{PyNumber, PyNumberMethods},
11+
protocol::PyNumberMethods,
1212
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
1313
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
1414
};
@@ -418,16 +418,10 @@ impl Hashable for PyComplex {
418418
impl AsNumber for PyComplex {
419419
fn as_number() -> &'static PyNumberMethods {
420420
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
421-
add: Some(|number, other, vm| {
422-
PyComplex::number_op(number, other, |a, b, _vm| a + b, vm)
423-
}),
424-
subtract: Some(|number, other, vm| {
425-
PyComplex::number_op(number, other, |a, b, _vm| a - b, vm)
426-
}),
427-
multiply: Some(|number, other, vm| {
428-
PyComplex::number_op(number, other, |a, b, _vm| a * b, vm)
429-
}),
430-
power: Some(|number, other, vm| PyComplex::number_op(number, other, inner_pow, vm)),
421+
add: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a + b, vm)),
422+
subtract: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a - b, vm)),
423+
multiply: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a * b, vm)),
424+
power: Some(|a, b, vm| PyComplex::number_op(a, b, inner_pow, vm)),
431425
negative: Some(|number, vm| {
432426
let value = PyComplex::number_downcast(number).value;
433427
(-value).to_pyresult(vm)
@@ -440,9 +434,7 @@ impl AsNumber for PyComplex {
440434
value.norm().to_pyresult(vm)
441435
}),
442436
boolean: Some(|number, _vm| Ok(PyComplex::number_downcast(number).value.is_zero())),
443-
true_divide: Some(|number, other, vm| {
444-
PyComplex::number_op(number, other, inner_div, vm)
445-
}),
437+
true_divide: Some(|a, b, vm| PyComplex::number_op(a, b, inner_div, vm)),
446438
..PyNumberMethods::NOT_IMPLEMENTED
447439
};
448440
&AS_NUMBER
@@ -494,12 +486,12 @@ impl Representable for PyComplex {
494486
}
495487

496488
impl PyComplex {
497-
fn number_op<F, R>(number: PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
489+
fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
498490
where
499491
F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R,
500492
R: ToPyResult,
501493
{
502-
if let (Some(a), Some(b)) = (to_op_complex(number.obj, vm)?, to_op_complex(other, vm)?) {
494+
if let (Some(a), Some(b)) = (to_op_complex(a, vm)?, to_op_complex(b, vm)?) {
503495
op(a, b, vm).to_pyresult(vm)
504496
} else {
505497
Ok(vm.ctx.not_implemented())

vm/src/builtins/dict.rs

Lines changed: 42 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -472,16 +472,16 @@ impl AsSequence for PyDict {
472472
impl AsNumber for PyDict {
473473
fn as_number() -> &'static PyNumberMethods {
474474
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
475-
or: Some(|num, args, vm| {
476-
if let Some(num) = num.obj.downcast_ref::<PyDict>() {
477-
PyDict::or(num, args.to_pyobject(vm), vm)
475+
or: Some(|a, b, vm| {
476+
if let Some(a) = a.downcast_ref::<PyDict>() {
477+
PyDict::or(a, b.to_pyobject(vm), vm)
478478
} else {
479479
Ok(vm.ctx.not_implemented())
480480
}
481481
}),
482-
inplace_or: Some(|num, args, vm| {
483-
if let Some(num) = num.obj.downcast_ref::<PyDict>() {
484-
PyDict::ior(num.to_owned(), args.to_pyobject(vm), vm).map(|d| d.into())
482+
inplace_or: Some(|a, b, vm| {
483+
if let Some(a) = a.downcast_ref::<PyDict>() {
484+
PyDict::ior(a.to_owned(), b.to_pyobject(vm), vm).map(|d| d.into())
485485
} else {
486486
Ok(vm.ctx.not_implemented())
487487
}
@@ -1169,51 +1169,10 @@ impl AsSequence for PyDictKeys {
11691169
impl AsNumber for PyDictKeys {
11701170
fn as_number() -> &'static PyNumberMethods {
11711171
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
1172-
subtract: Some(|num, args, vm| {
1173-
let num = PySetInner::from_iter(
1174-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1175-
vm,
1176-
)?;
1177-
Ok(PySet {
1178-
inner: num
1179-
.difference(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
1180-
}
1181-
.into_pyobject(vm))
1182-
}),
1183-
and: Some(|num, args, vm| {
1184-
let num = PySetInner::from_iter(
1185-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1186-
vm,
1187-
)?;
1188-
Ok(PySet {
1189-
inner: num
1190-
.intersection(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
1191-
}
1192-
.into_pyobject(vm))
1193-
}),
1194-
xor: Some(|num, args, vm| {
1195-
let num = PySetInner::from_iter(
1196-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1197-
vm,
1198-
)?;
1199-
Ok(PySet {
1200-
inner: num.symmetric_difference(
1201-
ArgIterable::try_from_object(vm, args.to_owned())?,
1202-
vm,
1203-
)?,
1204-
}
1205-
.into_pyobject(vm))
1206-
}),
1207-
or: Some(|num, args, vm| {
1208-
let num = PySetInner::from_iter(
1209-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1210-
vm,
1211-
)?;
1212-
Ok(PySet {
1213-
inner: num.union(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
1214-
}
1215-
.into_pyobject(vm))
1216-
}),
1172+
subtract: Some(set_inner_number_subtract),
1173+
and: Some(set_inner_number_and),
1174+
xor: Some(set_inner_number_xor),
1175+
or: Some(set_inner_number_or),
12171176
..PyNumberMethods::NOT_IMPLEMENTED
12181177
};
12191178
&AS_NUMBER
@@ -1288,51 +1247,10 @@ impl AsSequence for PyDictItems {
12881247
impl AsNumber for PyDictItems {
12891248
fn as_number() -> &'static PyNumberMethods {
12901249
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
1291-
subtract: Some(|num, args, vm| {
1292-
let num = PySetInner::from_iter(
1293-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1294-
vm,
1295-
)?;
1296-
Ok(PySet {
1297-
inner: num
1298-
.difference(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
1299-
}
1300-
.into_pyobject(vm))
1301-
}),
1302-
and: Some(|num, args, vm| {
1303-
let num = PySetInner::from_iter(
1304-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1305-
vm,
1306-
)?;
1307-
Ok(PySet {
1308-
inner: num
1309-
.intersection(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
1310-
}
1311-
.into_pyobject(vm))
1312-
}),
1313-
xor: Some(|num, args, vm| {
1314-
let num = PySetInner::from_iter(
1315-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1316-
vm,
1317-
)?;
1318-
Ok(PySet {
1319-
inner: num.symmetric_difference(
1320-
ArgIterable::try_from_object(vm, args.to_owned())?,
1321-
vm,
1322-
)?,
1323-
}
1324-
.into_pyobject(vm))
1325-
}),
1326-
or: Some(|num, args, vm| {
1327-
let num = PySetInner::from_iter(
1328-
ArgIterable::try_from_object(vm, num.obj.to_owned())?.iter(vm)?,
1329-
vm,
1330-
)?;
1331-
Ok(PySet {
1332-
inner: num.union(ArgIterable::try_from_object(vm, args.to_owned())?, vm)?,
1333-
}
1334-
.into_pyobject(vm))
1335-
}),
1250+
subtract: Some(set_inner_number_subtract),
1251+
and: Some(set_inner_number_and),
1252+
xor: Some(set_inner_number_xor),
1253+
or: Some(set_inner_number_or),
13361254
..PyNumberMethods::NOT_IMPLEMENTED
13371255
};
13381256
&AS_NUMBER
@@ -1358,6 +1276,34 @@ impl AsSequence for PyDictValues {
13581276
}
13591277
}
13601278

1279+
fn set_inner_number_op<F>(a: &PyObject, b: &PyObject, f: F, vm: &VirtualMachine) -> PyResult
1280+
where
1281+
F: FnOnce(PySetInner, ArgIterable) -> PyResult<PySetInner>,
1282+
{
1283+
let a = PySetInner::from_iter(
1284+
ArgIterable::try_from_object(vm, a.to_owned())?.iter(vm)?,
1285+
vm,
1286+
)?;
1287+
let b = ArgIterable::try_from_object(vm, b.to_owned())?;
1288+
Ok(PySet { inner: f(a, b)? }.into_pyobject(vm))
1289+
}
1290+
1291+
fn set_inner_number_subtract(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1292+
set_inner_number_op(a, b, |a, b| a.difference(b, vm), vm)
1293+
}
1294+
1295+
fn set_inner_number_and(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1296+
set_inner_number_op(a, b, |a, b| a.intersection(b, vm), vm)
1297+
}
1298+
1299+
fn set_inner_number_xor(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1300+
set_inner_number_op(a, b, |a, b| a.symmetric_difference(b, vm), vm)
1301+
}
1302+
1303+
fn set_inner_number_or(a: &PyObject, b: &PyObject, vm: &VirtualMachine) -> PyResult {
1304+
set_inner_number_op(a, b, |a, b| a.union(b, vm), vm)
1305+
}
1306+
13611307
pub(crate) fn init(context: &Context) {
13621308
PyDict::extend_class(context, context.types.dict_type);
13631309
PyDictKeys::extend_class(context, context.types.dict_keys_type);

vm/src/builtins/float.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
PyArithmeticValue::{self, *},
1414
PyComparisonValue,
1515
},
16-
protocol::{PyNumber, PyNumberMethods},
16+
protocol::PyNumberMethods,
1717
types::{AsNumber, Callable, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
1818
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
1919
TryFromBorrowedObject, TryFromObject, VirtualMachine,
@@ -544,12 +544,12 @@ impl Hashable for PyFloat {
544544
impl AsNumber for PyFloat {
545545
fn as_number() -> &'static PyNumberMethods {
546546
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
547-
add: Some(|num, other, vm| PyFloat::number_op(num, other, |a, b, _vm| a + b, vm)),
548-
subtract: Some(|num, other, vm| PyFloat::number_op(num, other, |a, b, _vm| a - b, vm)),
549-
multiply: Some(|num, other, vm| PyFloat::number_op(num, other, |a, b, _vm| a * b, vm)),
550-
remainder: Some(|num, other, vm| PyFloat::number_op(num, other, inner_mod, vm)),
551-
divmod: Some(|num, other, vm| PyFloat::number_op(num, other, inner_divmod, vm)),
552-
power: Some(|num, other, vm| PyFloat::number_op(num, other, float_pow, vm)),
547+
add: Some(|a, b, vm| PyFloat::number_op(a, b, |a, b, _vm| a + b, vm)),
548+
subtract: Some(|a, b, vm| PyFloat::number_op(a, b, |a, b, _vm| a - b, vm)),
549+
multiply: Some(|a, b, vm| PyFloat::number_op(a, b, |a, b, _vm| a * b, vm)),
550+
remainder: Some(|a, b, vm| PyFloat::number_op(a, b, inner_mod, vm)),
551+
divmod: Some(|a, b, vm| PyFloat::number_op(a, b, inner_divmod, vm)),
552+
power: Some(|a, b, vm| PyFloat::number_op(a, b, float_pow, vm)),
553553
negative: Some(|num, vm| {
554554
let value = PyFloat::number_downcast(num).value;
555555
(-value).to_pyresult(vm)
@@ -565,8 +565,8 @@ impl AsNumber for PyFloat {
565565
try_to_bigint(value, vm).map(|x| vm.ctx.new_int(x))
566566
}),
567567
float: Some(|num, vm| Ok(PyFloat::number_downcast_exact(num, vm))),
568-
floor_divide: Some(|num, other, vm| PyFloat::number_op(num, other, inner_floordiv, vm)),
569-
true_divide: Some(|num, other, vm| PyFloat::number_op(num, other, inner_div, vm)),
568+
floor_divide: Some(|a, b, vm| PyFloat::number_op(a, b, inner_floordiv, vm)),
569+
true_divide: Some(|a, b, vm| PyFloat::number_op(a, b, inner_div, vm)),
570570
..PyNumberMethods::NOT_IMPLEMENTED
571571
};
572572
&AS_NUMBER
@@ -586,12 +586,12 @@ impl Representable for PyFloat {
586586
}
587587

588588
impl PyFloat {
589-
fn number_op<F, R>(number: PyNumber, other: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
589+
fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
590590
where
591591
F: FnOnce(f64, f64, &VirtualMachine) -> R,
592592
R: ToPyResult,
593593
{
594-
if let (Some(a), Some(b)) = (to_op_float(number.obj, vm)?, to_op_float(other, vm)?) {
594+
if let (Some(a), Some(b)) = (to_op_float(a, vm)?, to_op_float(b, vm)?) {
595595
op(a, b, vm).to_pyresult(vm)
596596
} else {
597597
Ok(vm.ctx.not_implemented())

0 commit comments

Comments
 (0)