Skip to content

Commit 4c11720

Browse files
committed
feat: itertools.chain evaluate lazily
1 parent a9d16e2 commit 4c11720

File tree

1 file changed

+46
-40
lines changed

1 file changed

+46
-40
lines changed

vm/src/stdlib/itertools.rs

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod decl {
77
rc::PyRc,
88
};
99
use crate::{
10-
builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyTuple, PyTupleRef, PyTypeRef},
10+
builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef},
1111
convert::ToPyObject,
1212
function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs},
1313
identifier,
@@ -25,19 +25,18 @@ mod decl {
2525
#[pyclass(name = "chain")]
2626
#[derive(Debug, PyPayload)]
2727
struct PyItertoolsChain {
28-
iterables: Vec<PyObjectRef>,
29-
cur_idx: AtomicCell<usize>,
30-
cached_iter: PyRwLock<Option<PyIter>>,
28+
source: PyRwLock<Option<PyIter>>,
29+
active: PyRwLock<Option<PyIter>>,
3130
}
3231

3332
#[pyimpl(with(IterNext))]
3433
impl PyItertoolsChain {
3534
#[pyslot]
3635
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
36+
let args_list = PyList::from(args.args);
3737
PyItertoolsChain {
38-
iterables: args.args,
39-
cur_idx: AtomicCell::new(0),
40-
cached_iter: PyRwLock::new(None),
38+
source: PyRwLock::new(Some(args_list.to_pyobject(vm).get_iter(vm)?)),
39+
active: PyRwLock::new(None),
4140
}
4241
.into_ref_with_type(vm, cls)
4342
.map(Into::into)
@@ -46,13 +45,12 @@ mod decl {
4645
#[pyclassmethod]
4746
fn from_iterable(
4847
cls: PyTypeRef,
49-
iterable: PyObjectRef,
48+
source: PyObjectRef,
5049
vm: &VirtualMachine,
5150
) -> PyResult<PyRef<Self>> {
5251
PyItertoolsChain {
53-
iterables: iterable.try_to_value(vm)?,
54-
cur_idx: AtomicCell::new(0),
55-
cached_iter: PyRwLock::new(None),
52+
source: PyRwLock::new(Some(source.get_iter(vm)?)),
53+
active: PyRwLock::new(None),
5654
}
5755
.into_ref_with_type(vm, cls)
5856
}
@@ -65,37 +63,45 @@ mod decl {
6563
impl IterNextIterable for PyItertoolsChain {}
6664
impl IterNext for PyItertoolsChain {
6765
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
68-
loop {
69-
let pos = zelf.cur_idx.load();
70-
if pos >= zelf.iterables.len() {
71-
break;
72-
}
73-
let cur_iter = if zelf.cached_iter.read().is_none() {
74-
// We need to call "get_iter" outside of the lock.
75-
let iter = zelf.iterables[pos].clone().get_iter(vm)?;
76-
*zelf.cached_iter.write() = Some(iter.clone());
77-
iter
78-
} else if let Some(cached_iter) = (*zelf.cached_iter.read()).clone() {
79-
cached_iter
80-
} else {
81-
// Someone changed cached iter to None since we checked.
82-
continue;
83-
};
84-
85-
// We need to call "next" outside of the lock.
86-
match cur_iter.next(vm) {
87-
Ok(PyIterReturn::Return(ok)) => return Ok(PyIterReturn::Return(ok)),
88-
Ok(PyIterReturn::StopIteration(_)) => {
89-
zelf.cur_idx.fetch_add(1);
90-
*zelf.cached_iter.write() = None;
91-
}
92-
Err(err) => {
93-
return Err(err);
66+
let next = || {
67+
let source = zelf.source.read().clone();
68+
match source {
69+
None => {
70+
return Ok(PyIterReturn::StopIteration(None));
9471
}
72+
Some(source) => loop {
73+
let active = zelf.active.read().clone();
74+
match active {
75+
None => match source.next(vm) {
76+
Ok(PyIterReturn::Return(ok)) => {
77+
*zelf.active.write() = Some(ok.get_iter(vm)?);
78+
}
79+
Ok(PyIterReturn::StopIteration(_)) => {
80+
return Ok(PyIterReturn::StopIteration(None));
81+
}
82+
Err(err) => {
83+
return Err(err);
84+
}
85+
},
86+
Some(active) => match active.next(vm) {
87+
Ok(PyIterReturn::Return(ok)) => {
88+
return Ok(PyIterReturn::Return(ok));
89+
}
90+
Ok(PyIterReturn::StopIteration(_)) => {
91+
*zelf.active.write() = None;
92+
}
93+
Err(err) => {
94+
return Err(err);
95+
}
96+
},
97+
}
98+
},
9599
}
96-
}
97-
98-
Ok(PyIterReturn::StopIteration(None))
100+
};
101+
next().map_err(|err| {
102+
*zelf.source.write() = None;
103+
err
104+
})
99105
}
100106
}
101107

0 commit comments

Comments
 (0)