Skip to content

Commit 8d6d47a

Browse files
authored
Merge pull request RustPython#4642 from xiaozhiyan/fix-rsub-of-pyset-pyfrozenset
Fix `rsub` of `PySet` and `PyFrozenSet`
2 parents abf850a + 7984012 commit 8d6d47a

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

extra_tests/snippets/builtin_set.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class S(set):
103103
assert set([1,2,3]) - set([1,2,3,4,5]) == set([])
104104
assert_raises(TypeError, lambda: set([1,2,3]) - [1,2,3,4,5])
105105

106+
assert set([1,2]).__sub__(set([2,3])) == set([1])
107+
assert set([1,2]).__rsub__(set([2,3])) == set([3])
108+
106109
assert set([1,2,3]).symmetric_difference(set([1,2])) == set([3])
107110
assert set([1,2,3]).symmetric_difference(set([5,6])) == set([1,2,3,5,6])
108111
assert set([1,2,3]).symmetric_difference([1,2]) == set([3])
@@ -271,6 +274,9 @@ class S(set):
271274
assert frozenset([1,2,3]) - frozenset([1,2,3,4,5]) == frozenset([])
272275
assert_raises(TypeError, lambda: frozenset([1,2,3]) - [1,2,3,4,5])
273276

277+
assert frozenset([1,2]).__sub__(frozenset([2,3])) == frozenset([1])
278+
assert frozenset([1,2]).__rsub__(frozenset([2,3])) == frozenset([3])
279+
274280
assert frozenset([1,2,3]).symmetric_difference(frozenset([1,2])) == frozenset([3])
275281
assert frozenset([1,2,3]).symmetric_difference(frozenset([5,6])) == frozenset([1,2,3,5,6])
276282
assert frozenset([1,2,3]).symmetric_difference([1,2]) == frozenset([3])
@@ -311,6 +317,11 @@ class S(set):
311317
assert frozenset([1,2,3]) - set([4,5]) == frozenset([1,2,3])
312318
assert set([1,2,3]) - frozenset([4,5]) == frozenset([1,2,3])
313319

320+
assert frozenset([1,2]).__sub__(set([2,3])) == frozenset([1])
321+
assert frozenset([1,2]).__rsub__(set([2,3])) == set([3])
322+
assert set([1,2]).__sub__(frozenset([2,3])) == set([1])
323+
assert set([1,2]).__rsub__(frozenset([2,3])) == frozenset([3])
324+
314325
assert frozenset([1,2,3]).symmetric_difference(set([1,2])) == frozenset([3])
315326
assert set([1,2,3]).symmetric_difference(frozenset([1,2])) == set([3])
316327

vm/src/builtins/set.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,20 @@ impl PySet {
604604
}
605605

606606
#[pymethod(magic)]
607-
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
608-
self.sub(other, vm)
607+
fn rsub(
608+
zelf: PyRef<Self>,
609+
other: PyObjectRef,
610+
vm: &VirtualMachine,
611+
) -> PyResult<PyArithmeticValue<Self>> {
612+
if let Ok(other) = AnySet::try_from_object(vm, other) {
613+
Ok(PyArithmeticValue::Implemented(Self {
614+
inner: other
615+
.as_inner()
616+
.difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
617+
}))
618+
} else {
619+
Ok(PyArithmeticValue::NotImplemented)
620+
}
609621
}
610622

611623
#[pymethod(name = "__rxor__")]
@@ -1003,8 +1015,20 @@ impl PyFrozenSet {
10031015
}
10041016

10051017
#[pymethod(magic)]
1006-
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
1007-
self.sub(other, vm)
1018+
fn rsub(
1019+
zelf: PyRef<Self>,
1020+
other: PyObjectRef,
1021+
vm: &VirtualMachine,
1022+
) -> PyResult<PyArithmeticValue<Self>> {
1023+
if let Ok(other) = AnySet::try_from_object(vm, other) {
1024+
Ok(PyArithmeticValue::Implemented(Self {
1025+
inner: other
1026+
.as_inner()
1027+
.difference(ArgIterable::try_from_object(vm, zelf.into())?, vm)?,
1028+
}))
1029+
} else {
1030+
Ok(PyArithmeticValue::NotImplemented)
1031+
}
10081032
}
10091033

10101034
#[pymethod(name = "__rxor__")]

0 commit comments

Comments
 (0)