Skip to content

Commit e782c21

Browse files
committed
Fix test_merge_and_mutate
1 parent c5d6ef1 commit e782c21

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

Lib/test/test_set.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,8 +1763,6 @@ def test_iter_and_mutate(self):
17631763
s.update(range(100))
17641764
list(si)
17651765

1766-
# TODO: RUSTPYTHON
1767-
@unittest.expectedFailure
17681766
def test_merge_and_mutate(self):
17691767
class X:
17701768
def __hash__(self):

vm/src/builtins/set.rs

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
* Builtin set type with a sequence of unique items.
33
*/
44
use super::{
5-
builtins_iter, IterStatus, PositionIterInternal, PyDictRef, PyGenericAlias, PyTupleRef, PyType,
6-
PyTypeRef,
5+
builtins_iter, IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef,
6+
PyType, PyTypeRef,
77
};
88
use crate::common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc};
99
use crate::{
@@ -353,6 +353,36 @@ impl PySetInner {
353353
Ok(())
354354
}
355355

356+
fn update_internal(&self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
357+
// check AnySet
358+
if let Ok(any_set) = AnySet::try_from_object(vm, iterable.to_owned()) {
359+
self.merge_set(any_set, vm)
360+
// check Dict
361+
} else if let Ok(dict) = iterable.to_owned().downcast_exact::<PyDict>(vm) {
362+
self.merge_dict(dict, vm)
363+
} else {
364+
// add iterable that is not AnySet or Dict
365+
for item in iterable.try_into_value::<ArgIterable>(vm)?.iter(vm)? {
366+
self.add(item?, vm)?;
367+
}
368+
Ok(())
369+
}
370+
}
371+
372+
fn merge_set(&self, any_set: AnySet, vm: &VirtualMachine) -> PyResult<()> {
373+
for item in any_set.as_inner().elements() {
374+
self.add(item, vm)?;
375+
}
376+
Ok(())
377+
}
378+
379+
fn merge_dict(&self, dict: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
380+
for (key, _value) in dict {
381+
self.add(key, vm)?;
382+
}
383+
Ok(())
384+
}
385+
356386
fn intersection_update(
357387
&self,
358388
others: impl std::iter::Iterator<Item = ArgIterable>,
@@ -642,8 +672,10 @@ impl PySet {
642672
}
643673

644674
#[pymethod]
645-
fn update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
646-
self.inner.update(others.into_iter(), vm)?;
675+
fn update(&self, others: PosArgs<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
676+
for iterable in others {
677+
self.inner.update_internal(iterable, vm)?;
678+
}
647679
Ok(())
648680
}
649681

@@ -718,7 +750,7 @@ impl Constructor for PySet {
718750
}
719751

720752
impl Initializer for PySet {
721-
type Args = OptionalArg<ArgIterable>;
753+
type Args = OptionalArg<PyObjectRef>;
722754

723755
fn init(zelf: PyRef<Self>, iterable: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
724756
if zelf.len() > 0 {
@@ -997,6 +1029,14 @@ impl AnySet {
9971029
) -> PyResult<impl std::iter::Iterator<Item = ArgIterable>> {
9981030
Ok(std::iter::once(self.into_iterable(vm)?))
9991031
}
1032+
1033+
fn as_inner(&self) -> &PySetInner {
1034+
match_class!(match self.object.as_object() {
1035+
ref set @ PySet => &set.inner,
1036+
ref frozen @ PyFrozenSet => &frozen.inner,
1037+
_ => unreachable!("AnySet is always PySet or PyFrozenSet"), // should not be called.
1038+
})
1039+
}
10001040
}
10011041

10021042
impl TryFromObject for AnySet {

0 commit comments

Comments
 (0)