Skip to content

Commit 36a2700

Browse files
authored
Fix cmath to handle some edge cases in complex number calculation (RustPython#3987)
* Fix: handle edge case for complex comparison, when both is both NaN+NaNj * Fix: handle f64 overflow complex __abs__
1 parent e8ed8aa commit 36a2700

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

Lib/test/test_cmath.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ def test_infinity_and_nan_constants(self):
173173
self.assertEqual(repr(cmath.nan), "nan")
174174
self.assertEqual(repr(cmath.nanj), "nanj")
175175

176-
# TODO: RUSTPYTHON see TODO in cmath_log.
177-
@unittest.expectedFailure
178176
def test_user_object(self):
179177
# Test automatic calling of __complex__ and __float__ by cmath
180178
# functions
@@ -536,8 +534,6 @@ def test_abs(self):
536534
self.assertEqual(abs(complex(INF, NAN)), INF)
537535
self.assertTrue(math.isnan(abs(complex(NAN, NAN))))
538536

539-
# TODO: RUSTPYTHON
540-
@unittest.expectedFailure
541537
@requires_IEEE_754
542538
def test_abs_overflows(self):
543539
# result overflows

vm/src/builtins/complex.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,15 @@ impl PyComplex {
228228
}
229229

230230
#[pymethod(magic)]
231-
fn abs(&self) -> f64 {
231+
fn abs(&self, vm: &VirtualMachine) -> PyResult<f64> {
232232
let Complex64 { im, re } = self.value;
233-
re.hypot(im)
233+
let is_finite = im.is_finite() && re.is_finite();
234+
let abs_result = re.hypot(im);
235+
if is_finite && abs_result.is_infinite() {
236+
Err(vm.new_overflow_error("absolute value too large".to_string()))
237+
} else {
238+
Ok(abs_result)
239+
}
234240
}
235241

236242
#[inline]
@@ -402,7 +408,15 @@ impl Comparable for PyComplex {
402408
) -> PyResult<PyComparisonValue> {
403409
op.eq_only(|| {
404410
let result = if let Some(other) = other.payload_if_subclass::<PyComplex>(vm) {
405-
zelf.value == other.value
411+
if zelf.value.re.is_nan()
412+
&& zelf.value.im.is_nan()
413+
&& other.value.re.is_nan()
414+
&& other.value.im.is_nan()
415+
{
416+
true
417+
} else {
418+
zelf.value == other.value
419+
}
406420
} else {
407421
match float::to_op_float(other, vm) {
408422
Ok(Some(other)) => zelf.value == other.into(),

0 commit comments

Comments
 (0)