Add a Mapping declaration type.

This commit is contained in:
Eric Snow 2018-02-06 08:37:33 +00:00
parent 215f91d044
commit 800e05cd7a
3 changed files with 131 additions and 23 deletions

View file

@ -1,6 +1,6 @@
from ._common import NOT_SET, ANY # noqa
from ._datatype import FieldsNamespace # noqa
from ._decl import Union, Array, Field # noqa
from ._decl import Enum, Union, Array, Mapping, Field # noqa
from ._errors import ( # noqa
ArgumentError,
ArgMissingError, IncompleteArgError, ArgTypeMismatchError,

View file

@ -27,7 +27,9 @@ def _normalize_datatype(datatype):
return datatype
elif isinstance(datatype, Array):
return datatype
elif isinstance(datatype, Field):
elif isinstance(datatype, Array):
return datatype
elif isinstance(datatype, Mapping):
return datatype
elif isinstance(datatype, Fields):
return datatype
@ -49,7 +51,10 @@ def _normalize_datatype(datatype):
datatype, = datatype
return Array(datatype)
elif cls is dict:
raise NotImplementedError
if len(datatype) != 1:
raise NotImplementedError
[keytype, valuetype], = datatype.items()
return Mapping(valuetype, keytype)
# fallback:
else:
try:
@ -215,7 +220,7 @@ class Array(Readonly):
)
def __repr__(self):
return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype)
return '{}(itemtype={!r})'.format(type(self).__name__, self.itemtype)
def __hash__(self):
return hash(self.itemtype)
@ -238,6 +243,55 @@ class Array(Readonly):
return self.__class__(datatype, **kwargs)
class Mapping(Readonly):
"""Declare a mapping (to a single type)."""
def __init__(self, valuetype, keytype=str, _normalize=True):
if _normalize:
keytype = _transform_datatype(keytype, _normalize_datatype)
valuetype = _transform_datatype(valuetype, _normalize_datatype)
self._bind_attrs(
keytype=keytype,
valuetype=valuetype,
)
def __repr__(self):
if self.keytype is str:
return '{}(valuetype={!r})'.format(type(self).__name__, self.valuetype)
else:
return '{}(keytype={!r}, valuetype={!r})'.format(
type(self).__name__, self.keytype, self.valuetype)
def __hash__(self):
return hash((self.keytype, self.valuetype))
def __eq__(self, other):
try:
other_keytype = other.keytype
other_valuetype = other.valuetype
except AttributeError:
return False
if self.keytype != other_keytype:
return False
if self.valuetype != other_valuetype:
return False
return True
def __ne__(self, other):
return not (self == other)
def traverse(self, op, **kwargs):
"""Return a copy with op applied to the item datatype."""
keytype = op(self.keytype)
valuetype = op(self.valuetype)
if (keytype is self.keytype and
valuetype is self.valuetype and
not kwargs
):
return self
return self.__class__(valuetype, keytype, **kwargs)
class Field(namedtuple('Field', 'name datatype default optional')):
"""Declare a field in a data map param."""

View file

