bpo-37685: Fixed comparisons of datetime.timedelta and datetime.timezone. (GH-14996)

There was a discrepancy between the Python and C implementations.

Add singletons ALWAYS_EQ, LARGEST and SMALLEST in test.support
to test mixed type comparison.
This commit is contained in:
Serhiy Storchaka 2019-08-04 12:38:46 +03:00 committed by GitHub
parent 5c72badd06
commit 17e52649c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 107 additions and 84 deletions

View file

@ -356,11 +356,28 @@ The :mod:`test.support` module defines the following constants:
Check for presence of docstrings. Check for presence of docstrings.
.. data:: TEST_HTTP_URL .. data:: TEST_HTTP_URL
Define the URL of a dedicated HTTP server for the network tests. Define the URL of a dedicated HTTP server for the network tests.
.. data:: ALWAYS_EQ
Object that is equal to anything. Used to test mixed type comparison.
.. data:: LARGEST
Object that is greater than anything (except itself).
Used to test mixed type comparison.
.. data:: SMALLEST
Object that is less than anything (except itself).
Used to test mixed type comparison.
The :mod:`test.support` module defines the following functions: The :mod:`test.support` module defines the following functions:

View file

@ -739,25 +739,25 @@ class timedelta:
if isinstance(other, timedelta): if isinstance(other, timedelta):
return self._cmp(other) <= 0 return self._cmp(other) <= 0
else: else:
_cmperror(self, other) return NotImplemented
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, timedelta): if isinstance(other, timedelta):
return self._cmp(other) < 0 return self._cmp(other) < 0
else: else:
_cmperror(self, other) return NotImplemented
def __ge__(self, other): def __ge__(self, other):
if isinstance(other, timedelta): if isinstance(other, timedelta):
return self._cmp(other) >= 0 return self._cmp(other) >= 0
else: else:
_cmperror(self, other) return NotImplemented
def __gt__(self, other): def __gt__(self, other):
if isinstance(other, timedelta): if isinstance(other, timedelta):
return self._cmp(other) > 0 return self._cmp(other) > 0
else: else:
_cmperror(self, other) return NotImplemented
def _cmp(self, other): def _cmp(self, other):
assert isinstance(other, timedelta) assert isinstance(other, timedelta)
@ -1316,25 +1316,25 @@ class time:
if isinstance(other, time): if isinstance(other, time):
return self._cmp(other) <= 0 return self._cmp(other) <= 0
else: else:
_cmperror(self, other) return NotImplemented
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, time): if isinstance(other, time):
return self._cmp(other) < 0 return self._cmp(other) < 0
else: else:
_cmperror(self, other) return NotImplemented
def __ge__(self, other): def __ge__(self, other):
if isinstance(other, time): if isinstance(other, time):
return self._cmp(other) >= 0 return self._cmp(other) >= 0
else: else:
_cmperror(self, other) return NotImplemented
def __gt__(self, other): def __gt__(self, other):
if isinstance(other, time): if isinstance(other, time):
return self._cmp(other) > 0 return self._cmp(other) > 0
else: else:
_cmperror(self, other) return NotImplemented
def _cmp(self, other, allow_mixed=False): def _cmp(self, other, allow_mixed=False):
assert isinstance(other, time) assert isinstance(other, time)
@ -2210,9 +2210,9 @@ class timezone(tzinfo):
return (self._offset, self._name) return (self._offset, self._name)
def __eq__(self, other): def __eq__(self, other):
if type(other) != timezone: if isinstance(other, timezone):
return False return self._offset == other._offset
return self._offset == other._offset return NotImplemented
def __hash__(self): def __hash__(self):
return hash(self._offset) return hash(self._offset)

View file

