|
2 | 2 | * Builtin set type with a sequence of unique items. |
3 | 3 | */ |
4 | 4 | use super::{ |
5 | | - builtins_iter, IterStatus, PositionIterInternal, PyDictRef, PyGenericAlias, PyTupleRef, PyType, |
6 | | - PyTypeRef, |
| 5 | + builtins_iter, IterStatus, PositionIterInternal, PyDict, PyDictRef, PyGenericAlias, PyTupleRef, |
| 6 | + PyType, PyTypeRef, |
7 | 7 | }; |
8 | 8 | use crate::common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc}; |
9 | 9 | use crate::{ |
@@ -353,6 +353,36 @@ impl PySetInner { |
353 | 353 | Ok(()) |
354 | 354 | } |
355 | 355 |
|
| 356 | + fn update_internal(&self, iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { |
| 357 | + // check AnySet |
| 358 | + if let Ok(any_set) = AnySet::try_from_object(vm, iterable.to_owned()) { |
| 359 | + self.merge_set(any_set, vm) |
| 360 | + // check Dict |
| 361 | + } else if let Ok(dict) = iterable.to_owned().downcast_exact::<PyDict>(vm) { |
| 362 | + self.merge_dict(dict, vm) |
| 363 | + } else { |
| 364 | + // add iterable that is not AnySet or Dict |
| 365 | + for item in iterable.try_into_value::<ArgIterable>(vm)?.iter(vm)? { |
| 366 | + self.add(item?, vm)?; |
| 367 | + } |
| 368 | + Ok(()) |
| 369 | + } |
| 370 | + } |
| 371 | + |
| 372 | + fn merge_set(&self, any_set: AnySet, vm: &VirtualMachine) -> PyResult<()> { |
| 373 | + for item in any_set.as_inner().elements() { |
| 374 | + self.add(item, vm)?; |
| 375 | + } |
| 376 | + Ok(()) |
| 377 | + } |
| 378 | + |
| 379 | + fn merge_dict(&self, dict: PyDictRef, vm: &VirtualMachine) -> PyResult<()> { |
| 380 | + for (key, _value) in dict { |
| 381 | + self.add(key, vm)?; |
| 382 | + } |
| 383 | + Ok(()) |
| 384 | + } |
| 385 | + |
356 | 386 | fn intersection_update( |
357 | 387 | &self, |
358 | 388 | others: impl std::iter::Iterator<Item = ArgIterable>, |
@@ -642,8 +672,10 @@ impl PySet { |
642 | 672 | } |
643 | 673 |
|
644 | 674 | #[pymethod] |
645 | | - fn update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> { |
646 | | - self.inner.update(others.into_iter(), vm)?; |
| 675 | + fn update(&self, others: PosArgs<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> { |
| 676 | + for iterable in others { |
| 677 | + self.inner.update_internal(iterable, vm)?; |
| 678 | + } |
647 | 679 | Ok(()) |
648 | 680 | } |
649 | 681 |
|
@@ -718,7 +750,7 @@ impl Constructor for PySet { |
718 | 750 | } |
719 | 751 |
|
720 | 752 | impl Initializer for PySet { |
721 | | - type Args = OptionalArg<ArgIterable>; |
| 753 | + type Args = OptionalArg<PyObjectRef>; |
722 | 754 |
|
723 | 755 | fn init(zelf: PyRef<Self>, iterable: Self::Args, vm: &VirtualMachine) -> PyResult<()> { |
724 | 756 | if zelf.len() > 0 { |
@@ -997,6 +1029,14 @@ impl AnySet { |
997 | 1029 | ) -> PyResult<impl std::iter::Iterator<Item = ArgIterable>> { |
998 | 1030 | Ok(std::iter::once(self.into_iterable(vm)?)) |
999 | 1031 | } |
| 1032 | + |
| 1033 | + fn as_inner(&self) -> &PySetInner { |
| 1034 | + match_class!(match self.object.as_object() { |
| 1035 | + ref set @ PySet => &set.inner, |
| 1036 | + ref frozen @ PyFrozenSet => &frozen.inner, |
| 1037 | + _ => unreachable!("AnySet is always PySet or PyFrozenSet"), // should not be called. |
| 1038 | + }) |
| 1039 | + } |
1000 | 1040 | } |
1001 | 1041 |
|
1002 | 1042 | impl TryFromObject for AnySet { |
|
0 commit comments