Skip to content

Commit 283c97c

Browse files
committed
cut down modpow to keep fast path smaller
1 parent c303127 commit 283c97c

File tree

1 file changed

+51
-50
lines changed

1 file changed

+51
-50
lines changed

vm/src/builtins/int.rs

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -425,62 +425,63 @@ impl PyInt {
425425
self.int_op(other, |a, b| a & b, vm)
426426
}
427427

428+
fn modpow(&self, other: PyObjectRef, modulus: PyObjectRef, vm: &VirtualMachine) -> PyResult {
429+
let modulus = match modulus.payload_if_subclass::<PyInt>(vm) {
430+
Some(val) => val.as_bigint(),
431+
None => return Ok(vm.ctx.not_implemented()),
432+
};
433+
if modulus.is_zero() {
434+
return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned()));
435+
}
436+
437+
self.general_op(
438+
other,
439+
|a, b| {
440+
let i = if b.is_negative() {
441+
// modular multiplicative inverse
442+
// based on rust-num/num-integer#10, should hopefully be published soon
443+
fn normalize(a: BigInt, n: &BigInt) -> BigInt {
444+
let a = a % n;
445+
if a.is_negative() {
446+
a + n
447+
} else {
448+
a
449+
}
450+
}
451+
fn inverse(a: BigInt, n: &BigInt) -> Option<BigInt> {
452+
use num_integer::*;
453+
let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n);
454+
if gcd.is_one() {
455+
Some(normalize(c, n))
456+
} else {
457+
None
458+
}
459+
}
460+
let a = inverse(a % modulus, modulus).ok_or_else(|| {
461+
vm.new_value_error(
462+
"base is not invertible for the given modulus".to_owned(),
463+
)
464+
})?;
465+
let b = -b;
466+
a.modpow(&b, modulus)
467+
} else {
468+
a.modpow(b, modulus)
469+
};
470+
Ok(vm.ctx.new_int(i).into())
471+
},
472+
vm,
473+
)
474+
}
475+
428476
#[pymethod(magic)]
429477
fn pow(
430478
&self,
431479
other: PyObjectRef,
432-
mod_val: OptionalOption<PyObjectRef>,
480+
r#mod: OptionalOption<PyObjectRef>,
433481
vm: &VirtualMachine,
434482
) -> PyResult {
435-
match mod_val.flatten() {
436-
Some(int_ref) => {
437-
let int = match int_ref.payload_if_subclass::<PyInt>(vm) {
438-
Some(val) => val,
439-
None => return Ok(vm.ctx.not_implemented()),
440-
};
441-
442-
let modulus = int.as_bigint();
443-
if modulus.is_zero() {
444-
return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned()));
445-
}
446-
self.general_op(
447-
other,
448-
|a, b| {
449-
let i = if b.is_negative() {
450-
// modular multiplicative inverse
451-
// based on rust-num/num-integer#10, should hopefully be published soon
452-
fn normalize(a: BigInt, n: &BigInt) -> BigInt {
453-
let a = a % n;
454-
if a.is_negative() {
455-
a + n
456-
} else {
457-
a
458-
}
459-
}
460-
fn inverse(a: BigInt, n: &BigInt) -> Option<BigInt> {
461-
use num_integer::*;
462-
let ExtendedGcd { gcd, x: c, .. } = a.extended_gcd(n);
463-
if gcd.is_one() {
464-
Some(normalize(c, n))
465-
} else {
466-
None
467-
}
468-
}
469-
let a = inverse(a % modulus, modulus).ok_or_else(|| {
470-
vm.new_value_error(
471-
"base is not invertible for the given modulus".to_owned(),
472-
)
473-
})?;
474-
let b = -b;
475-
a.modpow(&b, modulus)
476-
} else {
477-
a.modpow(b, modulus)
478-
};
479-
Ok(vm.ctx.new_int(i).into())
480-
},
481-
vm,
482-
)
483-
}
483+
match r#mod.flatten() {
484+
Some(modulus) => self.modpow(other, modulus, vm),
484485
None => self.general_op(other, |a, b| inner_pow(a, b, vm), vm),
485486
}
486487
}

0 commit comments

Comments
 (0)