@ -2,11 +2,8 @@
See http://www.zope.org/Members/fdrake/DateTimeWiki/TestCases See http://www.zope.org/Members/fdrake/DateTimeWiki/TestCases
""" """
from test.support import is_resource_enabled
import itertools import itertools
import bisect import bisect
import copy import copy
import decimal import decimal
import sys import sys
@ -22,6 +19,7 @@ from array import array
from operator import lt, le, gt, ge, eq, ne, truediv, floordiv, mod from operator import lt, le, gt, ge, eq, ne, truediv, floordiv, mod
from test import support from test import support
from test.support import is_resource_enabled, ALWAYS_EQ, LARGEST, SMALLEST
import datetime as datetime_module import datetime as datetime_module
from datetime import MINYEAR, MAXYEAR from datetime import MINYEAR, MAXYEAR
@ -54,18 +52,6 @@ INF = float("inf")
NAN = float("nan") NAN = float("nan")
class ComparesEqualClass(object):
"""
A class that is always equal to whatever you compare it to.
"""
def __eq__(self, other):
return True
def __ne__(self, other):
return False
############################################################################# #############################################################################
# module tests # module tests
@ -353,6 +339,18 @@ class TestTimeZone(unittest.TestCase):
self.assertTrue(timezone(ZERO) != None) self.assertTrue(timezone(ZERO) != None)
self.assertFalse(timezone(ZERO) == None) self.assertFalse(timezone(ZERO) == None)
tz = timezone(ZERO)
self.assertTrue(tz == ALWAYS_EQ)
self.assertFalse(tz != ALWAYS_EQ)
self.assertTrue(tz < LARGEST)
self.assertFalse(tz > LARGEST)
self.assertTrue(tz <= LARGEST)
self.assertFalse(tz >= LARGEST)
self.assertFalse(tz < SMALLEST)
self.assertTrue(tz > SMALLEST)
self.assertFalse(tz <= SMALLEST)
self.assertTrue(tz >= SMALLEST)
def test_aware_datetime(self): def test_aware_datetime(self):
# test that timezone instances can be used by datetime # test that timezone instances can be used by datetime
t = datetime(1, 1, 1) t = datetime(1, 1, 1)
@ -414,10 +412,21 @@ class HarmlessMixedComparison:
# Comparison to objects of unsupported types should return # Comparison to objects of unsupported types should return
# NotImplemented which falls back to the right hand side's __eq__ # NotImplemented which falls back to the right hand side's __eq__
# method. In this case, ComparesEqualClass.__eq__ always returns True. # method. In this case, ALWAYS_EQ.__eq__ always returns True.
# ComparesEqualClass.__ne__ always returns False. # ALWAYS_EQ.__ne__ always returns False.
self.assertTrue(me == ComparesEqualClass()) self.assertTrue(me == ALWAYS_EQ)
self.assertFalse(me != ComparesEqualClass()) self.assertFalse(me != ALWAYS_EQ)
# If the other class explicitly defines ordering
# relative to our class, it is allowed to do so
self.assertTrue(me < LARGEST)
self.assertFalse(me > LARGEST)
self.assertTrue(me <= LARGEST)
self.assertFalse(me >= LARGEST)
self.assertFalse(me < SMALLEST)
self.assertTrue(me > SMALLEST)
self.assertFalse(me <= SMALLEST)
self.assertTrue(me >= SMALLEST)
def test_harmful_mixed_comparison(self): def test_harmful_mixed_comparison(self):
me = self.theclass(1, 1, 1) me = self.theclass(1, 1, 1)
@ -1582,29 +1591,6 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
self.assertRaises(TypeError, lambda: our < their) self.assertRaises(TypeError, lambda: our < their)
self.assertRaises(TypeError, lambda: their < our) self.assertRaises(TypeError, lambda: their < our)
# However, if the other class explicitly defines ordering
# relative to our class, it is allowed to do so
class LargerThanAnything:
def __lt__(self, other):
return False
def __le__(self, other):
return isinstance(other, LargerThanAnything)
def __eq__(self, other):
return isinstance(other, LargerThanAnything)
def __gt__(self, other):
return not isinstance(other, LargerThanAnything)
def __ge__(self, other):
return True
their = LargerThanAnything()
self.assertEqual(our == their, False)
self.assertEqual(their == our, False)
self.assertEqual(our != their, True)
self.assertEqual(their != our, True)
self.assertEqual(our < their, True)
self.assertEqual(their < our, False)
def test_bool(self): def test_bool(self):
# All dates are considered true. # All dates are considered true.
self.assertTrue(self.theclass.min) self.assertTrue(self.theclass.min)
@ -3781,8 +3767,8 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
self.assertRaises(ValueError, base.replace, microsecond=1000000) self.assertRaises(ValueError, base.replace, microsecond=1000000)
def test_mixed_compare(self): def test_mixed_compare(self):
t1 = time(1, 2, 3) t1 = self.theclass(1, 2, 3)
t2 = time(1, 2, 3) t2 = self.theclass(1, 2, 3)
self.assertEqual(t1, t2) self.assertEqual(t1, t2)
t2 = t2.replace(tzinfo=None) t2 = t2.replace(tzinfo=None)
self.assertEqual(t1, t2) self.assertEqual(t1, t2)

View file

