@@ -7,7 +7,7 @@ mod decl {
77 rc:: PyRc ,
88 } ;
99 use crate :: {
10- builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyTuple , PyTupleRef , PyTypeRef } ,
10+ builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyList , PyTuple , PyTupleRef , PyTypeRef } ,
1111 convert:: ToPyObject ,
1212 function:: { ArgCallable , FuncArgs , OptionalArg , OptionalOption , PosArgs } ,
1313 identifier,
@@ -25,19 +25,18 @@ mod decl {
2525 #[ pyclass( name = "chain" ) ]
2626 #[ derive( Debug , PyPayload ) ]
2727 struct PyItertoolsChain {
28- iterables : Vec < PyObjectRef > ,
29- cur_idx : AtomicCell < usize > ,
30- cached_iter : PyRwLock < Option < PyIter > > ,
28+ source : PyRwLock < Option < PyIter > > ,
29+ active : PyRwLock < Option < PyIter > > ,
3130 }
3231
3332 #[ pyimpl( with( IterNext ) ) ]
3433 impl PyItertoolsChain {
3534 #[ pyslot]
3635 fn slot_new ( cls : PyTypeRef , args : FuncArgs , vm : & VirtualMachine ) -> PyResult {
36+ let args_list = PyList :: from ( args. args ) ;
3737 PyItertoolsChain {
38- iterables : args. args ,
39- cur_idx : AtomicCell :: new ( 0 ) ,
40- cached_iter : PyRwLock :: new ( None ) ,
38+ source : PyRwLock :: new ( Some ( args_list. to_pyobject ( vm) . get_iter ( vm) ?) ) ,
39+ active : PyRwLock :: new ( None ) ,
4140 }
4241 . into_ref_with_type ( vm, cls)
4342 . map ( Into :: into)
@@ -46,13 +45,12 @@ mod decl {
4645 #[ pyclassmethod]
4746 fn from_iterable (
4847 cls : PyTypeRef ,
49- iterable : PyObjectRef ,
48+ source : PyObjectRef ,
5049 vm : & VirtualMachine ,
5150 ) -> PyResult < PyRef < Self > > {
5251 PyItertoolsChain {
53- iterables : iterable. try_to_value ( vm) ?,
54- cur_idx : AtomicCell :: new ( 0 ) ,
55- cached_iter : PyRwLock :: new ( None ) ,
52+ source : PyRwLock :: new ( Some ( source. get_iter ( vm) ?) ) ,
53+ active : PyRwLock :: new ( None ) ,
5654 }
5755 . into_ref_with_type ( vm, cls)
5856 }
@@ -65,37 +63,45 @@ mod decl {
6563 impl IterNextIterable for PyItertoolsChain { }
6664 impl IterNext for PyItertoolsChain {
6765 fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
68- loop {
69- let pos = zelf. cur_idx . load ( ) ;
70- if pos >= zelf. iterables . len ( ) {
71- break ;
72- }
73- let cur_iter = if zelf. cached_iter . read ( ) . is_none ( ) {
74- // We need to call "get_iter" outside of the lock.
75- let iter = zelf. iterables [ pos] . clone ( ) . get_iter ( vm) ?;
76- * zelf. cached_iter . write ( ) = Some ( iter. clone ( ) ) ;
77- iter
78- } else if let Some ( cached_iter) = ( * zelf. cached_iter . read ( ) ) . clone ( ) {
79- cached_iter
80- } else {
81- // Someone changed cached iter to None since we checked.
82- continue ;
83- } ;
84-
85- // We need to call "next" outside of the lock.
86- match cur_iter. next ( vm) {
87- Ok ( PyIterReturn :: Return ( ok) ) => return Ok ( PyIterReturn :: Return ( ok) ) ,
88- Ok ( PyIterReturn :: StopIteration ( _) ) => {
89- zelf. cur_idx . fetch_add ( 1 ) ;
90- * zelf. cached_iter . write ( ) = None ;
91- }
92- Err ( err) => {
93- return Err ( err) ;
66+ let next = || {
67+ let source = zelf. source . read ( ) . clone ( ) ;
68+ match source {
69+ None => {
70+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
9471 }
72+ Some ( source) => loop {
73+ let active = zelf. active . read ( ) . clone ( ) ;
74+ match active {
75+ None => match source. next ( vm) {
76+ Ok ( PyIterReturn :: Return ( ok) ) => {
77+ * zelf. active . write ( ) = Some ( ok. get_iter ( vm) ?) ;
78+ }
79+ Ok ( PyIterReturn :: StopIteration ( _) ) => {
80+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
81+ }
82+ Err ( err) => {
83+ return Err ( err) ;
84+ }
85+ } ,
86+ Some ( active) => match active. next ( vm) {
87+ Ok ( PyIterReturn :: Return ( ok) ) => {
88+ return Ok ( PyIterReturn :: Return ( ok) ) ;
89+ }
90+ Ok ( PyIterReturn :: StopIteration ( _) ) => {
91+ * zelf. active . write ( ) = None ;
92+ }
93+ Err ( err) => {
94+ return Err ( err) ;
95+ }
96+ } ,
97+ }
98+ } ,
9599 }
96- }
97-
98- Ok ( PyIterReturn :: StopIteration ( None ) )
100+ } ;
101+ next ( ) . map_err ( |err| {
102+ * zelf. source . write ( ) = None ;
103+ err
104+ } )
99105 }
100106 }
101107
0 commit comments