Skip to content

Commit 78586f0

Browse files
authored
Chain reduce (RustPython#4232)
1 parent dda6f86 commit 78586f0

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

Lib/test/test_itertools.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ def test_chain_from_iterable(self):
181181
self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd'))
182182
self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
183183

184-
# TODO: RUSTPYTHON
185-
@unittest.expectedFailure
186184
def test_chain_reducible(self):
187185
for oper in [copy.deepcopy] + picklecopiers:
188186
it = chain('abc', 'def')
@@ -195,8 +193,7 @@ def test_chain_reducible(self):
195193
self.assertRaises(TypeError, list, oper(chain(2, 3)))
196194
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
197195
self.pickletest(proto, chain('abc', 'def'), compare=list('abcdef'))
198-
# TODO: RUSTPYTHON
199-
@unittest.expectedFailure
196+
200197
def test_chain_setstate(self):
201198
self.assertRaises(TypeError, chain().__setstate__, ())
202199
self.assertRaises(TypeError, chain().__setstate__, [])

vm/src/stdlib/itertools.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ mod decl {
2929
active: PyRwLock<Option<PyIter>>,
3030
}
3131

32-
#[pyclass(with(IterNext), flags(BASETYPE))]
32+
#[pyclass(with(IterNext), flags(BASETYPE, HAS_DICT))]
3333
impl PyItertoolsChain {
3434
#[pyslot]
3535
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
@@ -59,6 +59,53 @@ mod decl {
5959
fn class_getitem(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
6060
PyGenericAlias::new(cls, args, vm)
6161
}
62+
63+
#[pymethod(magic)]
64+
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
65+
let source = zelf.source.read().clone();
66+
let active = zelf.active.read().clone();
67+
let cls = zelf.class().to_owned();
68+
let empty_tuple = vm.ctx.empty_tuple.clone();
69+
let reduced = match source {
70+
Some(source) => match active {
71+
Some(active) => vm.new_tuple((cls, empty_tuple, (source, active))),
72+
None => vm.new_tuple((cls, empty_tuple, (source,))),
73+
},
74+
None => vm.new_tuple((cls, empty_tuple)),
75+
};
76+
Ok(reduced)
77+
}
78+
79+
#[pymethod(magic)]
80+
fn setstate(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
81+
let args = state.as_slice();
82+
if args.is_empty() {
83+
let msg = String::from("function takes at leat 1 arguments (0 given)");
84+
return Err(vm.new_type_error(msg));
85+
}
86+
if args.len() > 2 {
87+
let msg = format!("function takes at most 2 arguments ({} given)", args.len());
88+
return Err(vm.new_type_error(msg));
89+
}
90+
let source = &args[0];
91+
if args.len() == 1 {
92+
if !PyIter::check(source.as_ref()) {
93+
return Err(vm.new_type_error(String::from("Arguments must be iterators.")));
94+
}
95+
*zelf.source.write() = source.to_owned().try_into_value(vm)?;
96+
return Ok(());
97+
}
98+
let active = &args[1];
99+
100+
if !PyIter::check(source.as_ref()) || !PyIter::check(active.as_ref()) {
101+
return Err(vm.new_type_error(String::from("Arguments must be iterators.")));
102+
}
103+
let mut source_lock = zelf.source.write();
104+
let mut active_lock = zelf.active.write();
105+
*source_lock = source.to_owned().try_into_value(vm)?;
106+
*active_lock = active.to_owned().try_into_value(vm)?;
107+
Ok(())
108+
}
62109
}
63110
impl IterNextIterable for PyItertoolsChain {}
64111
impl IterNext for PyItertoolsChain {

0 commit comments

Comments
 (0)