Issue 1979: Make Decimal comparisons (other than !=, ==) involving NaN

raise InvalidOperation (and return False if InvalidOperation is trapped).
This commit is contained in:
Mark Dickinson 2008-02-06 22:10:50 +00:00
parent 55b8c3e26f
commit 2fc9263df5
4 changed files with 145 additions and 40 deletions

View file

@ -717,6 +717,39 @@ class Decimal(object):
return other._fix_nan(context)
return 0
def _compare_check_nans(self, other, context):
"""Version of _check_nans used for the signaling comparisons
compare_signal, __le__, __lt__, __ge__, __gt__.
Signal InvalidOperation if either self or other is a (quiet
or signaling) NaN. Signaling NaNs take precedence over quiet
NaNs.
Return 0 if neither operand is a NaN.
"""
if context is None:
context = getcontext()
if self._is_special or other._is_special:
if self.is_snan():
return context._raise_error(InvalidOperation,
'comparison involving sNaN',
self)
elif other.is_snan():
return context._raise_error(InvalidOperation,
'comparison involving sNaN',
other)
elif self.is_qnan():
return context._raise_error(InvalidOperation,
'comparison involving NaN',
self)
elif other.is_qnan():
return context._raise_error(InvalidOperation,
'comparison involving NaN',
other)
return 0
def __nonzero__(self):
"""Return True if self is nonzero; otherwise return False.
@ -724,18 +757,13 @@ class Decimal(object):
"""
return self._is_special or self._int != '0'
def __cmp__(self, other):
other = _convert_other(other)
if other is NotImplemented:
# Never return NotImplemented
return 1
def _cmp(self, other):
"""Compare the two non-NaN decimal instances self and other.
Returns -1 if self < other, 0 if self == other and 1
if self > other. This routine is for internal use only."""
if self._is_special or other._is_special:
# check for nans, without raising on a signaling nan
if self._isnan() or other._isnan():
return 1 # Comparison involving NaN's always reports self > other
# INF = INF
return cmp(self._isinfinity(), other._isinfinity())
# check for zeros; note that cmp(0, -0) should return 0
@ -764,15 +792,71 @@ class Decimal(object):
else: # self_adjusted < other_adjusted
return -((-1)**self._sign)
# Note: The Decimal standard doesn't cover rich comparisons for
# Decimals. In particular, the specification is silent on the
# subject of what should happen for a comparison involving a NaN.
# We take the following approach:
#
# == comparisons involving a NaN always return False
# != comparisons involving a NaN always return True
# <, >, <= and >= comparisons involving a (quiet or signaling)
# NaN signal InvalidOperation, and return False if the
# InvalidOperation is trapped.
#
# This behavior is designed to conform as closely as possible to
# that specified by IEEE 754.
def __eq__(self, other):
if not isinstance(other, (Decimal, int, long)):
return NotImplemented
return self.__cmp__(other) == 0
other = _convert_other(other)
if other is NotImplemented:
return other
if self.is_nan() or other.is_nan():
return False
return self._cmp(other) == 0
def __ne__(self, other):
if not isinstance(other, (Decimal, int, long)):
return NotImplemented
return self.__cmp__(other) != 0
other = _convert_other(other)
if other is NotImplemented:
return other
if self.is_nan() or other.is_nan():
return True
return self._cmp(other) != 0
def __lt__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) < 0
def __le__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) <= 0
def __gt__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) > 0
def __ge__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) >= 0
def compare(self, other, context=None):
"""Compares one to another.
@ -791,7 +875,7 @@ class Decimal(object):
if ans:
return ans
return Decimal(self.__cmp__(other))
return Decimal(self._cmp(other))
def __hash__(self):
"""x.__hash__() <==> hash(x)"""
@ -2452,7 +2536,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.__cmp__(other)
c = self._cmp(other)
if c == 0:
# If both operands are finite and equal in numerical value
# then an ordering is applied:
@ -2494,7 +2578,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.__cmp__(other)
c = self._cmp(other)
if c == 0:
c = self.compare_total(other)
@ -2542,23 +2626,10 @@ class Decimal(object):
It's pretty much like compare(), but all NaNs signal, with signaling
NaNs taking precedence over quiet NaNs.
"""
if context is None:
context = getcontext()
self_is_nan = self._isnan()
other_is_nan = other._isnan()
if self_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
self)
if other_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
other)
if self_is_nan:
return context._raise_error(InvalidOperation, 'NaN in compare_signal',
self)
if other_is_nan:
return context._raise_error(InvalidOperation, 'NaN in compare_signal',
other)
other = _convert_other(other, raiseit = True)
ans = self._compare_check_nans(other, context)
if ans:
return ans
return self.compare(other, context=context)
def compare_total(self, other):
@ -3065,7 +3136,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.copy_abs().__cmp__(other.copy_abs())
c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@ -3095,7 +3166,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.copy_abs().__cmp__(other.copy_abs())
c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@ -3170,7 +3241,7 @@ class Decimal(object):
if ans:
return ans
comparison = self.__cmp__(other)
comparison = self._cmp(other)
if comparison == 0:
return self.copy_sign(other)