@@ -7,7 +7,7 @@ mod decl {
7
7
rc:: PyRc ,
8
8
} ;
9
9
use crate :: {
10
- builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyTuple , PyTupleRef , PyTypeRef } ,
10
+ builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyList , PyTuple , PyTupleRef , PyTypeRef } ,
11
11
convert:: ToPyObject ,
12
12
function:: { ArgCallable , FuncArgs , OptionalArg , OptionalOption , PosArgs } ,
13
13
identifier,
@@ -25,19 +25,18 @@ mod decl {
25
25
#[ pyclass( name = "chain" ) ]
26
26
#[ derive( Debug , PyPayload ) ]
27
27
struct PyItertoolsChain {
28
- iterables : Vec < PyObjectRef > ,
29
- cur_idx : AtomicCell < usize > ,
30
- cached_iter : PyRwLock < Option < PyIter > > ,
28
+ source : PyRwLock < Option < PyIter > > ,
29
+ active : PyRwLock < Option < PyIter > > ,
31
30
}
32
31
33
32
#[ pyimpl( with( IterNext ) ) ]
34
33
impl PyItertoolsChain {
35
34
#[ pyslot]
36
35
fn slot_new ( cls : PyTypeRef , args : FuncArgs , vm : & VirtualMachine ) -> PyResult {
36
+ let args_list = PyList :: from ( args. args ) ;
37
37
PyItertoolsChain {
38
- iterables : args. args ,
39
- cur_idx : AtomicCell :: new ( 0 ) ,
40
- cached_iter : PyRwLock :: new ( None ) ,
38
+ source : PyRwLock :: new ( Some ( args_list. to_pyobject ( vm) . get_iter ( vm) ?) ) ,
39
+ active : PyRwLock :: new ( None ) ,
41
40
}
42
41
. into_ref_with_type ( vm, cls)
43
42
. map ( Into :: into)
@@ -46,13 +45,12 @@ mod decl {
46
45
#[ pyclassmethod]
47
46
fn from_iterable (
48
47
cls : PyTypeRef ,
49
- iterable : PyObjectRef ,
48
+ source : PyObjectRef ,
50
49
vm : & VirtualMachine ,
51
50
) -> PyResult < PyRef < Self > > {
52
51
PyItertoolsChain {
53
- iterables : iterable. try_to_value ( vm) ?,
54
- cur_idx : AtomicCell :: new ( 0 ) ,
55
- cached_iter : PyRwLock :: new ( None ) ,
52
+ source : PyRwLock :: new ( Some ( source. get_iter ( vm) ?) ) ,
53
+ active : PyRwLock :: new ( None ) ,
56
54
}
57
55
. into_ref_with_type ( vm, cls)
58
56
}
@@ -65,37 +63,51 @@ mod decl {
65
63
impl IterNextIterable for PyItertoolsChain { }
66
64
impl IterNext for PyItertoolsChain {
67
65
fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
68
- loop {
69
- let pos = zelf. cur_idx . load ( ) ;
70
- if pos >= zelf. iterables . len ( ) {
71
- break ;
72
- }
73
- let cur_iter = if zelf. cached_iter . read ( ) . is_none ( ) {
74
- // We need to call "get_iter" outside of the lock.
75
- let iter = zelf. iterables [ pos] . clone ( ) . get_iter ( vm) ?;
76
- * zelf. cached_iter . write ( ) = Some ( iter. clone ( ) ) ;
77
- iter
78
- } else if let Some ( cached_iter) = ( * zelf. cached_iter . read ( ) ) . clone ( ) {
79
- cached_iter
80
- } else {
81
- // Someone changed cached iter to None since we checked.
82
- continue ;
83
- } ;
84
-
85
- // We need to call "next" outside of the lock.
86
- match cur_iter. next ( vm) {
87
- Ok ( PyIterReturn :: Return ( ok) ) => return Ok ( PyIterReturn :: Return ( ok) ) ,
88
- Ok ( PyIterReturn :: StopIteration ( _) ) => {
89
- zelf. cur_idx . fetch_add ( 1 ) ;
90
- * zelf. cached_iter . write ( ) = None ;
66
+ let source = if let Some ( source) = zelf. source . read ( ) . clone ( ) {
67
+ source
68
+ } else {
69
+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
70
+ } ;
71
+ let next = loop {
72
+ let maybe_active = zelf. active . read ( ) . clone ( ) ;
73
+ if let Some ( active) = maybe_active {
74
+ match active. next ( vm) {
75
+ Ok ( PyIterReturn :: Return ( ok) ) => {
76
+ break Ok ( PyIterReturn :: Return ( ok) ) ;
77
+ }
78
+ Ok ( PyIterReturn :: StopIteration ( _) ) => {
79
+ * zelf. active . write ( ) = None ;
80
+ }
81
+ Err ( err) => {
82
+ break Err ( err) ;
83
+ }
91
84
}
92
- Err ( err) => {
93
- return Err ( err) ;
85
+ } else {
86
+ match source. next ( vm) {
87
+ Ok ( PyIterReturn :: Return ( ok) ) => match ok. get_iter ( vm) {
88
+ Ok ( iter) => {
89
+ * zelf. active . write ( ) = Some ( iter) ;
90
+ }
91
+ Err ( err) => {
92
+ break Err ( err) ;
93
+ }
94
+ } ,
95
+ Ok ( PyIterReturn :: StopIteration ( _) ) => {
96
+ break Ok ( PyIterReturn :: StopIteration ( None ) ) ;
97
+ }
98
+ Err ( err) => {
99
+ break Err ( err) ;
100
+ }
94
101
}
95
102
}
96
- }
97
-
98
- Ok ( PyIterReturn :: StopIteration ( None ) )
103
+ } ;
104
+ match next {
105
+ Err ( _) | Ok ( PyIterReturn :: StopIteration ( _) ) => {
106
+ * zelf. source . write ( ) = None ;
107
+ }
108
+ _ => { }
109
+ } ;
110
+ next
99
111
}
100
112
}
101
113
0 commit comments