Skip to content

Commit e50ee53

Browse files
hrchuyouknowone
authored andcommitted
Fix lshift overflow handling
1 parent 8f0e40f commit e50ee53

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed

vm/src/builtins/int.rs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,33 @@ fn inner_divmod(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult {
157157
Ok(vm.new_tuple((div, modulo)).into())
158158
}
159159

160-
fn inner_shift<F>(int1: &BigInt, int2: &BigInt, shift_op: F, vm: &VirtualMachine) -> PyResult
160+
fn overflow_shift(int2: &BigInt, vm: &VirtualMachine) -> PyResult<usize> {
161+
int2.to_usize().ok_or_else(|| {
162+
vm.new_overflow_error("the number is too large to convert to int".to_owned())
163+
})
164+
}
165+
166+
fn inner_shift<F, S>(
167+
int1: &BigInt,
168+
int2: &BigInt,
169+
shift_op: F,
170+
overflow_shift: S,
171+
vm: &VirtualMachine,
172+
) -> PyResult
161173
where
162174
F: Fn(&BigInt, usize) -> BigInt,
175+
S: Fn(&BigInt, &VirtualMachine) -> PyResult<usize>,
163176
{
164177
if int2.is_negative() {
165178
Err(vm.new_value_error("negative shift count".to_owned()))
166179
} else if int1.is_zero() {
167180
Ok(vm.ctx.new_int(0).into())
168181
} else {
169-
let int2 = int2.min(&BigInt::from(usize::MAX)).to_usize().unwrap();
170-
Ok(vm.ctx.new_int(shift_op(int1, int2)).into())
182+
overflow_shift(int2, vm).map(|v| vm.ctx.new_int(shift_op(int1, v)).into())
171183
}
172184
}
173185

186+
#[inline]
174187
fn inner_truediv(i1: &BigInt, i2: &BigInt, vm: &VirtualMachine) -> PyResult {
175188
if i2.is_zero() {
176189
return Err(vm.new_zero_division_error("integer division by zero".to_owned()));
@@ -359,22 +372,54 @@ impl PyInt {
359372

360373
#[pymethod(magic)]
361374
fn lshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
362-
self.general_op(other, |a, b| inner_shift(a, b, |a, b| a << b, vm), vm)
375+
self.general_op(
376+
other,
377+
|a, b| inner_shift(a, b, |a, b| a << b, overflow_shift, vm),
378+
vm,
379+
)
363380
}
364381

365382
#[pymethod(magic)]
366383
fn rlshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
367-
self.general_op(other, |a, b| inner_shift(b, a, |a, b| a << b, vm), vm)
384+
self.general_op(
385+
other,
386+
|a, b| inner_shift(b, a, |a, b| a << b, overflow_shift, vm),
387+
vm,
388+
)
368389
}
369390

370391
#[pymethod(magic)]
371392
fn rshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
372-
self.general_op(other, |a, b| inner_shift(a, b, |a, b| a >> b, vm), vm)
393+
self.general_op(
394+
other,
395+
|a, b| {
396+
inner_shift(
397+
a,
398+
b,
399+
|a, b| a >> b,
400+
|a, _vm| Ok(a.to_usize().unwrap_or(usize::MAX)),
401+
vm,
402+
)
403+
},
404+
vm,
405+
)
373406
}
374407

375408
#[pymethod(magic)]
376409
fn rrshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
377-
self.general_op(other, |a, b| inner_shift(b, a, |a, b| a >> b, vm), vm)
410+
self.general_op(
411+
other,
412+
|a, b| {
413+
inner_shift(
414+
b,
415+
a,
416+
|a, b| a >> b,
417+
|a, _vm| Ok(a.to_usize().unwrap_or(usize::MAX)),
418+
vm,
419+
)
420+
},
421+
vm,
422+
)
378423
}
379424

380425
#[pymethod(name = "__rxor__")]

0 commit comments

Comments
 (0)