diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 4f0c2ff9..3b62e91c 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -9,6 +9,17 @@ REF = '' TYPE_REFERENCE = sentinel('TYPE_REFERENCE') +def _is_simple(datatype): + if datatype is ANY: + return True + elif datatype in list(SIMPLE_TYPES): + return True + elif isinstance(datatype, Enum): + return True + else: + return False + + def _normalize_datatype(datatype): cls = type(datatype) if datatype == REF or datatype is TYPE_REFERENCE: @@ -51,31 +62,31 @@ def _replace_ref(datatype, target): return datatype -class Enum(namedtuple('Enum', 'datatype choices')): +class Enum(namedtuple('Enum', 'datatype choice')): """A simple type with a limited set of allowed values.""" @classmethod - def _check_choices(cls, datatype, choices, strict=True): - if callable(choices): - return choices + def _check_choice(cls, datatype, choice, strict=True): + if callable(choice): + return choice - if isinstance(choices, str): - msg = 'bad choices (expected {!r} values, got {!r})' - raise ValueError(msg.format(datatype, choices)) + if isinstance(choice, str): + msg = 'bad choice (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choice)) - choices = frozenset(choices) - if not choices: - raise TypeError('missing choices') + choice = frozenset(choice) + if not choice: + raise TypeError('missing choice') if not strict: - return choices + return choice - for value in choices: + for value in choice: if type(value) is not datatype: - msg = 'bad choices (expected {!r} values, got {!r})' - raise ValueError(msg.format(datatype, choices)) - return choices + msg = 'bad choice (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choice)) + return choice - def __new__(cls, datatype, choices, **kwargs): + def __new__(cls, datatype, choice, **kwargs): strict = kwargs.pop('strict', True) normalize = kwargs.pop('_normalize', True) (lambda: None)(**kwargs) # Make sure there aren't any other kwargs. @@ -88,18 +99,19 @@ class Enum(namedtuple('Enum', 'datatype choices')): if normalize: # There's no need to normalize datatype (it's a simple type). pass - choices = cls._check_choices(datatype, choices, strict=strict) + choice = cls._check_choice(datatype, choice, strict=strict) - self = super(Enum, cls).__new__(cls, datatype, choices) + self = super(Enum, cls).__new__(cls, datatype, choice) return self -class Union(frozenset): +class Union(tuple): """Declare a union of different types. - Sets and frozensets are treated equivalently in declarations. + The declared order is preserved and respected. + + Sets and frozensets are treated Unions in declarations. """ - __slots__ = () @classmethod def _traverse(cls, datatypes, op): @@ -122,11 +134,31 @@ class Union(frozenset): datatypes, lambda dt: _transform_datatype(dt, _normalize_datatype), ) - return super(Union, cls).__new__(cls, datatypes) + self = super(Union, cls).__new__(cls, datatypes) + self._simple = all(_is_simple(dt) for dt in datatypes) + return self def __repr__(self): return '{}{}'.format(type(self).__name__, tuple(self)) + def __hash__(self): + return super(Union, self).__hash__() + + def __eq__(self, other): # honors order + if not isinstance(other, Union): + return NotImplemented + if super(Union, self).__eq__(other): + return True + if set(self) != set(other): + return False + if self._simple and other._simple: + return True + return NotImplemented + + + def __ne__(self, other): + return not (self == other) + @property def datatypes(self): return set(self) diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index ec296402..c3dfd2ee 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -37,8 +37,8 @@ class ModuleTests(unittest.TestCase): (ParameterImplBase(str), NOOP), (Arg(object(), object()), NOOP), (SimpleParameter(str), NOOP), - (UnionParameter(str), NOOP), - (ArrayParameter(str), NOOP), + (UnionParameter(Union(str)), NOOP), + (ArrayParameter(Array(str)), NOOP), (ComplexParameter(Fields()), NOOP), (NOT_SET, NOOP), (object(), NOOP), @@ -73,8 +73,8 @@ class ModuleTests(unittest.TestCase): ParameterImplBase(str), Arg(object(), object()), SimpleParameter(str), - UnionParameter(str, int), - ArrayParameter(str), + UnionParameter(Union(str, int)), + ArrayParameter(Array(str)), ComplexParameter(Fields()), NOT_SET, object(), @@ -204,8 +204,8 @@ class UnionTests(unittest.TestCase): def test_normalized(self): tests = [ (REF, TYPE_REFERENCE), - ({str, int}, Union(str, int)), - (frozenset([str, int]), Union(str, int)), + ({str, int}, Union(*{str, int})), + (frozenset([str, int]), Union(*frozenset([str, int]))), ([str], Array(str)), ((str,), Array(str)), (None, None),