mirror of
https://github.com/python/cpython.git
synced 2025-07-29 22:24:49 +00:00
Issue 1979: Make Decimal comparisons (other than !=, ==) involving NaN
raise InvalidOperation (and return False if InvalidOperation is trapped).
This commit is contained in:
parent
55b8c3e26f
commit
2fc9263df5
4 changed files with 145 additions and 40 deletions
149
Lib/decimal.py
149
Lib/decimal.py
|
@ -717,6 +717,39 @@ class Decimal(object):
|
|||
return other._fix_nan(context)
|
||||
return 0
|
||||
|
||||
def _compare_check_nans(self, other, context):
|
||||
"""Version of _check_nans used for the signaling comparisons
|
||||
compare_signal, __le__, __lt__, __ge__, __gt__.
|
||||
|
||||
Signal InvalidOperation if either self or other is a (quiet
|
||||
or signaling) NaN. Signaling NaNs take precedence over quiet
|
||||
NaNs.
|
||||
|
||||
Return 0 if neither operand is a NaN.
|
||||
|
||||
"""
|
||||
if context is None:
|
||||
context = getcontext()
|
||||
|
||||
if self._is_special or other._is_special:
|
||||
if self.is_snan():
|
||||
return context._raise_error(InvalidOperation,
|
||||
'comparison involving sNaN',
|
||||
self)
|
||||
elif other.is_snan():
|
||||
return context._raise_error(InvalidOperation,
|
||||
'comparison involving sNaN',
|
||||
other)
|
||||
elif self.is_qnan():
|
||||
return context._raise_error(InvalidOperation,
|
||||
'comparison involving NaN',
|
||||
self)
|
||||
elif other.is_qnan():
|
||||
return context._raise_error(InvalidOperation,
|
||||
'comparison involving NaN',
|
||||
other)
|
||||
return 0
|
||||
|
||||
def __nonzero__(self):
|
||||
"""Return True if self is nonzero; otherwise return False.
|
||||
|
||||
|
@ -724,18 +757,13 @@ class Decimal(object):
|
|||
"""
|
||||
return self._is_special or self._int != '0'
|
||||
|
||||
def __cmp__(self, other):
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
# Never return NotImplemented
|
||||
return 1
|
||||
def _cmp(self, other):
|
||||
"""Compare the two non-NaN decimal instances self and other.
|
||||
|
||||
Returns -1 if self < other, 0 if self == other and 1
|
||||
if self > other. This routine is for internal use only."""
|
||||
|
||||
if self._is_special or other._is_special:
|
||||
# check for nans, without raising on a signaling nan
|
||||
if self._isnan() or other._isnan():
|
||||
return 1 # Comparison involving NaN's always reports self > other
|
||||
|
||||
# INF = INF
|
||||
return cmp(self._isinfinity(), other._isinfinity())
|
||||
|
||||
# check for zeros; note that cmp(0, -0) should return 0
|
||||
|
@ -764,15 +792,71 @@ class Decimal(object):
|
|||
else: # self_adjusted < other_adjusted
|
||||
return -((-1)**self._sign)
|
||||
|
||||
# Note: The Decimal standard doesn't cover rich comparisons for
|
||||
# Decimals. In particular, the specification is silent on the
|
||||
# subject of what should happen for a comparison involving a NaN.
|
||||
# We take the following approach:
|
||||
#
|
||||
# == comparisons involving a NaN always return False
|
||||
# != comparisons involving a NaN always return True
|
||||
# <, >, <= and >= comparisons involving a (quiet or signaling)
|
||||
# NaN signal InvalidOperation, and return False if the
|
||||
# InvalidOperation is trapped.
|
||||
#
|
||||
# This behavior is designed to conform as closely as possible to
|
||||
# that specified by IEEE 754.
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, (Decimal, int, long)):
|
||||
return NotImplemented
|
||||
return self.__cmp__(other) == 0
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
return other
|
||||
if self.is_nan() or other.is_nan():
|
||||
return False
|
||||
return self._cmp(other) == 0
|
||||
|
||||
def __ne__(self, other):
|
||||
if not isinstance(other, (Decimal, int, long)):
|
||||
return NotImplemented
|
||||
return self.__cmp__(other) != 0
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
return other
|
||||
if self.is_nan() or other.is_nan():
|
||||
return True
|
||||
return self._cmp(other) != 0
|
||||
|
||||
def __lt__(self, other, context=None):
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
return other
|
||||
ans = self._compare_check_nans(other, context)
|
||||
if ans:
|
||||
return False
|
||||
return self._cmp(other) < 0
|
||||
|
||||
def __le__(self, other, context=None):
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
return other
|
||||
ans = self._compare_check_nans(other, context)
|
||||
if ans:
|
||||
return False
|
||||
return self._cmp(other) <= 0
|
||||
|
||||
def __gt__(self, other, context=None):
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
return other
|
||||
ans = self._compare_check_nans(other, context)
|
||||
if ans:
|
||||
return False
|
||||
return self._cmp(other) > 0
|
||||
|
||||
def __ge__(self, other, context=None):
|
||||
other = _convert_other(other)
|
||||
if other is NotImplemented:
|
||||
return other
|
||||
ans = self._compare_check_nans(other, context)
|
||||
if ans:
|
||||
return False
|
||||
return self._cmp(other) >= 0
|
||||
|
||||
def compare(self, other, context=None):
|
||||
"""Compares one to another.
|
||||
|
@ -791,7 +875,7 @@ class Decimal(object):
|
|||
if ans:
|
||||
return ans
|
||||
|
||||
return Decimal(self.__cmp__(other))
|
||||
return Decimal(self._cmp(other))
|
||||
|
||||
def __hash__(self):
|
||||
"""x.__hash__() <==> hash(x)"""
|
||||
|
@ -2452,7 +2536,7 @@ class Decimal(object):
|
|||
return other._fix_nan(context)
|
||||
return self._check_nans(other, context)
|
||||
|
||||
c = self.__cmp__(other)
|
||||
c = self._cmp(other)
|
||||
if c == 0:
|
||||
# If both operands are finite and equal in numerical value
|
||||
# then an ordering is applied:
|
||||
|
@ -2494,7 +2578,7 @@ class Decimal(object):
|
|||
return other._fix_nan(context)
|
||||
return self._check_nans(other, context)
|
||||
|
||||
c = self.__cmp__(other)
|
||||
c = self._cmp(other)
|
||||
if c == 0:
|
||||
c = self.compare_total(other)
|
||||
|
||||
|
@ -2542,23 +2626,10 @@ class Decimal(object):
|
|||
It's pretty much like compare(), but all NaNs signal, with signaling
|
||||
NaNs taking precedence over quiet NaNs.
|
||||
"""
|
||||
if context is None:
|
||||
context = getcontext()
|
||||
|
||||
self_is_nan = self._isnan()
|
||||
other_is_nan = other._isnan()
|
||||
if self_is_nan == 2:
|
||||
return context._raise_error(InvalidOperation, 'sNaN',
|
||||
self)
|
||||
if other_is_nan == 2:
|
||||
return context._raise_error(InvalidOperation, 'sNaN',
|
||||
other)
|
||||
if self_is_nan:
|
||||
return context._raise_error(InvalidOperation, 'NaN in compare_signal',
|
||||
self)
|
||||
if other_is_nan:
|
||||
return context._raise_error(InvalidOperation, 'NaN in compare_signal',
|
||||
other)
|
||||
other = _convert_other(other, raiseit = True)
|
||||
ans = self._compare_check_nans(other, context)
|
||||
if ans:
|
||||
return ans
|
||||
return self.compare(other, context=context)
|
||||
|
||||
def compare_total(self, other):
|
||||
|
@ -3065,7 +3136,7 @@ class Decimal(object):
|
|||
return other._fix_nan(context)
|
||||
return self._check_nans(other, context)
|
||||
|
||||
c = self.copy_abs().__cmp__(other.copy_abs())
|
||||
c = self.copy_abs()._cmp(other.copy_abs())
|
||||
if c == 0:
|
||||
c = self.compare_total(other)
|
||||
|
||||
|
@ -3095,7 +3166,7 @@ class Decimal(object):
|
|||
return other._fix_nan(context)
|
||||
return self._check_nans(other, context)
|
||||
|
||||
c = self.copy_abs().__cmp__(other.copy_abs())
|
||||
c = self.copy_abs()._cmp(other.copy_abs())
|
||||
if c == 0:
|
||||
c = self.compare_total(other)
|
||||
|
||||
|
@ -3170,7 +3241,7 @@ class Decimal(object):
|
|||
if ans:
|
||||
return ans
|
||||
|
||||
comparison = self.__cmp__(other)
|
||||
comparison = self._cmp(other)
|
||||
if comparison == 0:
|
||||
return self.copy_sign(other)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue