@@ -157,19 +157,47 @@ fn inner_divmod(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult {
157157 Ok ( vm. new_tuple ( ( div, modulo) ) . into ( ) )
158158}
159159
160- fn inner_shift < F > ( int1 : & BigInt , int2 : & BigInt , shift_op : F , vm : & VirtualMachine ) -> PyResult
160+ fn inner_lshift ( base : & BigInt , bits : & BigInt , vm : & VirtualMachine ) -> PyResult {
161+ inner_shift (
162+ base,
163+ bits,
164+ |base, bits| base << bits,
165+ |bits, vm| {
166+ bits. to_usize ( ) . ok_or_else ( || {
167+ vm. new_overflow_error ( "the number is too large to convert to int" . to_owned ( ) )
168+ } )
169+ } ,
170+ vm,
171+ )
172+ }
173+
174+ fn inner_rshift ( base : & BigInt , bits : & BigInt , vm : & VirtualMachine ) -> PyResult {
175+ inner_shift (
176+ base,
177+ bits,
178+ |base, bits| base >> bits,
179+ |bits, _vm| Ok ( bits. to_usize ( ) . unwrap_or ( usize:: MAX ) ) ,
180+ vm,
181+ )
182+ }
183+
184+ fn inner_shift < F , S > (
185+ base : & BigInt ,
186+ bits : & BigInt ,
187+ shift_op : F ,
188+ shift_bits : S ,
189+ vm : & VirtualMachine ,
190+ ) -> PyResult
161191where
162192 F : Fn ( & BigInt , usize ) -> BigInt ,
193+ S : Fn ( & BigInt , & VirtualMachine ) -> PyResult < usize > ,
163194{
164- if int2 . is_negative ( ) {
195+ if bits . is_negative ( ) {
165196 Err ( vm. new_value_error ( "negative shift count" . to_owned ( ) ) )
166- } else if int1 . is_zero ( ) {
197+ } else if base . is_zero ( ) {
167198 Ok ( vm. ctx . new_int ( 0 ) . into ( ) )
168199 } else {
169- let int2 = int2. to_usize ( ) . ok_or_else ( || {
170- vm. new_overflow_error ( "the number is too large to convert to int" . to_owned ( ) )
171- } ) ?;
172- Ok ( vm. ctx . new_int ( shift_op ( int1, int2) ) . into ( ) )
200+ shift_bits ( bits, vm) . map ( |bits| vm. ctx . new_int ( shift_op ( base, bits) ) . into ( ) )
173201 }
174202}
175203
@@ -361,22 +389,22 @@ impl PyInt {
361389
362390 #[ pymethod( magic) ]
363391 fn lshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
364- self . general_op ( other, |a, b| inner_shift ( a, b , |a , b| a << b, vm) , vm)
392+ self . general_op ( other, |a, b| inner_lshift ( a, b, vm) , vm)
365393 }
366394
367395 #[ pymethod( magic) ]
368396 fn rlshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
369- self . general_op ( other, |a, b| inner_shift ( b, a, |a , b| a << b , vm) , vm)
397+ self . general_op ( other, |a, b| inner_lshift ( b, a, vm) , vm)
370398 }
371399
372400 #[ pymethod( magic) ]
373401 fn rshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
374- self . general_op ( other, |a, b| inner_shift ( a, b , |a , b| a >> b, vm) , vm)
402+ self . general_op ( other, |a, b| inner_rshift ( a, b, vm) , vm)
375403 }
376404
377405 #[ pymethod( magic) ]
378406 fn rrshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
379- self . general_op ( other, |a, b| inner_shift ( b, a, |a , b| a >> b , vm) , vm)
407+ self . general_op ( other, |a, b| inner_rshift ( b, a, vm) , vm)
380408 }
381409
382410 #[ pymethod( name = "__rxor__" ) ]
0 commit comments