@ -113,6 +113,7 @@ __all__ = [
"run_with_locale", "swap_item", "run_with_locale", "swap_item",
"swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict", "swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict",
"run_with_tz", "PGO", "missing_compiler_executable", "fd_count", "run_with_tz", "PGO", "missing_compiler_executable", "fd_count",
"ALWAYS_EQ", "LARGEST", "SMALLEST"
] ]
class Error(Exception): class Error(Exception):
@ -3103,6 +3104,41 @@ class FakePath:
return self.path return self.path
class _ALWAYS_EQ:
"""
Object that is equal to anything.
"""
def __eq__(self, other):
return True
def __ne__(self, other):
return False
ALWAYS_EQ = _ALWAYS_EQ()
@functools.total_ordering
class _LARGEST:
"""
Object that is greater than anything (except itself).
"""
def __eq__(self, other):
return isinstance(other, _LARGEST)
def __lt__(self, other):
return False
LARGEST = _LARGEST()
@functools.total_ordering
class _SMALLEST:
"""
Object that is less than anything (except itself).
"""
def __eq__(self, other):
return isinstance(other, _SMALLEST)
def __gt__(self, other):
return False
SMALLEST = _SMALLEST()
def maybe_get_event_loop_policy(): def maybe_get_event_loop_policy():
"""Return the global event loop policy if one is set, else return None.""" """Return the global event loop policy if one is set, else return None."""
return asyncio.events._event_loop_policy return asyncio.events._event_loop_policy

View file

@ -12,6 +12,7 @@ import operator
import pickle import pickle
import ipaddress import ipaddress
import weakref import weakref
from test.support import LARGEST, SMALLEST
class BaseTestCase(unittest.TestCase): class BaseTestCase(unittest.TestCase):
@ -673,20 +674,6 @@ class FactoryFunctionErrors(BaseTestCase):
self.assertFactoryError(ipaddress.ip_network, "network") self.assertFactoryError(ipaddress.ip_network, "network")
@functools.total_ordering
class LargestObject:
def __eq__(self, other):
return isinstance(other, LargestObject)
def __lt__(self, other):
return False
@functools.total_ordering
class SmallestObject:
def __eq__(self, other):
return isinstance(other, SmallestObject)
def __gt__(self, other):
return False
class ComparisonTests(unittest.TestCase): class ComparisonTests(unittest.TestCase):
v4addr = ipaddress.IPv4Address(1) v4addr = ipaddress.IPv4Address(1)
@ -775,8 +762,6 @@ class ComparisonTests(unittest.TestCase):
def test_foreign_type_ordering(self): def test_foreign_type_ordering(self):
other = object() other = object()
smallest = SmallestObject()
largest = LargestObject()
for obj in self.objects: for obj in self.objects:
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
obj < other obj < other
@ -786,14 +771,14 @@ class ComparisonTests(unittest.TestCase):
obj <= other obj <= other
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
obj >= other obj >= other
self.assertTrue(obj < largest) self.assertTrue(obj < LARGEST)
self.assertFalse(obj > largest) self.assertFalse(obj > LARGEST)
self.assertTrue(obj <= largest) self.assertTrue(obj <= LARGEST)
self.assertFalse(obj >= largest) self.assertFalse(obj >= LARGEST)
self.assertFalse(obj < smallest) self.assertFalse(obj < SMALLEST)
self.assertTrue(obj > smallest) self.assertTrue(obj > SMALLEST)
self.assertFalse(obj <= smallest) self.assertFalse(obj <= SMALLEST)
self.assertTrue(obj >= smallest) self.assertTrue(obj >= SMALLEST)
def test_mixed_type_key(self): def test_mixed_type_key(self):
# with get_mixed_type_key, you can sort addresses and network. # with get_mixed_type_key, you can sort addresses and network.

View file

@ -0,0 +1,2 @@
Fixed comparisons of :class:`datetime.timedelta` and
:class:`datetime.timezone`.

View file

@ -3741,11 +3741,8 @@ timezone_richcompare(PyDateTime_TimeZone *self,
{ {
if (op != Py_EQ && op != Py_NE) if (op != Py_EQ && op != Py_NE)
Py_RETURN_NOTIMPLEMENTED; Py_RETURN_NOTIMPLEMENTED;
if (Py_TYPE(other) != &PyDateTime_TimeZoneType) { if (!PyTZInfo_Check(other)) {
if (op == Py_EQ) Py_RETURN_NOTIMPLEMENTED;
Py_RETURN_FALSE;
else
Py_RETURN_TRUE;
} }
return delta_richcompare(self->offset, other->offset, op); return delta_richcompare(self->offset, other->offset, op);
} }