@ -4,7 +4,7 @@ from debugger_protocol.arg import NOT_SET, ANY
from debugger_protocol.arg._datatype import FieldsNamespace
from debugger_protocol.arg._decl import (
REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype,
Enum, Union, Array, Field, Fields)
Enum, Union, Array, Mapping, Field, Fields)
from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg
from debugger_protocol.arg._params import (
SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter)
@ -37,6 +37,8 @@ class ModuleTests(unittest.TestCase):
(Array(str), NOOP),
([str], Array(str)),
((str,), Array(str)),
(Mapping(str), NOOP),
({str: str}, Mapping(str)),
# others
(Field('spam'), NOOP),
(Fields(Field('spam')), NOOP),
@ -61,9 +63,6 @@ class ModuleTests(unittest.TestCase):
self.assertEqual(datatype, expected)
with self.assertRaises(NotImplementedError):
_normalize_datatype({1: 2})
def test_transform_datatype_simple(self):
datatypes = [
REF,
@ -102,6 +101,7 @@ class ModuleTests(unittest.TestCase):
class Spam(FieldsNamespace):
FIELDS = [
Field('a'),
Field('b', {str: str})
]
fields = Fields(Field('...'))
@ -145,9 +145,12 @@ class ModuleTests(unittest.TestCase):
Union(Array(Spam)),
Array(Spam),
Spam,
#Fields(Field('a')),
Field('a'),
str,
Field('b', Mapping(str)),
Mapping(str),
str,
str,
# ...
field_eggs,
Array(TYPE_REFERENCE),
@ -213,6 +216,7 @@ class UnionTests(unittest.TestCase):
(frozenset([str, int]), Union(*frozenset([str, int]))),
([str], Array(str)),
((str,), Array(str)),
({str: str}, Mapping(str)),
(None, None),
]
for datatype, expected in tests:
@ -221,9 +225,6 @@ class UnionTests(unittest.TestCase):
self.assertEqual(union, Union(int, expected, str))
with self.assertRaises(NotImplementedError):
Union({1: 2})
def test_traverse_noop(self):
calls = []
op = (lambda dt: calls.append(dt) or dt)
@ -260,6 +261,7 @@ class ArrayTests(unittest.TestCase):
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
({str: str}, Mapping(str)),
(None, None),
]
for datatype, expected in tests:
@ -268,9 +270,6 @@ class ArrayTests(unittest.TestCase):
self.assertEqual(array, Array(expected))
with self.assertRaises(NotImplementedError):
Array({1: 2})
def test_normalized_transformed(self):
calls = 0
@ -311,6 +310,67 @@ class ArrayTests(unittest.TestCase):
])
class MappingTests(unittest.TestCase):
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
({str: str}, Mapping(str)),
(None, None),
]
for datatype, expected in tests:
with self.subTest(datatype):
mapping = Mapping(datatype)
self.assertEqual(mapping, Mapping(expected))
def test_normalized_transformed(self):
calls = 0
class Spam:
@classmethod
def traverse(cls, op):
nonlocal calls
calls += 1
return cls
mapping = Mapping(Spam)
self.assertIs(mapping.keytype, str)
self.assertIs(mapping.valuetype, Spam)
self.assertEqual(calls, 1)
def test_traverse_noop(self):
calls = []
op = (lambda dt: calls.append(dt) or dt)
mapping = Mapping(Union(str, int))
transformed = mapping.traverse(op)
self.assertIs(transformed, mapping)
self.assertCountEqual(calls, [
str,
# Note that it did not recurse into Union(str, int).
Union(str, int),
])
def test_traverse_changed(self):
calls = []
op = (lambda dt: calls.append(dt) or str)
mapping = Mapping(ANY)
transformed = mapping.traverse(op)
self.assertIsNot(transformed, mapping)
self.assertEqual(transformed, Mapping(str))
self.assertEqual(calls, [
str,
ANY,
])
class FieldTests(unittest.TestCase):
def test_defaults(self):
@ -333,6 +393,7 @@ class FieldTests(unittest.TestCase):
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
({str: str}, Mapping(str)),
(None, None),
]
for datatype, expected in tests:
@ -341,9 +402,6 @@ class FieldTests(unittest.TestCase):
self.assertEqual(field, Field('spam', expected))
with self.assertRaises(NotImplementedError):
Field('spam', {1: 2})
def test_normalized_transformed(self):
calls = 0
@ -420,6 +478,7 @@ class FieldsTests(unittest.TestCase):
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
({str: str}, Mapping(str)),
(None, None),
]
for datatype, expected in tests:
@ -432,11 +491,6 @@ class FieldsTests(unittest.TestCase):
Field('spam', expected),
])
with self.assertRaises(NotImplementedError):
Fields(
Field('spam', {1: 2}),
)
def test_with_START_OPTIONAL(self):
fields = Fields(
Field('spam'),