Preserve order in Union.

This commit is contained in:
Eric Snow 2018-01-31 22:57:07 +00:00
parent 6b4e02ef45
commit ecd482345c
2 changed files with 60 additions and 28 deletions

View file

@ -9,6 +9,17 @@ REF = '<ref>'
TYPE_REFERENCE = sentinel('TYPE_REFERENCE')
def _is_simple(datatype):
if datatype is ANY:
return True
elif datatype in list(SIMPLE_TYPES):
return True
elif isinstance(datatype, Enum):
return True
else:
return False
def _normalize_datatype(datatype):
cls = type(datatype)
if datatype == REF or datatype is TYPE_REFERENCE:
@ -51,31 +62,31 @@ def _replace_ref(datatype, target):
return datatype
class Enum(namedtuple('Enum', 'datatype choices')):
class Enum(namedtuple('Enum', 'datatype choice')):
"""A simple type with a limited set of allowed values."""
@classmethod
def _check_choices(cls, datatype, choices, strict=True):
if callable(choices):
return choices
def _check_choice(cls, datatype, choice, strict=True):
if callable(choice):
return choice
if isinstance(choices, str):
msg = 'bad choices (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choices))
if isinstance(choice, str):
msg = 'bad choice (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choice))
choices = frozenset(choices)
if not choices:
raise TypeError('missing choices')
choice = frozenset(choice)
if not choice:
raise TypeError('missing choice')
if not strict:
return choices
return choice
for value in choices:
for value in choice:
if type(value) is not datatype:
msg = 'bad choices (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choices))
return choices
msg = 'bad choice (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choice))
return choice
def __new__(cls, datatype, choices, **kwargs):
def __new__(cls, datatype, choice, **kwargs):
strict = kwargs.pop('strict', True)
normalize = kwargs.pop('_normalize', True)
(lambda: None)(**kwargs) # Make sure there aren't any other kwargs.
@ -88,18 +99,19 @@ class Enum(namedtuple('Enum', 'datatype choices')):
if normalize:
# There's no need to normalize datatype (it's a simple type).
pass
choices = cls._check_choices(datatype, choices, strict=strict)
choice = cls._check_choice(datatype, choice, strict=strict)
self = super(Enum, cls).__new__(cls, datatype, choices)
self = super(Enum, cls).__new__(cls, datatype, choice)
return self
class Union(frozenset):
class Union(tuple):
"""Declare a union of different types.
Sets and frozensets are treated equivalently in declarations.
The declared order is preserved and respected.
Sets and frozensets are treated Unions in declarations.
"""
__slots__ = ()
@classmethod
def _traverse(cls, datatypes, op):
@ -122,11 +134,31 @@ class Union(frozenset):
datatypes,
lambda dt: _transform_datatype(dt, _normalize_datatype),
)
return super(Union, cls).__new__(cls, datatypes)
self = super(Union, cls).__new__(cls, datatypes)
self._simple = all(_is_simple(dt) for dt in datatypes)
return self
def __repr__(self):
return '{}{}'.format(type(self).__name__, tuple(self))
def __hash__(self):
return super(Union, self).__hash__()
def __eq__(self, other): # honors order
if not isinstance(other, Union):
return NotImplemented
if super(Union, self).__eq__(other):
return True
if set(self) != set(other):
return False
if self._simple and other._simple:
return True
return NotImplemented
def __ne__(self, other):
return not (self == other)
@property
def datatypes(self):
return set(self)

View file

@ -37,8 +37,8 @@ class ModuleTests(unittest.TestCase):
(ParameterImplBase(str), NOOP),
(Arg(object(), object()), NOOP),
(SimpleParameter(str), NOOP),
(UnionParameter(str), NOOP),
(ArrayParameter(str), NOOP),
(UnionParameter(Union(str)), NOOP),
(ArrayParameter(Array(str)), NOOP),
(ComplexParameter(Fields()), NOOP),
(NOT_SET, NOOP),
(object(), NOOP),
@ -73,8 +73,8 @@ class ModuleTests(unittest.TestCase):
ParameterImplBase(str),
Arg(object(), object()),
SimpleParameter(str),
UnionParameter(str, int),
ArrayParameter(str),
UnionParameter(Union(str, int)),
ArrayParameter(Array(str)),
ComplexParameter(Fields()),
NOT_SET,
object(),
@ -204,8 +204,8 @@ class UnionTests(unittest.TestCase):
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
({str, int}, Union(*{str, int})),
(frozenset([str, int]), Union(*frozenset([str, int]))),
([str], Array(str)),
((str,), Array(str)),
(None, None),