Skip to content

Commit bc3101e

Browse files
gh-143006: Fix and optimize mixed comparison of float and int
When comparing negative non-integer float and int with the same number of bits in the integer part, __neg__() in the int subclass returning not an int caused an assertion error. Now the integer is no longer negated. Also, reduced the number of temporary created Python objects.
1 parent a88d1b8 commit bc3101e

File tree

3 files changed

+52
-45
lines changed

3 files changed

+52
-45
lines changed

Lib/test/test_float.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,24 @@ class F(float, H):
651651
value = F('nan')
652652
self.assertEqual(hash(value), object.__hash__(value))
653653

654+
def test_issue_gh143006(self):
655+
# When comparing negative non-integer float and int with the
656+
# same number of bits in the integer part, __neg__() in the
657+
# int subclass returning not an int caused an assertion error.
658+
class EvilInt(int):
659+
def __neg__(self):
660+
return ""
661+
662+
i = -1 << 50
663+
f = float(i) - 0.5
664+
i = EvilInt(i)
665+
self.assertFalse(f == i)
666+
self.assertTrue(f != i)
667+
self.assertTrue(f < i)
668+
self.assertTrue(f <= i)
669+
self.assertFalse(f > i)
670+
self.assertFalse(f >= i)
671+
654672

655673
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
656674
class FormatFunctionsTestCase(unittest.TestCase):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix a possible assertion error when comparing negative non-integer ``float``
2+
and ``int`` with the same number of bits in the integer part.

Objects/floatobject.c

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -435,27 +435,17 @@ float_richcompare(PyObject *v, PyObject *w, int op)
435435
assert(vsign != 0); /* if vsign were 0, then since wsign is
436436
* not 0, we would have taken the
437437
* vsign != wsign branch at the start */
438-
/* We want to work with non-negative numbers. */
439-
if (vsign < 0) {
440-
/* "Multiply both sides" by -1; this also swaps the
441-
* comparator.
442-
*/
443-
i = -i;
444-
op = _Py_SwappedOp[op];
445-
}
446-
assert(i > 0.0);
447438
(void) frexp(i, &exponent);
448439
/* exponent is the # of bits in v before the radix point;
449440
* we know that nbits (the # of bits in w) > 48 at this point
450441
*/
451442
if (exponent < nbits) {
452-
i = 1.0;
453-
j = 2.0;
443+
j = i;
444+
i = 0.0;
454445
goto Compare;
455446
}
456447
if (exponent > nbits) {
457-
i = 2.0;
458-
j = 1.0;
448+
j = 0.0;
459449
goto Compare;
460450
}
461451
/* v and w have the same number of bits before the radix
@@ -467,50 +457,47 @@ float_richcompare(PyObject *v, PyObject *w, int op)
467457
double intpart;
468458
PyObject *result = NULL;
469459
PyObject *vv = NULL;
470-
PyObject *ww = w;
471460

472-
if (wsign < 0) {
473-
ww = PyNumber_Negative(w);
474-
if (ww == NULL)
475-
goto Error;
461+
fracpart = modf(i, &intpart);
462+
if (fracpart != 0.0) {
463+
switch (op) {
464+
case Py_EQ:
465+
Py_RETURN_FALSE;
466+
case Py_NE:
467+
Py_RETURN_TRUE;
468+
case Py_LE:
469+
if (vsign > 0) {
470+
op = Py_LT;
471+
}
472+
break;
473+
case Py_GE:
474+
if (vsign < 0) {
475+
op = Py_GT;
476+
}
477+
break;
478+
case Py_LT:
479+
if (vsign < 0) {
480+
op = Py_LE;
481+
}
482+
break;
483+
case Py_GT:
484+
if (vsign > 0) {
485+
op = Py_GE;
486+
}
487+
break;
488+
}
476489
}
477-
else
478-
Py_INCREF(ww);
479490

480-
fracpart = modf(i, &intpart);
481491
vv = PyLong_FromDouble(intpart);
482492
if (vv == NULL)
483493
goto Error;
484494

485-
if (fracpart != 0.0) {
486-
/* Shift left, and or a 1 bit into vv
487-
* to represent the lost fraction.
488-
*/
489-
PyObject *temp;
490-
491-
temp = _PyLong_Lshift(ww, 1);
492-
if (temp == NULL)
493-
goto Error;
494-
Py_SETREF(ww, temp);
495-
496-
temp = _PyLong_Lshift(vv, 1);
497-
if (temp == NULL)
498-
goto Error;
499-
Py_SETREF(vv, temp);
500-
501-
temp = PyNumber_Or(vv, _PyLong_GetOne());
502-
if (temp == NULL)
503-
goto Error;
504-
Py_SETREF(vv, temp);
505-
}
506-
507-
r = PyObject_RichCompareBool(vv, ww, op);
495+
r = PyObject_RichCompareBool(vv, w, op);
508496
if (r < 0)
509497
goto Error;
510498
result = PyBool_FromLong(r);
511499
Error:
512500
Py_XDECREF(vv);
513-
Py_XDECREF(ww);
514501
return result;
515502
}
516503
} /* else if (PyLong_Check(w)) */

0 commit comments

Comments
 (0)