From 57e388bb5fb8f28effd8570f2a81995d49b4d53c Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Thu, 1 Feb 2018 16:43:05 +0000 Subject: [PATCH] Add parameter types. --- debugger_protocol/arg/__init__.py | 1 + debugger_protocol/arg/_decl.py | 1 - debugger_protocol/arg/_param.py | 2 +- debugger_protocol/arg/_params.py | 380 +++++++++++ tests/debugger_protocol/arg/test__params.py | 698 ++++++++++++++++++++ 5 files changed, 1080 insertions(+), 2 deletions(-) create mode 100644 debugger_protocol/arg/_params.py create mode 100644 tests/debugger_protocol/arg/test__params.py diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index f99af01c..f294ec8a 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -5,3 +5,4 @@ from ._errors import ( # noqa ArgumentError, ArgMissingError, IncompleteArgError, ArgTypeMismatchError, ) +from ._params import param_from_datatype # noqa diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 3b62e91c..f7fea26f 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -301,7 +301,6 @@ class Fields(Readonly, Sequence): def __getitem__(self, index): return self._fields[index] - @property def as_dict(self): return {field.name: field for field in self._fields} diff --git a/debugger_protocol/arg/_param.py b/debugger_protocol/arg/_param.py index 33e22199..bc87c124 100644 --- a/debugger_protocol/arg/_param.py +++ b/debugger_protocol/arg/_param.py @@ -17,7 +17,7 @@ class _ParameterBase(WithRepr): def __eq__(self, other): if type(self) is not type(other): - return False + return NotImplemented return self._datatype == other._datatype def __ne__(self, other): diff --git a/debugger_protocol/arg/_params.py b/debugger_protocol/arg/_params.py new file mode 100644 index 00000000..f988539a --- /dev/null +++ b/debugger_protocol/arg/_params.py @@ -0,0 +1,380 @@ +from ._common import ANY, SIMPLE_TYPES +from ._datatype import FieldsNamespace +from ._decl import Enum, Union, Array, Field, Fields +from ._errors import ArgTypeMismatchError +from ._param import Parameter, DatatypeHandler + + +#def as_parameter(cls): +# """Return a parameter that wraps the given FieldsNamespace subclass.""" +# # XXX inject_params +# cls.normalize(_inject_params) +# param = param_from_datatype(cls) +## cls.PARAM = param +# return param +# +# +#def _inject_params(datatype): +# return param_from_datatype(datatype) + + +def param_from_datatype(datatype, **kwargs): + """Return a parameter for the given datatype.""" + if isinstance(datatype, Parameter): + return datatype + + if isinstance(datatype, DatatypeHandler): + return Parameter(datatype.datatype, datatype, **kwargs) + elif isinstance(datatype, Fields): + return ComplexParameter(datatype, **kwargs) + elif isinstance(datatype, Field): + return param_from_datatype(datatype.datatype, **kwargs) + elif datatype is ANY: + return NoopParameter() + elif datatype is None: + return SingletonParameter(None) + elif datatype in list(SIMPLE_TYPES): + return SimpleParameter(datatype, **kwargs) + elif isinstance(datatype, Enum): + return EnumParameter(datatype.datatype, datatype.choice, **kwargs) + elif isinstance(datatype, Union): + return UnionParameter(datatype, **kwargs) + elif isinstance(datatype, (set, frozenset)): + return UnionParameter(Union(*datatype), **kwargs) + elif isinstance(datatype, Array): + return ArrayParameter(datatype, **kwargs) + elif isinstance(datatype, (list, tuple)): + datatype, = datatype + return ArrayParameter(Array(datatype), **kwargs) + elif not isinstance(datatype, type): + raise NotImplementedError + elif issubclass(datatype, FieldsNamespace): + return ComplexParameter(datatype, **kwargs) + else: + raise NotImplementedError + + +######################## +# param types + +class NoopParameter(Parameter): + """A parameter that treats any value as-is.""" + def __init__(self): + handler = DatatypeHandler(ANY) + super(NoopParameter, self).__init__(ANY, handler) + + +NOOP = NoopParameter() + + +class SingletonParameter(Parameter): + """A parameter that works only for the given value.""" + + class HANDLER(DatatypeHandler): + def validate(self, coerced): + if coerced is not self.datatype: + raise ValueError( + 'expected {!r}, got {!r}'.format(self.datatype, coerced)) + + def __init__(self, obj): + handler = self.HANDLER(obj) + super(SingletonParameter, self).__init__(obj, handler) + + def match_type(self, raw): + # Note we do not check equality for singletons. + if raw is not self.datatype: + return None + return super(SingletonParameter, self).match_type(raw) + + +class SimpleHandler(DatatypeHandler): + """A datatype handler for basic value types.""" + + def __init__(self, cls): + if not isinstance(cls, type): + raise ValueError('expected a class, got {!r}'.format(cls)) + super(SimpleHandler, self).__init__(cls) + + def coerce(self, raw): + if type(raw) is self.datatype: + return raw + return self.datatype(raw) + + def validate(self, coerced): + if type(coerced) is not self.datatype: + raise ValueError( + 'expected {!r}, got {!r}'.format(self.datatype, coerced)) + + +class SimpleParameter(Parameter): + """A parameter for basic value types.""" + + HANDLER = SimpleHandler + + def __init__(self, cls, strict=True): + handler = self.HANDLER(cls) + super(SimpleParameter, self).__init__(cls, handler) + self._strict = strict + + def match_type(self, raw): + if self._strict: + if type(raw) is not self.datatype: + return None + elif not isinstance(raw, self.datatype): + return None + return super(SimpleParameter, self).match_type(raw) + + +class EnumParameter(Parameter): + """A parameter for enums of basic value types.""" + + class HANDLER(SimpleHandler): + + def __init__(self, cls, enum): + if not enum: + raise TypeError('missing enum') + super(EnumParameter.HANDLER, self).__init__(cls) + if not callable(enum): + enum = set(enum) + self.enum = enum + + def validate(self, coerced): + super(EnumParameter.HANDLER, self).validate(coerced) + + if not self._match_enum(coerced): + msg = 'expected one of {!r}, got {!r}' + raise ValueError(msg.format(self.enum, coerced)) + + def _match_enum(self, coerced): + if callable(self.enum): + if not self.enum(coerced): + return False + elif coerced not in self.enum: + return False + return True + + def __init__(self, cls, enum): + handler = self.HANDLER(cls, enum) + super(EnumParameter, self).__init__(cls, handler) + self._match_enum = handler._match_enum + + def match_type(self, raw): + if type(raw) is not self.datatype: + return None + if not self._match_enum(raw): + return None + return super(EnumParameter, self).match_type(raw) + + +class UnionParameter(Parameter): + """A parameter that supports multiple different types.""" + + HANDLER = None # no handler + + @classmethod + def from_datatypes(cls, *datatypes, **kwargs): + datatype = Union(*datatypes) + return cls(datatype, **kwargs) + + def __init__(self, datatype, **kwargs): + if not isinstance(datatype, Union): + raise ValueError('expected Union, got {!r}'.format(datatype)) + super(UnionParameter, self).__init__(datatype) + + choice = [] + for dt in datatype: + param = param_from_datatype(dt) + choice.append(param) + self.choice = choice + + def __eq__(self, other): + if type(self) is not type(other): + return False + return set(self.datatype) == set(other.datatype) + + def match_type(self, raw): + for param in self.choice: + handler = param.match_type(raw) + if handler is not None: + return handler + return None + + +class ArrayParameter(Parameter): + """A parameter that is a list of some fixed type.""" + + class HANDLER(DatatypeHandler): + + def __init__(self, datatype, handlers=None, itemparam=None): + if not isinstance(datatype, Array): + raise ValueError( + 'expected an Array, got {!r}'.format(datatype)) + super(ArrayParameter.HANDLER, self).__init__(datatype) + self.handlers = handlers + self.itemparam = itemparam + + def coerce(self, raw): + if self.handlers is None: + if self.itemparam is None: + itemtype = self.datatype.itemtype + self.itemparam = param_from_datatype(itemtype) + handlers = [] + for item in raw: + handler = self.itemparam.match_type(item) + if handler is None: + raise ArgTypeMismatchError(item) + handlers.append(handler) + self.handlers = handlers + + result = [] + for i, item in enumerate(raw): + handler = self.handlers[i] + item = handler.coerce(item) + result.append(item) + return result + + def validate(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + for i, item in enumerate(coerced): + handler = self.handlers[i] + handler.validate(item) + + def as_data(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + data = [] + for i, item in enumerate(coerced): + handler = self.handlers[i] + datum = handler.as_data(item) + data.append(datum) + return data + + @classmethod + def from_itemtype(cls, itemtype, **kwargs): + datatype = Array(itemtype) + return cls(datatype, **kwargs) + + def __init__(self, datatype): + if not isinstance(datatype, Array): + raise ValueError('expected Array, got {!r}'.format(datatype)) + itemparam = param_from_datatype(datatype.itemtype) + handler = self.HANDLER(datatype, None, itemparam) + super(ArrayParameter, self).__init__(datatype, handler) + + self.itemparam = itemparam + + def match_type(self, raw): + if not isinstance(raw, list): + return None + handlers = [] + for item in raw: + handler = self.itemparam.match_type(item) + if handler is None: + return None + handlers.append(handler) + return self.HANDLER(self.datatype, handlers) + + +class ComplexParameter(Parameter): + + class HANDLER(DatatypeHandler): + + def __init__(self, datatype, handlers=None): + if (type(datatype) is not type or + not issubclass(datatype, FieldsNamespace) + ): + msg = 'expected FieldsNamespace, got {!r}' + raise ValueError(msg.format(datatype)) + super(ComplexParameter.HANDLER, self).__init__(datatype) + self.handlers = handlers + + def coerce(self, raw): + if self.handlers is None: + fields = self.datatype.FIELDS.as_dict() + handlers = {} + for name, value in raw.items(): + param = param_from_datatype(fields[name]) + handler = param.match_type(value) + if handler is None: + raise ArgTypeMismatchError((name, value)) + handlers[name] = handler + self.handlers = handlers + + result = {} + for name, value in raw.items(): + handler = self.handlers[name] + value = handler.coerce(value) + result[name] = value + return self.datatype(**result) + + def validate(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + for field in self.datatype.FIELDS: + try: + value = getattr(coerced, field.name) + except AttributeError: + continue + handler = self.handlers[field.name] + handler.validate(value) + + def as_data(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + data = {} + for field in self.datatype.FIELDS: + try: + value = getattr(coerced, field.name) + except AttributeError: + continue + handler = self.handlers[field.name] + datum = handler.as_data(value) + data[field.name] = datum + return data + + def __init__(self, datatype): + if isinstance(datatype, Fields): + class ArgNamespace(FieldsNamespace): + FIELDS = datatype + + datatype = ArgNamespace + elif (type(datatype) is not type or + not issubclass(datatype, FieldsNamespace)): + msg = 'expected Fields or FieldsNamespace, got {!r}' + raise ValueError(msg.format(datatype)) + datatype.normalize() + # We set handler later in match_type(). + super(ComplexParameter, self).__init__(datatype) + + self.params = {field.name: param_from_datatype(field) + for field in datatype.FIELDS} + + def __eq__(self, other): + if super(ComplexParameter, self).__eq__(other): + return True + try: + fields = self._datatype.FIELDS + other_fields = other._datatype.FIELDS + except AttributeError: + return NotImplemented + else: + return fields == other_fields + + def match_type(self, raw): + if not isinstance(raw, dict): + return None + handlers = {} + for field in self.datatype.FIELDS: + try: + value = raw[field.name] + except KeyError: + if not field.optional: + return None + value = field.default + param = self.params[field.name] + handler = param.match_type(value) + if handler is None: + return None + handlers[field.name] = handler + return self.HANDLER(self.datatype, handlers) diff --git a/tests/debugger_protocol/arg/test__params.py b/tests/debugger_protocol/arg/test__params.py new file mode 100644 index 00000000..2a05d69f --- /dev/null +++ b/tests/debugger_protocol/arg/test__params.py @@ -0,0 +1,698 @@ +import unittest + +from debugger_protocol.arg._common import NOT_SET, ANY +from debugger_protocol.arg._decl import Enum, Union, Array, Field, Fields +from debugger_protocol.arg._param import Parameter, DatatypeHandler +from debugger_protocol.arg._params import ( + param_from_datatype, + NoopParameter, SingletonParameter, + SimpleParameter, EnumParameter, + UnionParameter, ArrayParameter, ComplexParameter) + +from ._common import FIELDS_BASIC, BASIC_FULL, Basic + + +class String(str): + pass + + +class Integer(int): + pass + + +class ParamFromDatatypeTest(unittest.TestCase): + + def test_supported(self): + handler = DatatypeHandler(str) + tests = [ + (Parameter(str), Parameter(str)), + (handler, Parameter(str, handler)), + (Fields(Field('spam')), ComplexParameter(Fields(Field('spam')))), + (Field('spam'), SimpleParameter(str)), + (Field('spam', str, enum={'a'}), EnumParameter(str, {'a'})), + (ANY, NoopParameter()), + (None, SingletonParameter(None)), + (str, SimpleParameter(str)), + (int, SimpleParameter(int)), + (bool, SimpleParameter(bool)), + (Enum(str, {'a'}), EnumParameter(str, {'a'})), + (Union(str, int), UnionParameter(Union(str, int))), + ({str, int}, UnionParameter(Union(str, int))), + (frozenset([str, int]), UnionParameter(Union(str, int))), + (Array(str), ArrayParameter(Array(str))), + ([str], ArrayParameter(Array(str))), + ((str,), ArrayParameter(Array(str))), + (Basic, ComplexParameter(Basic)), + ] + for datatype, expected in tests: + with self.subTest(datatype): + param = param_from_datatype(datatype) + + self.assertEqual(param, expected) + + def test_not_supported(self): + datatypes = [ + String('spam'), + ..., + ] + for datatype in datatypes: + with self.subTest(datatype): + with self.assertRaises(NotImplementedError): + param_from_datatype(datatype) + + +class NoopParameterTests(unittest.TestCase): + + VALUES = [ + object(), + 'spam', + 10, + ['spam'], + {'spam': 42}, + True, + None, + NOT_SET, + ] + + def test_match_type(self): + values = [ + object(), + '', + 'spam', + b'spam', + 0, + 10, + 10.0, + 10+0j, + ('spam',), + (), + ['spam'], + [], + {'spam': 42}, + {}, + {'spam'}, + set(), + object, + type, + NoopParameterTests, + True, + None, + ..., + NotImplemented, + NOT_SET, + ANY, + Union(str, int), + Union(), + Array(str), + Field('spam'), + Fields(Field('spam')), + Fields(), + Basic, + ] + for value in values: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + + self.assertIs(type(handler), DatatypeHandler) + self.assertIs(handler.datatype, ANY) + + def test_coerce(self): + for value in self.VALUES: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertIs(coerced, value) + + def test_validate(self): + for value in self.VALUES: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + handler.validate(value) + + def test_as_data(self): + for value in self.VALUES: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + data = handler.as_data(value) + + self.assertIs(data, value) + + +class SingletonParameterTests(unittest.TestCase): + + def test_match_type_matched(self): + param = SingletonParameter(None) + handler = param.match_type(None) + + self.assertIs(handler.datatype, None) + + def test_match_type_no_match(self): + tests = [ + # same type, different value + ('spam', 'eggs'), + (10, 11), + (True, False), + # different type but equivalent + ('spam', b'spam'), + (10, 10.0), + (10, 10+0j), + (10, '10'), + (10, b'\10'), + ] + for singleton, value in tests: + with self.subTest((singleton, value)): + param = SingletonParameter(singleton) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce(self): + param = SingletonParameter(None) + handler = param.match_type(None) + value = handler.coerce(None) + + self.assertIs(value, None) + + def test_validate_valid(self): + param = SingletonParameter(None) + handler = param.match_type(None) + handler.validate(None) + + def test_validate_wrong_type(self): + tests = [ + (None, True), + (True, None), + ('spam', 10), + (10, 'spam'), + ] + for singleton, value in tests: + with self.subTest(singleton): + param = SingletonParameter(singleton) + handler = param.match_type(singleton) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_validate_same_type_wrong_value(self): + tests = [ + ('spam', 'eggs'), + (True, False), + (10, 11), + ] + for singleton, value in tests: + with self.subTest(singleton): + param = SingletonParameter(singleton) + handler = param.match_type(singleton) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_as_data(self): + param = SingletonParameter(None) + handler = param.match_type(None) + data = handler.as_data(None) + + self.assertIs(data, None) + + +class SimpleParameterTests(unittest.TestCase): + + def test_match_type_match(self): + tests = [ + (str, 'spam'), + (str, String('spam')), + (int, 10), + (bool, True), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + param = SimpleParameter(datatype, strict=False) + handler = param.match_type(value) + + self.assertIs(handler.datatype, datatype) + + def test_match_type_no_match(self): + tests = [ + (int, 'spam'), + # coercible + (str, 10), + (int, 10.0), + (int, '10'), + (bool, 1), + # semi-coercible + (str, b'spam'), + (int, 10+0j), + (int, b'\10'), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + param = SimpleParameter(datatype, strict=False) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_match_type_strict_match(self): + tests = { + str: 'spam', + int: 10, + bool: True, + } + for datatype, value in tests.items(): + with self.subTest(datatype): + param = SimpleParameter(datatype, strict=True) + handler = param.match_type(value) + + self.assertIs(handler.datatype, datatype) + + def test_match_type_strict_no_match(self): + tests = { + str: String('spam'), + int: Integer(10), + } + for datatype, value in tests.items(): + with self.subTest(datatype): + param = SimpleParameter(datatype, strict=True) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce(self): + tests = [ + (str, 'spam', 'spam'), + (str, String('spam'), 'spam'), + (int, 10, 10), + (bool, True, True), + # did not match, but still coercible + (str, 10, '10'), + (str, str, ""), + (int, 10.0, 10), + (int, '10', 10), + (bool, 1, True), + ] + for datatype, value, expected in tests: + with self.subTest((datatype, value)): + handler = SimpleParameter.HANDLER(datatype) + coerced = handler.coerce(value) + + self.assertEqual(coerced, expected) + + def test_validate_valid(self): + tests = { + str: 'spam', + int: 10, + bool: True, + } + for datatype, value in tests.items(): + with self.subTest(datatype): + handler = SimpleParameter.HANDLER(datatype) + handler.validate(value) + + def test_validate_invalid(self): + tests = [ + (int, 'spam'), + # coercible + (str, String('spam')), + (str, 10), + (int, 10.0), + (int, '10'), + (bool, 1), + # semi-coercible + (str, b'spam'), + (int, 10+0j), + (int, b'\10'), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + handler = SimpleParameter.HANDLER(datatype) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_as_data(self): + tests = [ + (str, 'spam'), + (int, 10), + (bool, True), + # did not match, but still coercible + (str, String('spam')), + (str, 10), + (str, str), + (int, 10.0), + (int, '10'), + (bool, 1), + # semi-coercible + (str, b'spam'), + (int, 10+0j), + (int, b'\10'), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + handler = SimpleParameter.HANDLER(datatype) + data = handler.as_data(value) + + self.assertIs(data, value) + + +class EnumParameterTests(unittest.TestCase): + + def test_match_type_match(self): + tests = [ + (str, ('spam', 'eggs'), 'spam'), + (str, ('spam',), 'spam'), + (int, (1, 2, 3), 2), + (bool, (True,), True), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum)): + param = EnumParameter(datatype, enum) + handler = param.match_type(value) + + self.assertIs(handler.datatype, datatype) + + def test_match_type_no_match(self): + tests = [ + # enum mismatch + (str, ('spam', 'eggs'), 'ham'), + (int, (1, 2, 3), 10), + # type mismatch + (int, (1, 2, 3), 'spam'), + # coercible + (str, ('spam', 'eggs'), String('spam')), + (str, ('1', '2', '3'), 2), + (int, (1, 2, 3), 2.0), + (int, (1, 2, 3), '2'), + (bool, (True,), 1), + # semi-coercible + (str, ('spam', 'eggs'), b'spam'), + (int, (1, 2, 3), 2+0j), + (int, (1, 2, 3), b'\02'), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum, value)): + param = EnumParameter(datatype, enum) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce(self): + tests = [ + (str, 'spam', 'spam'), + (int, 10, 10), + (bool, True, True), + # did not match, but still coercible + (str, String('spam'), 'spam'), + (str, 10, '10'), + (str, str, ""), + (int, 10.0, 10), + (int, '10', 10), + (bool, 1, True), + ] + for datatype, value, expected in tests: + with self.subTest((datatype, value)): + enum = (expected,) + handler = EnumParameter.HANDLER(datatype, enum) + coerced = handler.coerce(value) + + self.assertEqual(coerced, expected) + + def test_coerce_enum_mismatch(self): + enum = ('spam', 'eggs') + handler = EnumParameter.HANDLER(str, enum) + coerced = handler.coerce('ham') + + # It still works. + self.assertEqual(coerced, 'ham') + + def test_validate_valid(self): + tests = [ + (str, ('spam', 'eggs'), 'spam'), + (str, ('spam',), 'spam'), + (int, (1, 2, 3), 2), + (bool, (True, False), True), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum)): + handler = EnumParameter.HANDLER(datatype, enum) + handler.validate(value) + + def test_validate_invalid(self): + tests = [ + # enum mismatch + (str, ('spam', 'eggs'), 'ham'), + (int, (1, 2, 3), 10), + # type mismatch + (int, (1, 2, 3), 'spam'), + # coercible + (str, ('spam', 'eggs'), String('spam')), + (str, ('1', '2', '3'), 2), + (int, (1, 2, 3), 2.0), + (int, (1, 2, 3), '2'), + (bool, (True,), 1), + # semi-coercible + (str, ('spam', 'eggs'), b'spam'), + (int, (1, 2, 3), 2+0j), + (int, (1, 2, 3), b'\02'), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum, value)): + handler = EnumParameter.HANDLER(datatype, enum) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_as_data(self): + tests = [ + (str, ('spam', 'eggs'), 'spam'), + (str, ('spam',), 'spam'), + (int, (1, 2, 3), 2), + (bool, (True,), True), + # enum mismatch + (str, ('spam', 'eggs'), 'ham'), + (int, (1, 2, 3), 10), + # type mismatch + (int, (1, 2, 3), 'spam'), + # coercible + (str, ('spam', 'eggs'), String('spam')), + (str, ('1', '2', '3'), 2), + (int, (1, 2, 3), 2.0), + (int, (1, 2, 3), '2'), + (bool, (True,), 1), + # semi-coercible + (str, ('spam', 'eggs'), b'spam'), + (int, (1, 2, 3), 2+0j), + (int, (1, 2, 3), b'\02'), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum, value)): + handler = EnumParameter.HANDLER(datatype, enum) + data = handler.as_data(value) + + self.assertIs(data, value) + + +class UnionParameterTests(unittest.TestCase): + + def test_match_type_all_simple(self): + tests = [ + 'spam', + 10, + True, + ] + datatype = Union(str, int, bool) + param = UnionParameter(datatype) + for value in tests: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(type(handler), SimpleParameter.HANDLER) + self.assertIs(handler.datatype, type(value)) + + def test_match_type_mixed(self): + datatype = Union( + str, + # XXX add dedicated enums + Enum(int, (1, 2, 3)), + Basic, + Array(str), + Array(int), + Union(int, bool), + ) + param = UnionParameter(datatype) + + tests = [ + ('spam', SimpleParameter.HANDLER(str)), + (2, EnumParameter.HANDLER(int, (1, 2, 3))), + (BASIC_FULL, ComplexParameter(Basic).match_type(BASIC_FULL)), + (['spam'], ArrayParameter.HANDLER(Array(str))), + ([], ArrayParameter.HANDLER(Array(str))), + ([10], ArrayParameter.HANDLER(Array(int))), + (10, SimpleParameter.HANDLER(int)), + (True, SimpleParameter.HANDLER(bool)), + # no match + (Integer(2), None), + ([True], None), + ({}, None), + ] + for value, expected in tests: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_catchall(self): + NOOP = DatatypeHandler(ANY) + param = UnionParameter(Union(int, str, ANY)) + tests = [ + ('spam', SimpleParameter.HANDLER(str)), + (10, SimpleParameter.HANDLER(int)), + # catchall + (BASIC_FULL, NOOP), + (['spam'], NOOP), + (True, NOOP), + (Integer(2), NOOP), + ([10], NOOP), + ({}, NOOP), + ] + for value, expected in tests: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_no_match(self): + param = UnionParameter(Union(int, str)) + values = [ + BASIC_FULL, + ['spam'], + True, + Integer(2), + [10], + {}, + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(handler, None) + + +class ArrayParameterTests(unittest.TestCase): + + def test_match_type_match(self): + param = ArrayParameter(Array(str)) + expected = ArrayParameter.HANDLER(Array(str)) + values = [ + ['a', 'b', 'c'], + [], + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_no_match(self): + param = ArrayParameter(Array(str)) + values = [ + ['a', 1, 'c'], + ('a', 'b', 'c'), + 'spam', + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce_simple(self): + param = ArrayParameter(Array(str)) + values = [ + ['a', 'b', 'c'], + [], + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, value) + + def test_coerce_complicated(self): + param = ArrayParameter(Array(Union(str, Basic))) + value = [ + 'a', + BASIC_FULL, + 'c', + ] + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, [ + 'a', + Basic(name='spam', value='eggs'), + 'c', + ]) + + def test_validate(self): + param = ArrayParameter(Array(str)) + handler = param.match_type(['a', 'b', 'c']) + handler.validate(['a', 'b', 'c']) + + def test_as_data_simple(self): + param = ArrayParameter(Array(str)) + handler = param.match_type(['a', 'b', 'c']) + data = handler.as_data(['a', 'b', 'c']) + + self.assertEqual(data, ['a', 'b', 'c']) + + def test_as_data_complicated(self): + param = ArrayParameter(Array(Union(str, Basic))) + value = [ + 'a', + BASIC_FULL, + 'c', + ] + handler = param.match_type(value) + data = handler.as_data([ + 'a', + Basic(name='spam', value='eggs'), + 'c', + ]) + + self.assertEqual(data, value) + + +class ComplexParameterTests(unittest.TestCase): + + def test_match_type(self): + fields = Fields(*FIELDS_BASIC) + param = ComplexParameter(fields) + handler = param.match_type(BASIC_FULL) + + self.assertIs(type(handler), ComplexParameter.HANDLER) + self.assertEqual(handler.datatype.FIELDS, fields) + + def test_coerce(self): + handler = ComplexParameter.HANDLER(Basic) + coerced = handler.coerce(BASIC_FULL) + + self.assertEqual(coerced, Basic(**BASIC_FULL)) + + def test_validate(self): + handler = ComplexParameter.HANDLER(Basic) + handler.coerce(BASIC_FULL) + coerced = Basic(**BASIC_FULL) + handler.validate(coerced) + + def test_as_data(self): + handler = ComplexParameter.HANDLER(Basic) + handler.coerce(BASIC_FULL) + coerced = Basic(**BASIC_FULL) + data = handler.as_data(coerced) + + self.assertEqual(data, BASIC_FULL)