Skip to content

Commit cbaed46

Browse files
lgyanfDimitrisJim
andauthored
Implement itertools.accumulate.__reduce__ and __setstate__ (RustPython#4434)
Co-authored-by: DimitrisJim <d.f.hilliard@gmail.com>
1 parent 9cab670 commit cbaed46

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

Lib/test/test_itertools.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,8 +1694,6 @@ class TestExamples(unittest.TestCase):
16941694
def test_accumulate(self):
16951695
self.assertEqual(list(accumulate([1,2,3,4,5])), [1, 3, 6, 10, 15])
16961696

1697-
# TODO: RUSTPYTHON
1698-
@unittest.expectedFailure
16991697
def test_accumulate_reducible(self):
17001698
# check copy, deepcopy, pickle
17011699
data = [1, 2, 3, 4, 5]

vm/src/stdlib/itertools.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,54 @@ mod decl {
10921092
}
10931093

10941094
#[pyclass(with(IterNext, Iterable, Constructor))]
1095-
impl PyItertoolsAccumulate {}
1095+
impl PyItertoolsAccumulate {
1096+
#[pymethod(magic)]
1097+
fn setstate(zelf: PyRef<Self>, state: PyObjectRef, _vm: &VirtualMachine) -> PyResult<()> {
1098+
*zelf.acc_value.write() = Some(state);
1099+
Ok(())
1100+
}
1101+
1102+
#[pymethod(magic)]
1103+
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
1104+
let class = zelf.class().to_owned();
1105+
let binop = zelf.binop.clone();
1106+
let it = zelf.iterable.clone();
1107+
let acc_value = zelf.acc_value.read().clone();
1108+
if let Some(initial) = &zelf.initial {
1109+
let chain_args = PyList::from(vec![initial.clone(), it.to_pyobject(vm)]);
1110+
let chain = PyItertoolsChain {
1111+
source: PyRwLock::new(Some(chain_args.to_pyobject(vm).get_iter(vm).unwrap())),
1112+
active: PyRwLock::new(None),
1113+
};
1114+
let tup = vm.new_tuple((chain, binop));
1115+
return vm.new_tuple((class, tup, acc_value));
1116+
}
1117+
match acc_value {
1118+
Some(obj) if obj.is(&vm.ctx.none) => {
1119+
let chain_args = PyList::from(vec![]);
1120+
let chain = PyItertoolsChain {
1121+
source: PyRwLock::new(Some(
1122+
chain_args.to_pyobject(vm).get_iter(vm).unwrap(),
1123+
)),
1124+
active: PyRwLock::new(None),
1125+
}
1126+
.into_pyobject(vm);
1127+
let acc = Self {
1128+
iterable: PyIter::new(chain),
1129+
binop,
1130+
initial: None,
1131+
acc_value: PyRwLock::new(None),
1132+
};
1133+
let tup = vm.new_tuple((acc, 1, None::<PyObjectRef>));
1134+
let islice_cls = PyItertoolsIslice::class(&vm.ctx).to_owned();
1135+
return vm.new_tuple((islice_cls, tup));
1136+
}
1137+
_ => {}
1138+
}
1139+
let tup = vm.new_tuple((it, binop));
1140+
vm.new_tuple((class, tup, acc_value))
1141+
}
1142+
}
10961143

10971144
impl SelfIter for PyItertoolsAccumulate {}
10981145
impl IterNext for PyItertoolsAccumulate {

0 commit comments

Comments
 (0)