@@ -11,14 +11,13 @@ mod decl {
1111 convert:: ToPyObject ,
1212 function:: { ArgCallable , FuncArgs , OptionalArg , OptionalOption , PosArgs } ,
1313 identifier,
14- protocol:: { PyIter , PyIterReturn } ,
14+ protocol:: { PyIter , PyIterReturn , PyNumber } ,
1515 stdlib:: sys,
1616 types:: { Constructor , IterNext , IterNextIterable } ,
1717 AsObject , Py , PyObjectRef , PyPayload , PyRef , PyResult , PyWeakRef , VirtualMachine ,
1818 } ;
1919 use crossbeam_utils:: atomic:: AtomicCell ;
20- use num_bigint:: BigInt ;
21- use num_traits:: { One , Signed , ToPrimitive , Zero } ;
20+ use num_traits:: { Signed , ToPrimitive } ;
2221 use std:: fmt;
2322
2423 #[ pyattr]
@@ -174,14 +173,14 @@ mod decl {
174173 #[ pyclass( name = "count" ) ]
175174 #[ derive( Debug , PyPayload ) ]
176175 struct PyItertoolsCount {
177- cur : PyRwLock < BigInt > ,
178- step : BigInt ,
176+ cur : PyRwLock < PyObjectRef > ,
177+ step : PyIntRef ,
179178 }
180179
181180 #[ derive( FromArgs ) ]
182181 struct CountNewArgs {
183182 #[ pyarg( positional, optional) ]
184- start : OptionalArg < PyIntRef > ,
183+ start : OptionalArg < PyObjectRef > ,
185184
186185 #[ pyarg( positional, optional) ]
187186 step : OptionalArg < PyIntRef > ,
@@ -195,14 +194,11 @@ mod decl {
195194 Self :: Args { start, step } : Self :: Args ,
196195 vm : & VirtualMachine ,
197196 ) -> PyResult {
198- let start = match start. into_option ( ) {
199- Some ( int) => int. as_bigint ( ) . clone ( ) ,
200- None => BigInt :: zero ( ) ,
201- } ;
202- let step = match step. into_option ( ) {
203- Some ( int) => int. as_bigint ( ) . clone ( ) ,
204- None => BigInt :: one ( ) ,
205- } ;
197+ let start: PyObjectRef = start. into_option ( ) . unwrap_or_else ( || vm. new_pyobj ( 0 ) ) ;
198+ let step: PyIntRef = step. into_option ( ) . unwrap_or_else ( || vm. new_pyref ( 1 ) ) ;
199+ if !PyNumber :: check ( & start, vm) {
200+ return Err ( vm. new_value_error ( "a number is require" . to_owned ( ) ) ) ;
201+ }
206202
207203 PyItertoolsCount {
208204 cur : PyRwLock :: new ( start) ,
@@ -219,7 +215,7 @@ mod decl {
219215 // if (lz->cnt == PY_SSIZE_T_MAX)
220216 // return Py_BuildValue("0(00)", Py_TYPE(lz), lz->long_cnt, lz->long_step);
221217 #[ pymethod( magic) ]
222- fn reduce ( zelf : PyRef < Self > ) -> ( PyTypeRef , ( BigInt , ) ) {
218+ fn reduce ( zelf : PyRef < Self > ) -> ( PyTypeRef , ( PyObjectRef , ) ) {
223219 ( zelf. class ( ) . clone ( ) , ( zelf. cur . read ( ) . clone ( ) , ) )
224220 }
225221
@@ -234,8 +230,9 @@ mod decl {
234230 impl IterNext for PyItertoolsCount {
235231 fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
236232 let mut cur = zelf. cur . write ( ) ;
233+ let step = zelf. step . clone ( ) ;
237234 let result = cur. clone ( ) ;
238- * cur += & zelf . step ;
235+ * cur = vm . _iadd ( & * cur , step. as_object ( ) ) ? ;
239236 Ok ( PyIterReturn :: Return ( result. to_pyobject ( vm) ) )
240237 }
241238 }
0 commit comments