diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index f294ec8a..ec58e37a 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -1,6 +1,6 @@ from ._common import NOT_SET, ANY # noqa from ._datatype import FieldsNamespace # noqa -from ._decl import Union, Array, Field # noqa +from ._decl import Enum, Union, Array, Mapping, Field # noqa from ._errors import ( # noqa ArgumentError, ArgMissingError, IncompleteArgError, ArgTypeMismatchError, diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 2fd275d2..049b2b77 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -27,7 +27,9 @@ def _normalize_datatype(datatype): return datatype elif isinstance(datatype, Array): return datatype - elif isinstance(datatype, Field): + elif isinstance(datatype, Array): + return datatype + elif isinstance(datatype, Mapping): return datatype elif isinstance(datatype, Fields): return datatype @@ -49,7 +51,10 @@ def _normalize_datatype(datatype): datatype, = datatype return Array(datatype) elif cls is dict: - raise NotImplementedError + if len(datatype) != 1: + raise NotImplementedError + [keytype, valuetype], = datatype.items() + return Mapping(valuetype, keytype) # fallback: else: try: @@ -215,7 +220,7 @@ class Array(Readonly): ) def __repr__(self): - return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype) + return '{}(itemtype={!r})'.format(type(self).__name__, self.itemtype) def __hash__(self): return hash(self.itemtype) @@ -238,6 +243,55 @@ class Array(Readonly): return self.__class__(datatype, **kwargs) +class Mapping(Readonly): + """Declare a mapping (to a single type).""" + + def __init__(self, valuetype, keytype=str, _normalize=True): + if _normalize: + keytype = _transform_datatype(keytype, _normalize_datatype) + valuetype = _transform_datatype(valuetype, _normalize_datatype) + self._bind_attrs( + keytype=keytype, + valuetype=valuetype, + ) + + def __repr__(self): + if self.keytype is str: + return '{}(valuetype={!r})'.format(type(self).__name__, self.valuetype) + else: + return '{}(keytype={!r}, valuetype={!r})'.format( + type(self).__name__, self.keytype, self.valuetype) + + def __hash__(self): + return hash((self.keytype, self.valuetype)) + + def __eq__(self, other): + try: + other_keytype = other.keytype + other_valuetype = other.valuetype + except AttributeError: + return False + if self.keytype != other_keytype: + return False + if self.valuetype != other_valuetype: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def traverse(self, op, **kwargs): + """Return a copy with op applied to the item datatype.""" + keytype = op(self.keytype) + valuetype = op(self.valuetype) + if (keytype is self.keytype and + valuetype is self.valuetype and + not kwargs + ): + return self + return self.__class__(valuetype, keytype, **kwargs) + + class Field(namedtuple('Field', 'name datatype default optional')): """Declare a field in a data map param.""" diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 8164c09f..2aa953ee 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -4,7 +4,7 @@ from debugger_protocol.arg import NOT_SET, ANY from debugger_protocol.arg._datatype import FieldsNamespace from debugger_protocol.arg._decl import ( REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype, - Enum, Union, Array, Field, Fields) + Enum, Union, Array, Mapping, Field, Fields) from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from debugger_protocol.arg._params import ( SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter) @@ -37,6 +37,8 @@ class ModuleTests(unittest.TestCase): (Array(str), NOOP), ([str], Array(str)), ((str,), Array(str)), + (Mapping(str), NOOP), + ({str: str}, Mapping(str)), # others (Field('spam'), NOOP), (Fields(Field('spam')), NOOP), @@ -61,9 +63,6 @@ class ModuleTests(unittest.TestCase): self.assertEqual(datatype, expected) - with self.assertRaises(NotImplementedError): - _normalize_datatype({1: 2}) - def test_transform_datatype_simple(self): datatypes = [ REF, @@ -102,6 +101,7 @@ class ModuleTests(unittest.TestCase): class Spam(FieldsNamespace): FIELDS = [ Field('a'), + Field('b', {str: str}) ] fields = Fields(Field('...')) @@ -145,9 +145,12 @@ class ModuleTests(unittest.TestCase): Union(Array(Spam)), Array(Spam), Spam, - #Fields(Field('a')), Field('a'), str, + Field('b', Mapping(str)), + Mapping(str), + str, + str, # ... field_eggs, Array(TYPE_REFERENCE), @@ -213,6 +216,7 @@ class UnionTests(unittest.TestCase): (frozenset([str, int]), Union(*frozenset([str, int]))), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -221,9 +225,6 @@ class UnionTests(unittest.TestCase): self.assertEqual(union, Union(int, expected, str)) - with self.assertRaises(NotImplementedError): - Union({1: 2}) - def test_traverse_noop(self): calls = [] op = (lambda dt: calls.append(dt) or dt) @@ -260,6 +261,7 @@ class ArrayTests(unittest.TestCase): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -268,9 +270,6 @@ class ArrayTests(unittest.TestCase): self.assertEqual(array, Array(expected)) - with self.assertRaises(NotImplementedError): - Array({1: 2}) - def test_normalized_transformed(self): calls = 0 @@ -311,6 +310,67 @@ class ArrayTests(unittest.TestCase): ]) +class MappingTests(unittest.TestCase): + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + ({str: str}, Mapping(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + mapping = Mapping(datatype) + + self.assertEqual(mapping, Mapping(expected)) + + def test_normalized_transformed(self): + calls = 0 + + class Spam: + @classmethod + def traverse(cls, op): + nonlocal calls + calls += 1 + return cls + + mapping = Mapping(Spam) + + self.assertIs(mapping.keytype, str) + self.assertIs(mapping.valuetype, Spam) + self.assertEqual(calls, 1) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + mapping = Mapping(Union(str, int)) + transformed = mapping.traverse(op) + + self.assertIs(transformed, mapping) + self.assertCountEqual(calls, [ + str, + # Note that it did not recurse into Union(str, int). + Union(str, int), + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or str) + mapping = Mapping(ANY) + transformed = mapping.traverse(op) + + self.assertIsNot(transformed, mapping) + self.assertEqual(transformed, Mapping(str)) + self.assertEqual(calls, [ + str, + ANY, + ]) + + class FieldTests(unittest.TestCase): def test_defaults(self): @@ -333,6 +393,7 @@ class FieldTests(unittest.TestCase): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -341,9 +402,6 @@ class FieldTests(unittest.TestCase): self.assertEqual(field, Field('spam', expected)) - with self.assertRaises(NotImplementedError): - Field('spam', {1: 2}) - def test_normalized_transformed(self): calls = 0 @@ -420,6 +478,7 @@ class FieldsTests(unittest.TestCase): (frozenset([str, int]), Union(str, int)), ([str], Array(str)), ((str,), Array(str)), + ({str: str}, Mapping(str)), (None, None), ] for datatype, expected in tests: @@ -432,11 +491,6 @@ class FieldsTests(unittest.TestCase): Field('spam', expected), ]) - with self.assertRaises(NotImplementedError): - Fields( - Field('spam', {1: 2}), - ) - def test_with_START_OPTIONAL(self): fields = Fields( Field('spam'),