@@ -7,7 +7,7 @@ mod decl {
77 rc:: PyRc ,
88 } ;
99 use crate :: {
10- builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyTuple , PyTupleRef , PyTypeRef } ,
10+ builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyList , PyTuple , PyTupleRef , PyTypeRef } ,
1111 convert:: ToPyObject ,
1212 function:: { ArgCallable , FuncArgs , OptionalArg , OptionalOption , PosArgs } ,
1313 identifier,
@@ -25,19 +25,18 @@ mod decl {
2525 #[ pyclass( name = "chain" ) ]
2626 #[ derive( Debug , PyPayload ) ]
2727 struct PyItertoolsChain {
28- iterables : Vec < PyObjectRef > ,
29- cur_idx : AtomicCell < usize > ,
30- cached_iter : PyRwLock < Option < PyIter > > ,
28+ source : PyRwLock < Option < PyIter > > ,
29+ active : PyRwLock < Option < PyIter > > ,
3130 }
3231
3332 #[ pyimpl( with( IterNext ) ) ]
3433 impl PyItertoolsChain {
3534 #[ pyslot]
3635 fn slot_new ( cls : PyTypeRef , args : FuncArgs , vm : & VirtualMachine ) -> PyResult {
36+ let args_list = PyList :: from ( args. args ) ;
3737 PyItertoolsChain {
38- iterables : args. args ,
39- cur_idx : AtomicCell :: new ( 0 ) ,
40- cached_iter : PyRwLock :: new ( None ) ,
38+ source : PyRwLock :: new ( Some ( args_list. to_pyobject ( vm) . get_iter ( vm) ?) ) ,
39+ active : PyRwLock :: new ( None ) ,
4140 }
4241 . into_ref_with_type ( vm, cls)
4342 . map ( Into :: into)
@@ -46,13 +45,12 @@ mod decl {
4645 #[ pyclassmethod]
4746 fn from_iterable (
4847 cls : PyTypeRef ,
49- iterable : PyObjectRef ,
48+ source : PyObjectRef ,
5049 vm : & VirtualMachine ,
5150 ) -> PyResult < PyRef < Self > > {
5251 PyItertoolsChain {
53- iterables : iterable. try_to_value ( vm) ?,
54- cur_idx : AtomicCell :: new ( 0 ) ,
55- cached_iter : PyRwLock :: new ( None ) ,
52+ source : PyRwLock :: new ( Some ( source. get_iter ( vm) ?) ) ,
53+ active : PyRwLock :: new ( None ) ,
5654 }
5755 . into_ref_with_type ( vm, cls)
5856 }
@@ -65,37 +63,51 @@ mod decl {
6563 impl IterNextIterable for PyItertoolsChain { }
6664 impl IterNext for PyItertoolsChain {
6765 fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
68- loop {
69- let pos = zelf. cur_idx . load ( ) ;
70- if pos >= zelf. iterables . len ( ) {
71- break ;
72- }
73- let cur_iter = if zelf. cached_iter . read ( ) . is_none ( ) {
74- // We need to call "get_iter" outside of the lock.
75- let iter = zelf. iterables [ pos] . clone ( ) . get_iter ( vm) ?;
76- * zelf. cached_iter . write ( ) = Some ( iter. clone ( ) ) ;
77- iter
78- } else if let Some ( cached_iter) = ( * zelf. cached_iter . read ( ) ) . clone ( ) {
79- cached_iter
80- } else {
81- // Someone changed cached iter to None since we checked.
82- continue ;
83- } ;
84-
85- // We need to call "next" outside of the lock.
86- match cur_iter. next ( vm) {
87- Ok ( PyIterReturn :: Return ( ok) ) => return Ok ( PyIterReturn :: Return ( ok) ) ,
88- Ok ( PyIterReturn :: StopIteration ( _) ) => {
89- zelf. cur_idx . fetch_add ( 1 ) ;
90- * zelf. cached_iter . write ( ) = None ;
66+ let source = if let Some ( source) = zelf. source . read ( ) . clone ( ) {
67+ source
68+ } else {
69+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
70+ } ;
71+ let next = loop {
72+ let maybe_active = zelf. active . read ( ) . clone ( ) ;
73+ if let Some ( active) = maybe_active {
74+ match active. next ( vm) {
75+ Ok ( PyIterReturn :: Return ( ok) ) => {
76+ break Ok ( PyIterReturn :: Return ( ok) ) ;
77+ }
78+ Ok ( PyIterReturn :: StopIteration ( _) ) => {
79+ * zelf. active . write ( ) = None ;
80+ }
81+ Err ( err) => {
82+ break Err ( err) ;
83+ }
9184 }
92- Err ( err) => {
93- return Err ( err) ;
85+ } else {
86+ match source. next ( vm) {
87+ Ok ( PyIterReturn :: Return ( ok) ) => match ok. get_iter ( vm) {
88+ Ok ( iter) => {
89+ * zelf. active . write ( ) = Some ( iter) ;
90+ }
91+ Err ( err) => {
92+ break Err ( err) ;
93+ }
94+ } ,
95+ Ok ( PyIterReturn :: StopIteration ( _) ) => {
96+ break Ok ( PyIterReturn :: StopIteration ( None ) ) ;
97+ }
98+ Err ( err) => {
99+ break Err ( err) ;
100+ }
94101 }
95102 }
96- }
97-
98- Ok ( PyIterReturn :: StopIteration ( None ) )
103+ } ;
104+ match next {
105+ Err ( _) | Ok ( PyIterReturn :: StopIteration ( _) ) => {
106+ * zelf. source . write ( ) = None ;
107+ }
108+ _ => { }
109+ } ;
110+ next
99111 }
100112 }
101113
0 commit comments