From 2fc452e3d832f20f1d359ea392625846d8225c82 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Thu, 18 Jan 2018 17:35:36 +0000 Subject: [PATCH 01/10] Add Readonly and WithRepr. --- debugger_protocol/_base.py | 28 ++++++++++++++++++++++++++++ debugger_protocol/messages/_base.py | 22 ++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 debugger_protocol/_base.py create mode 100644 debugger_protocol/messages/_base.py diff --git a/debugger_protocol/_base.py b/debugger_protocol/_base.py new file mode 100644 index 00000000..f0f8aff9 --- /dev/null +++ b/debugger_protocol/_base.py @@ -0,0 +1,28 @@ + + +class Readonly(object): + """For read-only instances.""" + + def __setattr__(self, name, value): + raise AttributeError( + '{} objects are read-only'.format(type(self).__name__)) + + def __delattr__(self, name): + raise AttributeError( + '{} objects are read-only'.format(type(self).__name__)) + + def _bind_attrs(self, **attrs): + for name, value in attrs.items(): + object.__setattr__(self, name, value) + + +class WithRepr(object): + + def _init_args(self): + # XXX Extract from __init__()... + return () + + def __repr__(self): + args = ', '.join('{}={!r}'.format(arg, value) + for arg, value in self._init_args()) + return '{}({})'.format(type(self).__name__, args) diff --git a/debugger_protocol/messages/_base.py b/debugger_protocol/messages/_base.py new file mode 100644 index 00000000..7efc3e3b --- /dev/null +++ b/debugger_protocol/messages/_base.py @@ -0,0 +1,22 @@ +from debugger_protocol._base import Readonly, WithRepr + + +class Base(Readonly, WithRepr): + """Base class for message-related types.""" + + _INIT_ARGS = None + + @classmethod + def from_data(cls, **kwargs): + """Return an instance based on the given raw data.""" + return cls(**kwargs) + + def __init__(self): + self._validate() + + def _validate(self): + pass + + def as_data(self): + """Return serializable data for the instance.""" + return {} From 66224aa0856bcb04063a984c003a462deecfd809 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 30 Jan 2018 19:08:42 +0000 Subject: [PATCH 02/10] Add a Stub testing helper. --- tests/helpers/stub.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/helpers/stub.py diff --git a/tests/helpers/stub.py b/tests/helpers/stub.py new file mode 100644 index 00000000..c51eff90 --- /dev/null +++ b/tests/helpers/stub.py @@ -0,0 +1,25 @@ + + +class Stub(object): + """A testing double that tracks calls.""" + + def __init__(self): + self.calls = [] + self._exceptions = [] + + def set_exceptions(self, *exceptions): + self._exceptions = list(exceptions) + + def add_call(self, name, *args, **kwargs): + self.add_call_exact(name, args, kwargs) + + def add_call_exact(self, name, args, kwargs): + self.calls.append((name, args, kwargs)) + + def maybe_raise(self): + if not self._exceptions: + return + exc = self._exceptions.pop(0) + if exc is None: + return + raise exc From 21d161113724d83dd3439415124474b6798809a4 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 29 Jan 2018 18:12:36 +0000 Subject: [PATCH 03/10] Add the debugger_protocol.arg packge. --- debugger_protocol/arg/__init__.py | 1 + debugger_protocol/arg/_common.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 debugger_protocol/arg/__init__.py create mode 100644 debugger_protocol/arg/_common.py diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py new file mode 100644 index 00000000..fc4abbf4 --- /dev/null +++ b/debugger_protocol/arg/__init__.py @@ -0,0 +1 @@ +from ._common import NOT_SET, ANY # noqa diff --git a/debugger_protocol/arg/_common.py b/debugger_protocol/arg/_common.py new file mode 100644 index 00000000..c36af8e7 --- /dev/null +++ b/debugger_protocol/arg/_common.py @@ -0,0 +1,17 @@ + +def sentinel(name): + """Return a named value to use as a sentinel.""" + class Sentinel(object): + def __repr__(self): + return name + + return Sentinel() + + +# NOT_SET indicates that an arg was not provided. +NOT_SET = sentinel('NOT_SET') + +# ANY is a datatype surrogate indicating that any value is okay. +ANY = sentinel('ANY') + +SIMPLE_TYPES = {None, bool, int, str} From 414e6209748f941e65f9940e3e0c7d7dffee6091 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 29 Jan 2018 18:15:41 +0000 Subject: [PATCH 04/10] Add declarative field classes. --- debugger_protocol/arg/__init__.py | 1 + debugger_protocol/arg/_decl.py | 239 ++++++++++++ tests/debugger_protocol/arg/__init__.py | 0 tests/debugger_protocol/arg/test__decl.py | 441 ++++++++++++++++++++++ 4 files changed, 681 insertions(+) create mode 100644 debugger_protocol/arg/_decl.py create mode 100644 tests/debugger_protocol/arg/__init__.py create mode 100644 tests/debugger_protocol/arg/test__decl.py diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index fc4abbf4..e50d0c51 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -1 +1,2 @@ from ._common import NOT_SET, ANY # noqa +from ._decl import Union, Array, Field # noqa diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py new file mode 100644 index 00000000..20743c8a --- /dev/null +++ b/debugger_protocol/arg/_decl.py @@ -0,0 +1,239 @@ +from collections import namedtuple +from collections.abc import Sequence + +from debugger_protocol._base import Readonly +from ._common import sentinel, NOT_SET, ANY, SIMPLE_TYPES + + +REF = '' +TYPE_REFERENCE = sentinel('TYPE_REFERENCE') + + +def _normalize_datatype(datatype): + cls = type(datatype) + if datatype == REF or datatype is TYPE_REFERENCE: + return TYPE_REFERENCE + elif datatype is ANY: + return ANY + elif datatype in list(SIMPLE_TYPES): + return datatype + elif isinstance(datatype, Union): + return datatype + elif isinstance(datatype, Array): + return datatype + elif cls is set or cls is frozenset: + return Union(*datatype) + elif cls is list or cls is tuple: + datatype, = datatype + return Array(datatype) + elif cls is dict: + raise NotImplementedError + else: + return datatype + + +def _transform_datatype(datatype, op): + try: + dt_traverse = datatype.traverse + except AttributeError: + pass + else: + datatype = dt_traverse(lambda dt: _transform_datatype(dt, op)) + return op(datatype) + + +def _replace_ref(datatype, target): + if datatype is TYPE_REFERENCE: + return target + else: + return datatype + + +class Union(frozenset): + """Declare a union of different types. + + Sets and frozensets are treated equivalently in declarations. + """ + __slots__ = () + + @classmethod + def _traverse(cls, datatypes, op): + changed = False + result = [] + for datatype in datatypes: + transformed = op(datatype) + if transformed is not datatype: + changed = True + result.append(transformed) + return result, changed + + def __new__(cls, *datatypes, **kwargs): + normalize = kwargs.pop('_normalize', True) + (lambda: None)(**kwargs) # Make sure there aren't any other kwargs. + + datatypes = list(datatypes) + if normalize: + datatypes, _ = cls._traverse( + datatypes, + lambda dt: _transform_datatype(dt, _normalize_datatype), + ) + return super(Union, cls).__new__(cls, datatypes) + + def __repr__(self): + return '{}{}'.format(type(self).__name__, tuple(self)) + + @property + def datatypes(self): + return set(self) + + def traverse(self, op, **kwargs): + """Return a copy with op applied to each contained datatype.""" + datatypes, changed = self._traverse(self, op) + if not changed and not kwargs: + return self + return self.__class__(*datatypes, **kwargs) + + +class Array(Readonly): + """Declare an array (of a single type). + + Lists and tuples (single-item) are treated equivalently + in declarations. + """ + + def __init__(self, itemtype, _normalize=True): + if _normalize: + itemtype = _normalize_datatype(itemtype) + self._bind_attrs( + itemtype=itemtype, + ) + + def __repr__(self): + return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype) + + def __hash__(self): + return hash(self.itemtype) + + def __eq__(self, other): + try: + other_itemtype = other.itemtype + except AttributeError: + return False + return self.itemtype == other_itemtype + + def __ne__(self, other): + return not (self == other) + + def traverse(self, op, **kwargs): + """Return a copy with op applied to the item datatype.""" + datatype = op(self.itemtype) + if datatype is self.itemtype and not kwargs: + return self + return self.__class__(datatype, **kwargs) + + +class Field(namedtuple('Field', 'name datatype default optional')): + """Declare a field in a data map param.""" + + START_OPTIONAL = sentinel('START_OPTIONAL') + + def __new__(cls, name, datatype=str, default=NOT_SET, optional=False, + _normalize=True, **kwargs): + if _normalize: + datatype = _normalize_datatype(datatype) + self = super(Field, cls).__new__( + cls, + name=str(name) if name else None, + datatype=datatype, + default=default, + optional=bool(optional), + ) + self._kwargs = kwargs.items() + return self + + @property + def kwargs(self): + return dict(self._kwargs) + + def traverse(self, op, **kwargs): + """Return a copy with op applied to the datatype.""" + datatype = op(self.datatype) + if datatype is self.datatype and not kwargs: + return self + kwargs.setdefault('default', self.default) + kwargs.setdefault('optional', self.optional) + return self.__class__(self.name, datatype, **kwargs) + + +class Fields(Readonly, Sequence): + """Declare a set of fields.""" + + @classmethod + def _iter_fixed(cls, fields, _normalize=True): + optional = None + for field in fields or (): + if field is Field.START_OPTIONAL: + if optional is not None: + raise RuntimeError('START_OPTIONAL used more than once') + optional = True + continue + + if not isinstance(field, Field): + raise TypeError('got non-field {!r}'.format(field)) + if _normalize: + field = _transform_datatype(field, _normalize_datatype) + if optional is not None and field.optional is not optional: + field = field._replace(optional=optional) + yield field + + def __init__(self, *fields, **kwargs): + fields = list(self._iter_fixed(fields, **kwargs)) + self._bind_attrs( + _fields=fields, + ) + + def __repr__(self): + return '{}(*{})'.format(type(self).__name__, self._fields) + + def __hash__(self): + return hash(tuple(self)) + + def __eq__(self, other): + try: + other_len = len(other) + other_iter = iter(other) + except TypeError: + return False + if len(self) != other_len: + return False + for i, item in enumerate(other_iter): + if self[i] != item: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def __len__(self): + return len(self._fields) + + def __getitem__(self, index): + return self._fields[index] + + @property + def as_dict(self): + return {field.name: field for field in self._fields} + + def traverse(self, op, **kwargs): + """Return a copy with op applied to each field.""" + changed = False + updated = [] + for field in self._fields: + transformed = op(field) + if transformed is not field: + changed = True + updated.append(transformed) + + if not changed and not kwargs: + return self + return self.__class__(*updated, **kwargs) diff --git a/tests/debugger_protocol/arg/__init__.py b/tests/debugger_protocol/arg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py new file mode 100644 index 00000000..59ee2744 --- /dev/null +++ b/tests/debugger_protocol/arg/test__decl.py @@ -0,0 +1,441 @@ +import unittest + +from debugger_protocol.arg import NOT_SET, ANY +from debugger_protocol.arg._datatype import FieldsNamespace +from debugger_protocol.arg._decl import ( + REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype, + Union, Array, Field, Fields) +from debugger_protocol.arg._param import Parameter, ParameterImplBase, Arg +from debugger_protocol.arg._params import ( + SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter) + + +class ModuleTests(unittest.TestCase): + + def test_normalize_datatype(self): + NOOP = object() + tests = [ + # explicitly handled + (REF, TYPE_REFERENCE), + (TYPE_REFERENCE, NOOP), + (ANY, NOOP), + (None, NOOP), + (int, NOOP), + (str, NOOP), + (bool, NOOP), + (Union(str, int), NOOP), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + (Array(str), NOOP), + ([str], Array(str)), + ((str,), Array(str)), + # others + (Field('spam'), NOOP), + (Fields(Field('spam')), NOOP), + (Parameter(object()), NOOP), + (ParameterImplBase(str), NOOP), + (Arg(object(), object()), NOOP), + (SimpleParameter(str), NOOP), + (UnionParameter(str), NOOP), + (ArrayParameter(str), NOOP), + (ComplexParameter(Fields()), NOOP), + (NOT_SET, NOOP), + (object(), NOOP), + (object, NOOP), + (type, NOOP), + ] + for datatype, expected in tests: + if expected is NOOP: + expected = datatype + with self.subTest(datatype): + datatype = _normalize_datatype(datatype) + + self.assertEqual(datatype, expected) + + with self.assertRaises(NotImplementedError): + _normalize_datatype({1: 2}) + + def test_transform_datatype_simple(self): + datatypes = [ + REF, + TYPE_REFERENCE, + ANY, + None, + int, + str, + bool, + {str, int}, + frozenset([str, int]), + [str], + (str,), + Parameter(object()), + ParameterImplBase(str), + Arg(object(), object()), + SimpleParameter(str), + UnionParameter(str, int), + ArrayParameter(str), + ComplexParameter(Fields()), + NOT_SET, + object(), + object, + type, + ] + for expected in datatypes: + transformed = [] + op = (lambda dt: transformed.append(dt) or dt) + with self.subTest(expected): + datatype = _transform_datatype(expected, op) + + self.assertIs(datatype, expected) + self.assertEqual(transformed, [expected]) + + def test_transform_datatype_container(self): + class Spam(FieldsNamespace): + FIELDS = [ + Field('a'), + ] + + Spam.normalize() + + fields = Fields(Field('...')) + field_spam = Field('spam', ANY) + field_ham = Field('ham', Union( + Array(Spam), + )) + field_eggs = Field('eggs', Array(TYPE_REFERENCE)) + nested = Fields( + Field('???', fields), + field_spam, + field_ham, + field_eggs, + ) + tests = { + Array(str): [ + str, + Array(str), + ], + Field('...'): [ + str, + Field('...'), + ], + fields: [ + str, + Field('...'), + fields, + ], + nested: [ + str, + Field('...'), + fields, + Field('???', fields), + # ... + ANY, + Field('spam', ANY), + # ... + str, + Field('a'), + Fields(Field('a')), + Spam, + Array(Spam), + Union(Array(Spam)), + field_ham, + # ... + TYPE_REFERENCE, + Array(TYPE_REFERENCE), + field_eggs, + # ... + nested, + ], + } + for datatype, expected in tests.items(): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + with self.subTest(datatype): + transformed = _transform_datatype(datatype, op) + + self.assertIs(transformed, datatype) + self.assertEqual(calls, expected) + + # Check Union separately due to set iteration order. + calls = [] + op = (lambda dt: calls.append(dt) or dt) + datatype = Union(str, int) + transformed = _transform_datatype(datatype, op) + + self.assertIs(transformed, datatype) + self.assertEqual(set(calls[:2]), {str, int}) + self.assertEqual(calls[2:], [ + Union(str, int), + ]) + + +class UnionTests(unittest.TestCase): + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + union = Union(int, datatype, str) + + self.assertEqual(union, Union(int, expected, str)) + + with self.assertRaises(NotImplementedError): + Union({1: 2}) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + union = Union(str, Array(int), int) + transformed = union.traverse(op) + + self.assertIs(transformed, union) + self.assertCountEqual(calls, [ + str, + # Note that it did not recurse into Array(int). + Array(int), + int, + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or str) + union = Union(ANY) + transformed = union.traverse(op) + + self.assertIsNot(transformed, union) + self.assertEqual(transformed, Union(str)) + self.assertEqual(calls, [ + ANY, + ]) + + +class ArrayTests(unittest.TestCase): + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + array = Array(datatype) + + self.assertEqual(array, Array(expected)) + + with self.assertRaises(NotImplementedError): + Array({1: 2}) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + array = Array(Union(str, int)) + transformed = array.traverse(op) + + self.assertIs(transformed, array) + self.assertCountEqual(calls, [ + # Note that it did not recurse into Union(str, int). + Union(str, int), + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or str) + array = Array(ANY) + transformed = array.traverse(op) + + self.assertIsNot(transformed, array) + self.assertEqual(transformed, Array(str)) + self.assertEqual(calls, [ + ANY, + ]) + + +class FieldTests(unittest.TestCase): + + def test_defaults(self): + field = Field('spam') + + self.assertEqual(field.name, 'spam') + self.assertIs(field.datatype, str) + self.assertIs(field.default, NOT_SET) + self.assertFalse(field.optional) + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + field = Field('spam', datatype) + + self.assertEqual(field, Field('spam', expected)) + + with self.assertRaises(NotImplementedError): + Field('spam', {1: 2}) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + field = Field('spam', Union(str, int)) + transformed = field.traverse(op) + + self.assertIs(transformed, field) + self.assertCountEqual(calls, [ + # Note that it did not recurse into Union(str, int). + Union(str, int), + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or str) + field = Field('spam', ANY) + transformed = field.traverse(op) + + self.assertIsNot(transformed, field) + self.assertEqual(transformed, Field('spam', str)) + self.assertEqual(calls, [ + ANY, + ]) + + +class FieldsTests(unittest.TestCase): + + def test_single(self): + fields = Fields( + Field('spam'), + ) + + self.assertEqual(fields, [ + Field('spam'), + ]) + + def test_multiple(self): + fields = Fields( + Field('spam'), + Field('ham'), + Field('eggs'), + ) + + self.assertEqual(fields, [ + Field('spam'), + Field('ham'), + Field('eggs'), + ]) + + def test_empty(self): + fields = Fields() + + self.assertCountEqual(fields, []) + + def test_normalized(self): + tests = [ + (REF, TYPE_REFERENCE), + ({str, int}, Union(str, int)), + (frozenset([str, int]), Union(str, int)), + ([str], Array(str)), + ((str,), Array(str)), + (None, None), + ] + for datatype, expected in tests: + with self.subTest(datatype): + fields = Fields( + Field('spam', datatype), + ) + + self.assertEqual(fields, [ + Field('spam', expected), + ]) + + with self.assertRaises(NotImplementedError): + Fields( + Field('spam', {1: 2}), + ) + + def test_with_START_OPTIONAL(self): + fields = Fields( + Field('spam'), + Field('ham', optional=True), + Field('eggs'), + Field.START_OPTIONAL, + Field('a'), + Field('b', optional=False), + ) + + self.assertEqual(fields, [ + Field('spam'), + Field('ham', optional=True), + Field('eggs'), + Field('a', optional=True), + Field('b', optional=True), + ]) + + def test_non_field(self): + with self.assertRaises(TypeError): + Fields(str) + + def test_as_dict(self): + fields = Fields( + Field('spam', int), + Field('ham'), + Field('eggs', Array(str)), + ) + result = fields.as_dict + + self.assertEqual(result, { + 'spam': fields[0], + 'ham': fields[1], + 'eggs': fields[2], + }) + + def test_traverse_noop(self): + calls = [] + op = (lambda dt: calls.append(dt) or dt) + fields = Fields( + Field('spam'), + Field('ham'), + Field('eggs'), + ) + transformed = fields.traverse(op) + + self.assertIs(transformed, fields) + self.assertCountEqual(calls, [ + # Note that it did not recurse into the fields. + Field('spam'), + Field('ham'), + Field('eggs'), + ]) + + def test_traverse_changed(self): + calls = [] + op = (lambda dt: calls.append(dt) or Field(dt.name, str)) + fields = Fields( + Field('spam', ANY), + Field('eggs', None), + ) + transformed = fields.traverse(op) + + self.assertIsNot(transformed, fields) + self.assertEqual(transformed, Fields( + Field('spam', str), + Field('eggs', str), + )) + self.assertEqual(calls, [ + Field('spam', ANY), + Field('eggs', None), + ]) From 6c097a4f549a3e078f654c3021b295b3ea1c5895 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 30 Jan 2018 00:40:10 +0000 Subject: [PATCH 05/10] Add FieldsNamespace. --- debugger_protocol/arg/__init__.py | 5 + debugger_protocol/arg/_datatype.py | 277 ++++++++++++++++++ debugger_protocol/arg/_errors.py | 32 ++ tests/debugger_protocol/arg/_common.py | 50 ++++ tests/debugger_protocol/arg/test__datatype.py | 255 ++++++++++++++++ 5 files changed, 619 insertions(+) create mode 100644 debugger_protocol/arg/_datatype.py create mode 100644 debugger_protocol/arg/_errors.py create mode 100644 tests/debugger_protocol/arg/_common.py create mode 100644 tests/debugger_protocol/arg/test__datatype.py diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index e50d0c51..f99af01c 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -1,2 +1,7 @@ from ._common import NOT_SET, ANY # noqa +from ._datatype import FieldsNamespace # noqa from ._decl import Union, Array, Field # noqa +from ._errors import ( # noqa + ArgumentError, + ArgMissingError, IncompleteArgError, ArgTypeMismatchError, +) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py new file mode 100644 index 00000000..ee5b76ff --- /dev/null +++ b/debugger_protocol/arg/_datatype.py @@ -0,0 +1,277 @@ +from debugger_protocol._base import Readonly, WithRepr +from ._common import NOT_SET, ANY +from ._decl import ( + _transform_datatype, _replace_ref, + Union, Array, Field, Fields) +from ._errors import ArgTypeMismatchError, ArgMissingError, IncompleteArgError + + +def _coerce(datatype, value, call=True): + if datatype is ANY: + return value + elif type(value) is datatype: + return value + elif value is datatype: + return value + + # decl types + elif isinstance(datatype, Union): + for dt in datatype: + try: + return _coerce(dt, value, call=False) + except ArgTypeMismatchError: + continue + else: + raise ArgTypeMismatchError(value) + elif isinstance(datatype, Array): + try: + values = iter(value) + except TypeError: + raise ArgTypeMismatchError(value) + return [_coerce(datatype.itemtype, v, call=False) + for v in values] + elif isinstance(datatype, Field): + return _coerce(datatype.datatype, value) + elif isinstance(datatype, Fields): + class ArgNamespace(FieldsNamespace): + FIELDS = datatype + + return _coerce(ArgNamespace, value) + elif issubclass(datatype, FieldsNamespace): + arg = datatype.bind(value) + try: + arg_coerce = arg.coerce + except AttributeError: + return arg + else: + return arg_coerce() + + # fallbacks + elif callable(datatype) and call: + try: + return datatype(value) + except ArgTypeMismatchError: + raise + except (TypeError, ValueError): + raise ArgTypeMismatchError(value) + elif value == datatype: + return value + + raise ArgTypeMismatchError(value) + + +######################## +# fields + +class FieldsNamespace(Readonly, WithRepr): + """A namespace of field values exposed via attributes.""" + + FIELDS = None + PARAM_TYPE = None + PARAM = None + + @classmethod + def traverse(cls, op, **kwargs): + """Apply op to each field in cls.FIELDS.""" + fields = cls._normalize(cls.FIELDS) + fields = fields.traverse(op) + cls.FIELDS = cls._normalize(fields) + return cls + + @classmethod + def normalize(cls, *transforms): + """Normalize FIELDS and apply the given ops.""" + fields = cls._normalize(cls.FIELDS) + if not isinstance(fields, Fields): + fields = Fields(*fields) + for transform in transforms: + fields = _transform_datatype(fields, transform) + fields = cls._normalize(fields) + cls.FIELDS = fields + + @classmethod + def _normalize(cls, fields): + if fields is None: + raise TypeError('missing FIELDS') + if isinstance(fields, Fields): + try: + normalized = cls._normalized + except AttributeError: + normalized = cls._normalized = False + else: + fields = Fields(*fields) + normalized = cls._normalized = False + if not normalized: + fields = _transform_datatype(fields, + lambda dt: _replace_ref(dt, cls)) + return fields + + @classmethod + def bind(cls, ns, **kwargs): + param = cls.PARAM + if param is None: + if cls.PARAM_TYPE is None: + return cls(**ns) + param = cls.PARAM_TYPE(cls.FIELDS, cls) + return param.bind(ns, **kwargs) + + @classmethod + def _bind(cls, kwargs): + cls.FIELDS = cls._normalize(cls.FIELDS) + bound, missing = _fields_bind(cls.FIELDS, kwargs) + if missing: + raise IncompleteArgError(cls.FIELDS, missing) + + values = {} + validators = [] + serializers = {} + for field, arg in bound.items(): + if arg is NOT_SET: + continue + + try: + coerce = arg.coerce + except AttributeError: + value = arg + else: + value = coerce(arg) + values[field.name] = value + + try: + validate = arg.validate + validate = value.validate + except AttributeError: + pass + else: + validators.append(validate) + + try: + as_data = arg.as_data + as_data = value.as_data + except AttributeError: + pass + else: + serializers[field.name] = as_data + values['_validators'] = validators + values['_serializers'] = serializers + return values + + def __init__(self, **kwargs): + super(FieldsNamespace, self).__init__() + validate = kwargs.pop('_validate', True) + + kwargs = self._bind(kwargs) + self._bind_attrs(**kwargs) + if validate: + self.validate() + + def _init_args(self): + if self.FIELDS is not None: + for field in self.FIELDS: + try: + value = getattr(self, field.name) + except AttributeError: + continue + yield (field.name, value) + else: + for item in sorted(vars(self).items()): + yield item + + def __eq__(self, other): + try: + other_as_data = other.as_data + except AttributeError: + other_data = other + else: + other_data = other_as_data() + + return self.as_data() == other_data + + def __ne__(self, other): + return not (self == other) + + def validate(self): + """Ensure that the field values are valid.""" + for validate in self._validators: + validate() + + def as_data(self): + """Return serializable data for the instance.""" + data = {name: as_data() + for name, as_data in self._serializers.items()} + for field in self.FIELDS: + if field.name in data: + continue + try: + data[field.name] = getattr(self, field.name) + except AttributeError: + pass + return data + + +def _field_missing(field, value): + if value is NOT_SET: + return True + + try: + missing = field.datatype.missing + except AttributeError: + return None + else: + return missing(value) + + +def _field_bind(field, value, applydefaults=True): + missing = _field_missing(field, value) + if missing: + if field.optional: + if applydefaults: + return field.default + return NOT_SET + raise ArgMissingError(field, missing) + + try: + bind = field.datatype.bind + except AttributeError: + bind = (lambda v: _coerce(field.datatype, v)) + return bind(value) + + +def _fields_iter_values(fields, remainder): + for field in fields or (): + value = remainder.pop(field.name, NOT_SET) + yield field, value + + +def _fields_iter_bound(fields, remainder, applydefaults=True): + for field, value in _fields_iter_values(fields, remainder): + try: + arg = _field_bind(field, value, applydefaults=applydefaults) + except ArgMissingError as exc: + yield field, value, exc, False +# except ArgTypeMismatchError as exc: +# yield field, value, None, exc + else: + yield field, arg, False, False + + +def _fields_bind(fields, kwargs, applydefaults=True): + bound = {} + missing = {} + mismatched = {} + remainder = dict(kwargs) + bound_iter = _fields_iter_bound(fields, remainder, + applydefaults=applydefaults) + for field, arg, missed, mismatch in bound_iter: + if missed: + missing[field.name] = missed + elif mismatch: + mismatched[field.name] = arg + else: + bound[field] = arg + if remainder: + remainder = ', '.join(sorted(remainder)) + raise TypeError('got extra fields: {}'.format(remainder)) + if mismatched: + raise ArgTypeMismatchError(mismatched) + return bound, missing diff --git a/debugger_protocol/arg/_errors.py b/debugger_protocol/arg/_errors.py new file mode 100644 index 00000000..03b600f5 --- /dev/null +++ b/debugger_protocol/arg/_errors.py @@ -0,0 +1,32 @@ + +class ArgumentError(TypeError): + """The base class for argument-related exceptions.""" + + +class ArgMissingError(ArgumentError): + """Indicates that the argument for the field is missing.""" + + def __init__(self, field): + super(ArgMissingError, self).__init__( + 'missing arg {!r}'.format(field.name)) + self.field = field + + +class IncompleteArgError(ArgumentError): + """Indicates that the "complex" arg has missing fields.""" + + def __init__(self, fields, missing): + msg = 'incomplete arg (missing or incomplete fields: {})' + super(IncompleteArgError, self).__init__( + msg.format(', '.join(sorted(missing)))) + self.fields = fields + self.missing = missing + + +class ArgTypeMismatchError(ArgumentError): + """Indicates that the arg did not have the expected type.""" + + def __init__(self, value): + super(ArgTypeMismatchError, self).__init__( + 'bad value {!r} (unsupported type)'.format(value)) + self.value = value diff --git a/tests/debugger_protocol/arg/_common.py b/tests/debugger_protocol/arg/_common.py new file mode 100644 index 00000000..638d0bba --- /dev/null +++ b/tests/debugger_protocol/arg/_common.py @@ -0,0 +1,50 @@ +from debugger_protocol.arg import ANY, FieldsNamespace, Field + + +FIELDS_BASIC = [ + Field('name'), + Field.START_OPTIONAL, + Field('value'), +] + +BASIC_FULL = { + 'name': 'spam', + 'value': 'eggs', +} + +BASIC_MIN = { + 'name': 'spam', +} + + +class Basic(FieldsNamespace): + FIELDS = FIELDS_BASIC + + +FIELDS_EXTENDED = [ + Field('name', datatype=str, optional=False), + Field('valid', datatype=bool, optional=True), + Field('id', datatype=int, optional=False), + Field('value', datatype=ANY, optional=True), + Field('x', datatype=Basic, optional=True), + Field('y', datatype={int, str}, optional=True), + Field('z', datatype=[Basic], optional=True), +] + +EXTENDED_FULL = { + 'name': 'spam', + 'valid': True, + 'id': 10, + 'value': None, + 'x': BASIC_FULL, + 'y': 11, + 'z': [ + BASIC_FULL, + BASIC_MIN, + ], +} + +EXTENDED_MIN = { + 'name': 'spam', + 'id': 10, +} diff --git a/tests/debugger_protocol/arg/test__datatype.py b/tests/debugger_protocol/arg/test__datatype.py new file mode 100644 index 00000000..3ddfa6af --- /dev/null +++ b/tests/debugger_protocol/arg/test__datatype.py @@ -0,0 +1,255 @@ +import itertools +import unittest + +from debugger_protocol.arg._common import ANY +from debugger_protocol.arg._datatype import FieldsNamespace +from debugger_protocol.arg._decl import Array, Field, Fields + +from ._common import ( + BASIC_FULL, BASIC_MIN, Basic, + FIELDS_EXTENDED, EXTENDED_FULL, EXTENDED_MIN) + + +class FieldsNamespaceTests(unittest.TestCase): + + def test_traverse_noop(self): + fields = [ + Field('spam'), + Field('ham'), + Field('eggs'), + ] + + class Spam(FieldsNamespace): + FIELDS = Fields(*fields) + + calls = [] + op = (lambda dt: calls.append(dt) or dt) + transformed = Spam.traverse(op) + + self.assertIs(transformed, Spam) + self.assertIs(transformed.FIELDS, Spam.FIELDS) + for i, field in enumerate(Spam.FIELDS): + self.assertIs(field, fields[i]) + self.assertCountEqual(calls, [ + # Note that it did not recurse into the fields. + Field('spam'), + Field('ham'), + Field('eggs'), + ]) + + def test_traverse_unnormalized(self): + fields = [ + Field('spam'), + Field('ham'), + Field('eggs'), + ] + + class Spam(FieldsNamespace): + FIELDS = fields + + calls = [] + op = (lambda dt: calls.append(dt) or dt) + transformed = Spam.traverse(op) + + self.assertIs(transformed, Spam) + self.assertIsInstance(transformed.FIELDS, Fields) + for i, field in enumerate(Spam.FIELDS): + self.assertIs(field, fields[i]) + self.assertCountEqual(calls, [ + Field('spam'), + Field('ham'), + Field('eggs'), + ]) + + def test_traverse_changed(self): + class Spam(FieldsNamespace): + FIELDS = Fields( + Field('spam', ANY), + Field('eggs', None), + ) + + calls = [] + op = (lambda dt: calls.append(dt) or Field(dt.name, str)) + transformed = Spam.traverse(op) + + self.assertIs(transformed, Spam) + self.assertEqual(transformed.FIELDS, Fields( + Field('spam', str), + Field('eggs', str), + )) + self.assertEqual(calls, [ + Field('spam', ANY), + Field('eggs', None), + ]) + + def test_normalize_without_ops(self): + fieldlist = [ + Field('spam'), + Field('ham'), + Field('eggs'), + ] + fields = Fields(*fieldlist) + + class Spam(FieldsNamespace): + FIELDS = fields + + Spam.normalize() + + self.assertIs(Spam.FIELDS, fields) + for i, field in enumerate(Spam.FIELDS): + self.assertIs(field, fieldlist[i]) + + def test_normalize_unnormalized(self): + fieldlist = [ + Field('spam'), + Field('ham'), + Field('eggs'), + ] + + class Spam(FieldsNamespace): + FIELDS = fieldlist + + Spam.normalize() + + self.assertIsInstance(Spam.FIELDS, Fields) + for i, field in enumerate(Spam.FIELDS): + self.assertIs(field, fieldlist[i]) + + def test_normalize_with_ops_noop(self): + fieldlist = [ + Field('spam'), + Field('ham', int), + Field('eggs', Array(ANY)), + ] + fields = Fields(*fieldlist) + + class Spam(FieldsNamespace): + FIELDS = fields + + calls = [] + op1 = (lambda dt: calls.append((op1, dt)) or dt) + op2 = (lambda dt: calls.append((op2, dt)) or dt) + Spam.normalize(op1, op2) + + self.assertIs(Spam.FIELDS, fields) + for i, field in enumerate(Spam.FIELDS): + self.assertIs(field, fieldlist[i]) + self.assertEqual(calls, [ + (op1, str), + (op1, Field('spam')), + (op1, int), + (op1, Field('ham', int)), + (op1, ANY), + (op1, Array(ANY)), + (op1, Field('eggs', Array(ANY))), + (op1, fields), + (op2, str), + (op2, Field('spam')), + (op2, int), + (op2, Field('ham', int)), + (op2, ANY), + (op2, Array(ANY)), + (op2, Field('eggs', Array(ANY))), + (op2, fields), + ]) + + def test_normalize_with_op_changed(self): + class Spam(FieldsNamespace): + FIELDS = Fields( + Field('spam', Array(ANY)), + ) + + op = (lambda dt: int if dt is ANY else dt) + Spam.normalize(op) + + self.assertEqual(Spam.FIELDS, Fields( + Field('spam', Array(int)), + )) + + def test_normalize_missing(self): + with self.assertRaises(TypeError): + FieldsNamespace.normalize() + + ####### + + def test_fields_full(self): + class Spam(FieldsNamespace): + FIELDS = FIELDS_EXTENDED + + spam = Spam(**EXTENDED_FULL) + ns = vars(spam) + del ns['_validators'] + del ns['_serializers'] + + self.assertEqual(ns, { + 'name': 'spam', + 'valid': True, + 'id': 10, + 'value': None, + 'x': Basic(**BASIC_FULL), + 'y': 11, + 'z': [ + Basic(**BASIC_FULL), + Basic(**BASIC_MIN), + ], + }) + + def test_fields_min(self): + class Spam(FieldsNamespace): + FIELDS = FIELDS_EXTENDED + + spam = Spam(**EXTENDED_MIN) + ns = vars(spam) + del ns['_validators'] + del ns['_serializers'] + + self.assertEqual(ns, { + 'name': 'spam', + 'id': 10, + }) + + def test_no_fields(self): + with self.assertRaises(TypeError): + FieldsNamespace( + x='spam', + y=42, + z=None, + ) + + def test_attrs(self): + ns = Basic(name='', value='') + + self.assertEqual(ns.name, '') + self.assertEqual(ns.value, '') + + def test_equality(self): + ns1 = Basic(name='', value='') + ns2 = Basic(name='', value='') + + self.assertTrue(ns1 == ns1) + self.assertTrue(ns1 == ns2) + + def test_inequality(self): + p = [Basic(name=n, value=v) + for n in ['<>', ''] + for v in ['<>', '']] + for basic1, basic2 in itertools.combinations(p, 2): + with self.subTest((basic1, basic2)): + self.assertTrue(basic1 != basic2) + + @unittest.skip('not ready') + def test_validate(self): + # TODO: finish + raise NotImplementedError + + def test_as_data(self): + class Spam(FieldsNamespace): + FIELDS = FIELDS_EXTENDED + + spam = Spam(**EXTENDED_FULL) + sdata = spam.as_data() + basic = Basic(**BASIC_FULL) + bdata = basic.as_data() + + self.assertEqual(sdata, EXTENDED_FULL) + self.assertEqual(bdata, BASIC_FULL) From 6b4e02ef45ef929d44f4caf2f57d89bb7f79d295 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Wed, 31 Jan 2018 20:59:19 +0000 Subject: [PATCH 06/10] Add _decl.Enum. --- debugger_protocol/arg/_datatype.py | 13 +++++- debugger_protocol/arg/_decl.py | 53 ++++++++++++++++++++++- tests/debugger_protocol/arg/test__decl.py | 39 ++++++++++++++++- 3 files changed, 99 insertions(+), 6 deletions(-) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index ee5b76ff..575eb8d9 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -1,8 +1,8 @@ from debugger_protocol._base import Readonly, WithRepr -from ._common import NOT_SET, ANY +from ._common import NOT_SET, ANY, SIMPLE_TYPES from ._decl import ( _transform_datatype, _replace_ref, - Union, Array, Field, Fields) + Enum, Union, Array, Field, Fields) from ._errors import ArgTypeMismatchError, ArgMissingError, IncompleteArgError @@ -13,8 +13,17 @@ def _coerce(datatype, value, call=True): return value elif value is datatype: return value + elif datatype is None: + pass # fail below + elif datatype in SIMPLE_TYPES: + # We already checked for exact type match above. + pass # fail below # decl types + elif isinstance(datatype, Enum): + value = _coerce(datatype.datatype, value, call=False) + if value in datatype.choices: + return value elif isinstance(datatype, Union): for dt in datatype: try: diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 20743c8a..4f0c2ff9 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -17,6 +17,8 @@ def _normalize_datatype(datatype): return ANY elif datatype in list(SIMPLE_TYPES): return datatype + elif isinstance(datatype, Enum): + return datatype elif isinstance(datatype, Union): return datatype elif isinstance(datatype, Array): @@ -49,6 +51,49 @@ def _replace_ref(datatype, target): return datatype +class Enum(namedtuple('Enum', 'datatype choices')): + """A simple type with a limited set of allowed values.""" + + @classmethod + def _check_choices(cls, datatype, choices, strict=True): + if callable(choices): + return choices + + if isinstance(choices, str): + msg = 'bad choices (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choices)) + + choices = frozenset(choices) + if not choices: + raise TypeError('missing choices') + if not strict: + return choices + + for value in choices: + if type(value) is not datatype: + msg = 'bad choices (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choices)) + return choices + + def __new__(cls, datatype, choices, **kwargs): + strict = kwargs.pop('strict', True) + normalize = kwargs.pop('_normalize', True) + (lambda: None)(**kwargs) # Make sure there aren't any other kwargs. + + if not isinstance(datatype, type): + raise ValueError('expected a class, got {!r}'.format(datatype)) + if datatype not in list(SIMPLE_TYPES): + msg = 'only simple datatypes are supported, got {!r}' + raise ValueError(msg.format(datatype)) + if normalize: + # There's no need to normalize datatype (it's a simple type). + pass + choices = cls._check_choices(datatype, choices, strict=strict) + + self = super(Enum, cls).__new__(cls, datatype, choices) + return self + + class Union(frozenset): """Declare a union of different types. @@ -137,8 +182,12 @@ class Field(namedtuple('Field', 'name datatype default optional')): START_OPTIONAL = sentinel('START_OPTIONAL') - def __new__(cls, name, datatype=str, default=NOT_SET, optional=False, - _normalize=True, **kwargs): + def __new__(cls, name, datatype=str, enum=None, default=NOT_SET, + optional=False, _normalize=True, **kwargs): + if enum is not None and not isinstance(enum, Enum): + datatype = Enum(datatype, enum) + enum = None + if _normalize: datatype = _normalize_datatype(datatype) self = super(Field, cls).__new__( diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 59ee2744..ec296402 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -4,8 +4,8 @@ from debugger_protocol.arg import NOT_SET, ANY from debugger_protocol.arg._datatype import FieldsNamespace from debugger_protocol.arg._decl import ( REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype, - Union, Array, Field, Fields) -from debugger_protocol.arg._param import Parameter, ParameterImplBase, Arg + Enum, Union, Array, Field, Fields) +from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from debugger_protocol.arg._params import ( SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter) @@ -23,6 +23,7 @@ class ModuleTests(unittest.TestCase): (int, NOOP), (str, NOOP), (bool, NOOP), + (Enum(str, ('spam',)), NOOP), (Union(str, int), NOOP), ({str, int}, Union(str, int)), (frozenset([str, int]), Union(str, int)), @@ -169,6 +170,35 @@ class ModuleTests(unittest.TestCase): ]) +class EnumTests(unittest.TestCase): + + def test_attrs(self): + enum = Enum(str, ('spam', 'eggs')) + datatype, choices = enum + + self.assertIs(datatype, str) + self.assertEqual(choices, frozenset(['spam', 'eggs'])) + + def test_bad_datatype(self): + with self.assertRaises(ValueError): + Enum('spam', ('spam', 'eggs')) + with self.assertRaises(ValueError): + Enum(dict, ('spam', 'eggs')) + + def test_bad_choices(self): + class String(str): + pass + + with self.assertRaises(ValueError): + Enum(str, 'spam') + with self.assertRaises(TypeError): + Enum(str, ()) + with self.assertRaises(ValueError): + Enum(str, ('spam', 10)) + with self.assertRaises(ValueError): + Enum(str, ('spam', String)) + + class UnionTests(unittest.TestCase): def test_normalized(self): @@ -271,6 +301,11 @@ class FieldTests(unittest.TestCase): self.assertIs(field.default, NOT_SET) self.assertFalse(field.optional) + def test_enum(self): + field = Field('spam', str, enum=('a', 'b', 'c')) + + self.assertEqual(field.datatype, Enum(str, ('a', 'b', 'c'))) + def test_normalized(self): tests = [ (REF, TYPE_REFERENCE), From ecd482345cbb2f7b7c645a45cf68a9a809461b20 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Wed, 31 Jan 2018 22:57:07 +0000 Subject: [PATCH 07/10] Preserve order in Union. --- debugger_protocol/arg/_decl.py | 76 ++++++++++++++++------- tests/debugger_protocol/arg/test__decl.py | 12 ++-- 2 files changed, 60 insertions(+), 28 deletions(-) diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 4f0c2ff9..3b62e91c 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -9,6 +9,17 @@ REF = '' TYPE_REFERENCE = sentinel('TYPE_REFERENCE') +def _is_simple(datatype): + if datatype is ANY: + return True + elif datatype in list(SIMPLE_TYPES): + return True + elif isinstance(datatype, Enum): + return True + else: + return False + + def _normalize_datatype(datatype): cls = type(datatype) if datatype == REF or datatype is TYPE_REFERENCE: @@ -51,31 +62,31 @@ def _replace_ref(datatype, target): return datatype -class Enum(namedtuple('Enum', 'datatype choices')): +class Enum(namedtuple('Enum', 'datatype choice')): """A simple type with a limited set of allowed values.""" @classmethod - def _check_choices(cls, datatype, choices, strict=True): - if callable(choices): - return choices + def _check_choice(cls, datatype, choice, strict=True): + if callable(choice): + return choice - if isinstance(choices, str): - msg = 'bad choices (expected {!r} values, got {!r})' - raise ValueError(msg.format(datatype, choices)) + if isinstance(choice, str): + msg = 'bad choice (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choice)) - choices = frozenset(choices) - if not choices: - raise TypeError('missing choices') + choice = frozenset(choice) + if not choice: + raise TypeError('missing choice') if not strict: - return choices + return choice - for value in choices: + for value in choice: if type(value) is not datatype: - msg = 'bad choices (expected {!r} values, got {!r})' - raise ValueError(msg.format(datatype, choices)) - return choices + msg = 'bad choice (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choice)) + return choice - def __new__(cls, datatype, choices, **kwargs): + def __new__(cls, datatype, choice, **kwargs): strict = kwargs.pop('strict', True) normalize = kwargs.pop('_normalize', True) (lambda: None)(**kwargs) # Make sure there aren't any other kwargs. @@ -88,18 +99,19 @@ class Enum(namedtuple('Enum', 'datatype choices')): if normalize: # There's no need to normalize datatype (it's a simple type). pass - choices = cls._check_choices(datatype, choices, strict=strict) + choice = cls._check_choice(datatype, choice, strict=strict) - self = super(Enum, cls).__new__(cls, datatype, choices) + self = super(Enum, cls).__new__(cls, datatype, choice) return self -class Union(frozenset): +class Union(tuple): """Declare a union of different types. - Sets and frozensets are treated equivalently in declarations. + The declared order is preserved and respected. + + Sets and frozensets are treated Unions in declarations. """ - __slots__ = () @classmethod def _traverse(cls, datatypes, op): @@ -122,11 +134,31 @@ class Union(frozenset): datatypes, lambda dt: _transform_datatype(dt, _normalize_datatype), ) - return super(Union, cls).__new__(cls, datatypes) + self = super(Union, cls).__new__(cls, datatypes) + self._simple = all(_is_simple(dt) for dt in datatypes) + return self def __repr__(self): return '{}{}'.format(type(self).__name__, tuple(self)) + def __hash__(self): + return super(Union, self).__hash__() + + def __eq__(self, other): # honors order + if not isinstance(other, Union): + return NotImplemented + if super(Union, self).__eq__(other): + return True + if set(self) != set(other): + return False + if self._simple and other._simple: + return True + return NotImplemented + + + def __ne__(self, other): + return not (self == other) + @property def datatypes(self): return set(self) diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index ec296402..c3dfd2ee 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -37,8 +37,8 @@ class ModuleTests(unittest.TestCase): (ParameterImplBase(str), NOOP), (Arg(object(), object()), NOOP), (SimpleParameter(str), NOOP), - (UnionParameter(str), NOOP), - (ArrayParameter(str), NOOP), + (UnionParameter(Union(str)), NOOP), + (ArrayParameter(Array(str)), NOOP), (ComplexParameter(Fields()), NOOP), (NOT_SET, NOOP), (object(), NOOP), @@ -73,8 +73,8 @@ class ModuleTests(unittest.TestCase): ParameterImplBase(str), Arg(object(), object()), SimpleParameter(str), - UnionParameter(str, int), - ArrayParameter(str), + UnionParameter(Union(str, int)), + ArrayParameter(Array(str)), ComplexParameter(Fields()), NOT_SET, object(), @@ -204,8 +204,8 @@ class UnionTests(unittest.TestCase): def test_normalized(self): tests = [ (REF, TYPE_REFERENCE), - ({str, int}, Union(str, int)), - (frozenset([str, int]), Union(str, int)), + ({str, int}, Union(*{str, int})), + (frozenset([str, int]), Union(*frozenset([str, int]))), ([str], Array(str)), ((str,), Array(str)), (None, None), From 555b8d1accbff50008f0c271fb0234d97d376a85 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 30 Jan 2018 19:10:50 +0000 Subject: [PATCH 08/10] Add Parameter, ParameterImpl, and Arg. --- debugger_protocol/arg/_param.py | 280 ++++++++++++++++ tests/debugger_protocol/arg/test__param.py | 355 +++++++++++++++++++++ 2 files changed, 635 insertions(+) create mode 100644 debugger_protocol/arg/_param.py create mode 100644 tests/debugger_protocol/arg/test__param.py diff --git a/debugger_protocol/arg/_param.py b/debugger_protocol/arg/_param.py new file mode 100644 index 00000000..d3c123da --- /dev/null +++ b/debugger_protocol/arg/_param.py @@ -0,0 +1,280 @@ +from debugger_protocol._base import Readonly, WithRepr +from ._common import NOT_SET + + +#def arg_missing(param, raw): +# """Return True if the value is "missing" relative to the parameter. +# +# The result is based on the result of calling param.missing(), with +# the exception of if the value is NOT_SET (which always means +# "missing"). +# """ +# if raw is NOT_SET: +# return True +# param = param.match_type(raw) +# missing = param.missing(raw) +# if missing: +# return missing +# elif missing is None: +# return True +# else: +# return False + + +class Parameter(object): + """Effectively a serializer for a "class" of values. + + The parameter is backed by one or more data classes to which + raw values are de-serialized (and which serialize to the + corresponding raw values). + """ + + def __init__(self, impl): + if not isinstance(impl, ParameterImplBase): + raise TypeError('bad impl') + self._impl = impl + + def __repr__(self): + return '<{} wrapping {!r}>'.format(type(self).__name__, self._impl) + + def __hash__(self): + return hash(self._impl) + + def __eq__(self, other): + if type(self) is not type(other): + return False + return self._impl == other._impl + + def __ne__(self, other): + return not (self == other) + + def bind(self, raw): + """Return an Arg for the given raw value. + + As with match_type(), if the value is not supported by this + parameter return None. + """ + param = self.match_type(raw) + return Arg(param, raw) + + def match_type(self, raw): + """Return the parameter to use for the given raw value. + + If the value does not match then return None. + + Normally this method returns self or None. For some parameters + the method may return other parameters to use. In fact, for + some (e.g. unions) it only returns other parameters (never + returns self). + """ + param = self._impl.match_type(raw) + if param is None: + return None + elif param is self._impl: + return self + elif isinstance(param, Parameter): + return param + else: + return self.__class__(param) + + def missing(self, raw): + """Return True if the raw value should be treated as NOT_SET. + + A True result corresponds to raising ArgMissingError. A result + of None means defer to other parameters, much as NotImplemented + works. If every parameter returns None then the value should + be treated as missing. + + In addition to True/False, for "complex" values missing() may + also return a mapping of names to the portions of the raw value + that are missing. In that case the result corresponds instead + to raising IncompleteArgError. + """ + return self._impl.missing(raw) + + def coerce(self, raw): + # XXX + """Return the deserialized equivalent of the given raw value. + + XXX + + If the parameter's underlying data class + """ + return self._impl.coerce(raw) + + def validate(self, coerced): + """Ensure that the already-deserialized value is correct. + + If the value has a "validate()" method then it gets called. + Otherwise it's up to the parameter. + """ + try: + validate = coerced.validate + except AttributeError: + self._impl.validate(coerced) + else: + validate() + + def as_data(self, coerced): + """Return a serialized equivalent of the given value. + + This method round-trips with the "coerce()" method. + """ + try: + as_data = coerced.as_data + except AttributeError: + return self._impl.as_data(coerced) + else: + return as_data(coerced) + + +class ParameterImplBase(Readonly): + """The base class for low-level Parameter implementations. + + The default methods are essentially noops. + + See corresponding Parameter methods. + """ + + def __init__(self, datatype=NOT_SET): + self._bind_attrs(datatype=datatype) + + def __repr__(self): + if self.datatype is NOT_SET: + return '{}()'.format(type(self).__name__) + else: + return '{}({!r})'.format(type(self).__name__, self.datatype) + + def __hash__(self): + try: + return hash(self.datatype) + except TypeError: + return hash(id(self)) + + def __eq__(self, other): + if type(self) is not type(other): + return False + return self.datatype == other.datatype + + def __ne__(self, other): + return not (self == other) + + def match_type(self, raw): + return self + + def missing(self, raw): + return False + + def coerce(self, raw): + return raw + + def validate(self, coerced): + return + + def as_data(self, coerced): + return coerced + + +class Arg(Readonly, WithRepr): + """The bridge between a raw value and a deserialized one. + + This is primarily the product of Parameter.bind(). + """ + # The value of this type lies in encapsulating intermediate state + # and caching data. + + def __init__(self, param, value, israw=True): + if isinstance(param, ParameterImplBase): + param = Parameter(param) + elif not isinstance(param, Parameter): + raise TypeError( + 'bad param (expected Parameter, got {!r})'.format(param)) + key = '_raw' if israw else '_value' + kwargs = {key: value} + self._bind_attrs( + param=param, + _validated=False, + **kwargs + ) + + def _init_args(self): + yield ('param', self.param) + try: + yield ('value', self._raw) + except AttributeError: + yield ('value', self._value) + yield ('israw', False) + + def __hash__(self): + try: + return hash(self.param.datatype) + except TypeError: + return hash(id(self)) + + def __eq__(self, other): + if type(self) is not type(other): + return False + return self.param == other.param + + def __ne__(self, other): + return not (self == other) + + @property + def raw(self): + """The serialized value.""" + return self.as_data() + + @property + def value(self): + """The de-serialized value.""" + value = self.coerce() + if not self._validated: + self._validate() + return value + + def coerce(self, cached=True): + """Return the deserialized equivalent of the raw value.""" + if not cached: + try: + raw = self._raw + except AttributeError: + # Use the cached value anyway. + return self._value + else: + return self.param.coerce(raw) + + try: + return self._value + except AttributeError: + value = self.param.coerce(self._raw) + self._bind_attrs( + _value=value, + ) + return value + + def validate(self, force=False): + """Ensure that the (deserialized) value is correct.""" + if not self._validated or force: + self.coerce() + self._validate() + + def _validate(self): + self.param.validate(self._value) + self._bind_attrs( + _validated=True, + ) + + def as_data(self, cached=True): + """Return a serialized equivalent of the value.""" + self.validate() + if not cached: + return self.param.as_data(self._value) + + try: + return self._raw + except AttributeError: + raw = self.param.as_data(self._value) + self._bind_attrs( + _raw=raw, + ) + return raw diff --git a/tests/debugger_protocol/arg/test__param.py b/tests/debugger_protocol/arg/test__param.py new file mode 100644 index 00000000..3d39ac67 --- /dev/null +++ b/tests/debugger_protocol/arg/test__param.py @@ -0,0 +1,355 @@ +from types import SimpleNamespace +import unittest + +from debugger_protocol.arg import NOT_SET +from debugger_protocol.arg._param import Parameter, ParameterImplBase, Arg + +from tests.helpers.stub import Stub + + +class FakeImpl(ParameterImplBase): + + def __init__(self, stub=None): + super().__init__() + self._bind_attrs( + stub=stub or Stub(), + returns=SimpleNamespace( + match_type=None, + missing=None, + coerce=None, + as_data=None, + ), + ) + + def match_type(self, raw): + self.stub.add_call('match_type', raw) + self.stub.maybe_raise() + return self.returns.match_type + + def missing(self, raw): + self.stub.add_call('missing', raw) + self.stub.maybe_raise() + return self.returns.missing + + def coerce(self, raw): + self.stub.add_call('coerce', raw) + self.stub.maybe_raise() + return self.returns.coerce + + def validate(self, coerced): + self.stub.add_call('validate', coerced) + self.stub.maybe_raise() + + def as_data(self, coerced): + self.stub.add_call('as_data', coerced) + self.stub.maybe_raise() + return self.returns.as_data + + +class ParameterTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.stub = Stub() + self.impl = FakeImpl(self.stub) + + def test_bad_impl(self): + with self.assertRaises(TypeError): + Parameter(None) + with self.assertRaises(TypeError): + Parameter(str) + + def test_bind_matched(self): + self.impl.returns.match_type = self.impl + param = Parameter(self.impl) + arg = param.bind('spam') + + self.assertEqual(arg, Arg(param, 'spam')) + self.assertEqual(self.stub.calls, [ + ('match_type', ('spam',), {}), + ]) + + def test_bind_no_match(self): + self.impl.returns.match_type = None + param = Parameter(self.impl) + + with self.assertRaises(TypeError): + param.bind('spam') + self.assertEqual(self.stub.calls, [ + ('match_type', ('spam',), {}), + ]) + + def test_match_type_no_match(self): + self.impl.returns.match_type = None + param = Parameter(self.impl) + matched = param.match_type('spam') + + self.assertIs(matched, None) + self.assertEqual(self.stub.calls, [ + ('match_type', ('spam',), {}), + ]) + + def test_match_type_param(self): + other = Parameter(ParameterImplBase(str)) + self.impl.returns.match_type = other + param = Parameter(self.impl) + matched = param.match_type('spam') + + self.assertIs(matched, other) + self.assertNotEqual(matched, param) + self.assertEqual(self.stub.calls, [ + ('match_type', ('spam',), {}), + ]) + + def test_match_type_impl_noop(self): + self.impl.returns.match_type = self.impl + param = Parameter(self.impl) + matched = param.match_type('spam') + + self.assertIs(matched, param) + self.assertEqual(self.stub.calls, [ + ('match_type', ('spam',), {}), + ]) + + def test_match_type_impl_wrap(self): + other = ParameterImplBase(str) + self.impl.returns.match_type = other + param = Parameter(self.impl) + matched = param.match_type('spam') + + self.assertNotEqual(matched, param) + self.assertIs(matched._impl, other) + self.assertEqual(self.stub.calls, [ + ('match_type', ('spam',), {}), + ]) + + def test_missing(self): + self.impl.returns.missing = False + param = Parameter(self.impl) + missing = param.missing('spam') + + self.assertFalse(missing) + self.assertEqual(self.stub.calls, [ + ('missing', ('spam',), {}), + ]) + + def test_coerce(self): + self.impl.returns.coerce = 'spam' + param = Parameter(self.impl) + coerced = param.coerce('spam') + + self.assertEqual(coerced, 'spam') + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ]) + + def test_validate_use_impl(self): + param = Parameter(self.impl) + param.validate('spam') + + self.assertEqual(self.stub.calls, [ + ('validate', ('spam',), {}), + ]) + + def test_validate_use_coerced(self): + other = FakeImpl() + arg = Arg(Parameter(other), 'spam', israw=False) + param = Parameter(self.impl) + param.validate(arg) + + self.assertEqual(self.stub.calls, []) + self.assertEqual(other.stub.calls, [ + ('validate', ('spam',), {}), + ]) + + def test_as_data_use_impl(self): + self.impl.returns.as_data = 'spam' + param = Parameter(self.impl) + data = param.as_data('spam') + + self.assertEqual(data, 'spam') + self.assertEqual(self.stub.calls, [ + ('as_data', ('spam',), {}), + ]) + + def test_as_data_use_coerced(self): + other = FakeImpl() + arg = Arg(Parameter(other), 'spam', israw=False) + other.returns.as_data = 'spam' + param = Parameter(self.impl) + data = param.as_data(arg) + + self.assertEqual(data, 'spam') + self.assertEqual(self.stub.calls, []) + self.assertEqual(other.stub.calls, [ + ('validate', ('spam',), {}), + ('as_data', ('spam',), {}), + ]) + + +class ParameterImplBaseTests(unittest.TestCase): + + def test_defaults(self): + impl = ParameterImplBase() + + self.assertIs(impl.datatype, NOT_SET) + + def test_match_type(self): + impl = ParameterImplBase() + param = impl.match_type('spam') + + self.assertIs(param, impl) + + def test_missing(self): + impl = ParameterImplBase() + missing = impl.missing('spam') + + self.assertFalse(missing) + + def test_coerce(self): + values = [ + (str, 'spam'), + (int, 10), + (str, 10), + (int, '10'), + ] + for datatype, value in values: + with self.subTest(value): + impl = ParameterImplBase(datatype) + coerced = impl.coerce(value) + + self.assertEqual(coerced, value) + + def test_validate(self): + impl = ParameterImplBase(str) + impl.validate('spam') + + def test_as_data(self): + impl = ParameterImplBase(str) + data = impl.as_data('spam') + + self.assertEqual(data, 'spam') + + +class ArgTests(unittest.TestCase): + + def setUp(self): + super().setUp() + self.stub = Stub() + self.impl = FakeImpl(self.stub) + self.param = Parameter(self.impl) + + def test_raw_valid(self): + self.impl.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam') + raw = arg.raw + + self.assertEqual(raw, 'spam') + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ('validate', ('eggs',), {}), + ]) + + def test_raw_invalid(self): + self.impl.returns.coerce = 'eggs' + self.stub.set_exceptions( + None, + ValueError('oops'), + ) + arg = Arg(self.param, 'spam') + + with self.assertRaises(ValueError): + arg.raw + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ('validate', ('eggs',), {}), + ]) + + def test_raw_generated(self): + self.impl.returns.as_data = 'spam' + arg = Arg(self.param, 'eggs', israw=False) + raw = arg.raw + + self.assertEqual(raw, 'spam') + self.assertEqual(self.stub.calls, [ + ('validate', ('eggs',), {}), + ('as_data', ('eggs',), {}), + ]) + + def test_value_valid(self): + arg = Arg(self.param, 'eggs', israw=False) + value = arg.value + + self.assertEqual(value, 'eggs') + self.assertEqual(self.stub.calls, [ + ('validate', ('eggs',), {}), + ]) + + def test_value_invalid(self): + self.stub.set_exceptions( + ValueError('oops'), + ) + arg = Arg(self.param, 'eggs', israw=False) + + with self.assertRaises(ValueError): + arg.value + self.assertEqual(self.stub.calls, [ + ('validate', ('eggs',), {}), + ]) + + def test_value_generated(self): + self.impl.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam') + value = arg.value + + self.assertEqual(value, 'eggs') + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ('validate', ('eggs',), {}), + ]) + + def test_coerce(self): + self.impl.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam') + value = arg.coerce() + + self.assertEqual(value, 'eggs') + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ]) + + def test_validate_okay(self): + self.impl.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam') + arg.validate() + + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ('validate', ('eggs',), {}), + ]) + + def test_validate_invalid(self): + self.stub.set_exceptions( + None, + ValueError('oops'), + ) + self.impl.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam') + + with self.assertRaises(ValueError): + arg.validate() + self.assertEqual(self.stub.calls, [ + ('coerce', ('spam',), {}), + ('validate', ('eggs',), {}), + ]) + + def test_as_data(self): + self.impl.returns.as_data = 'spam' + arg = Arg(self.param, 'eggs', israw=False) + data = arg.as_data() + + self.assertEqual(data, 'spam') + self.assertEqual(self.stub.calls, [ + ('validate', ('eggs',), {}), + ('as_data', ('eggs',), {}), + ]) From 0fed761c98562d17faaf275fa28b541984ecba05 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 30 Jan 2018 23:18:41 +0000 Subject: [PATCH 09/10] ParameterImplBase -> DatatypeHandler. --- debugger_protocol/arg/_param.py | 277 ++++++++------------- tests/debugger_protocol/arg/test__decl.py | 11 +- tests/debugger_protocol/arg/test__param.py | 272 ++++++-------------- 3 files changed, 192 insertions(+), 368 deletions(-) diff --git a/debugger_protocol/arg/_param.py b/debugger_protocol/arg/_param.py index d3c123da..33e22199 100644 --- a/debugger_protocol/arg/_param.py +++ b/debugger_protocol/arg/_param.py @@ -1,149 +1,129 @@ from debugger_protocol._base import Readonly, WithRepr -from ._common import NOT_SET -#def arg_missing(param, raw): -# """Return True if the value is "missing" relative to the parameter. -# -# The result is based on the result of calling param.missing(), with -# the exception of if the value is NOT_SET (which always means -# "missing"). -# """ -# if raw is NOT_SET: -# return True -# param = param.match_type(raw) -# missing = param.missing(raw) -# if missing: -# return missing -# elif missing is None: -# return True -# else: -# return False +class _ParameterBase(WithRepr): + def __init__(self, datatype): + self._datatype = datatype -class Parameter(object): - """Effectively a serializer for a "class" of values. - - The parameter is backed by one or more data classes to which - raw values are de-serialized (and which serialize to the - corresponding raw values). - """ - - def __init__(self, impl): - if not isinstance(impl, ParameterImplBase): - raise TypeError('bad impl') - self._impl = impl - - def __repr__(self): - return '<{} wrapping {!r}>'.format(type(self).__name__, self._impl) + def _init_args(self): + yield ('datatype', self._datatype) def __hash__(self): - return hash(self._impl) + try: + return hash(self._datatype) + except TypeError: + return hash(id(self)) def __eq__(self, other): if type(self) is not type(other): return False - return self._impl == other._impl + return self._datatype == other._datatype def __ne__(self, other): return not (self == other) + @property + def datatype(self): + return self._datatype + + +class Parameter(_ParameterBase): + """Base class for different parameter types.""" + + def __init__(self, datatype, handler=None): + super(Parameter, self).__init__(datatype) + self._handler = handler + + def _init_args(self): + for item in super(Parameter, self)._init_args(): + yield item + if self._handler is not None: + yield ('handler', self._handler) + def bind(self, raw): """Return an Arg for the given raw value. As with match_type(), if the value is not supported by this parameter return None. """ - param = self.match_type(raw) - return Arg(param, raw) + handler = self.match_type(raw) + if handler is None: + return None + return Arg(self, raw, handler) def match_type(self, raw): - """Return the parameter to use for the given raw value. + """Return the datatype handler to use for the given raw value. If the value does not match then return None. - - Normally this method returns self or None. For some parameters - the method may return other parameters to use. In fact, for - some (e.g. unions) it only returns other parameters (never - returns self). """ - param = self._impl.match_type(raw) - if param is None: - return None - elif param is self._impl: - return self - elif isinstance(param, Parameter): - return param - else: - return self.__class__(param) + return self._handler - def missing(self, raw): - """Return True if the raw value should be treated as NOT_SET. - A True result corresponds to raising ArgMissingError. A result - of None means defer to other parameters, much as NotImplemented - works. If every parameter returns None then the value should - be treated as missing. - - In addition to True/False, for "complex" values missing() may - also return a mapping of names to the portions of the raw value - that are missing. In that case the result corresponds instead - to raising IncompleteArgError. - """ - return self._impl.missing(raw) +class DatatypeHandler(_ParameterBase): + """Base class for datatype handlers.""" def coerce(self, raw): - # XXX - """Return the deserialized equivalent of the given raw value. - - XXX - - If the parameter's underlying data class - """ - return self._impl.coerce(raw) + """Return the deserialized equivalent of the given raw value.""" + # By default this is a noop. + return raw def validate(self, coerced): - """Ensure that the already-deserialized value is correct. - - If the value has a "validate()" method then it gets called. - Otherwise it's up to the parameter. - """ - try: - validate = coerced.validate - except AttributeError: - self._impl.validate(coerced) - else: - validate() + """Ensure that the already-deserialized value is correct.""" + # By default this is a noop. + return def as_data(self, coerced): """Return a serialized equivalent of the given value. This method round-trips with the "coerce()" method. """ - try: - as_data = coerced.as_data - except AttributeError: - return self._impl.as_data(coerced) - else: - return as_data(coerced) + # By default this is a noop. + return coerced -class ParameterImplBase(Readonly): - """The base class for low-level Parameter implementations. +class Arg(Readonly, WithRepr): + """The bridge between a raw value and a deserialized one. - The default methods are essentially noops. - - See corresponding Parameter methods. + This is primarily the product of Parameter.bind(). """ + # The value of this type lies in encapsulating intermediate state + # and in caching data. - def __init__(self, datatype=NOT_SET): - self._bind_attrs(datatype=datatype) + def __init__(self, param, value, handler=None, israw=True): + if not isinstance(param, Parameter): + raise TypeError( + 'bad param (expected Parameter, got {!r})'.format(param)) + if handler is None: + if israw: + handler = param.match_type(value) + else: + raise TypeError('missing handler') + if not isinstance(handler, DatatypeHandler): + msg = 'bad handler (expected DatatypeHandler, got {!r})' + raise TypeError(msg.format(handler)) - def __repr__(self): - if self.datatype is NOT_SET: - return '{}()'.format(type(self).__name__) - else: - return '{}({!r})'.format(type(self).__name__, self.datatype) + key = '_raw' if israw else '_value' + kwargs = {key: value} + self._bind_attrs( + param=param, + _handler=handler, + _validated=False, + **kwargs + ) + + def _init_args(self): + yield ('param', self.param) + israw = True + try: + yield ('value', self._raw) + except AttributeError: + yield ('value', self._value) + israw = False + if self.datatype != self.param.datatype: + yield ('handler', self._handler) + if not israw: + yield ('israw', False) def __hash__(self): try: @@ -154,71 +134,17 @@ class ParameterImplBase(Readonly): def __eq__(self, other): if type(self) is not type(other): return False - return self.datatype == other.datatype - - def __ne__(self, other): - return not (self == other) - - def match_type(self, raw): - return self - - def missing(self, raw): - return False - - def coerce(self, raw): - return raw - - def validate(self, coerced): - return - - def as_data(self, coerced): - return coerced - - -class Arg(Readonly, WithRepr): - """The bridge between a raw value and a deserialized one. - - This is primarily the product of Parameter.bind(). - """ - # The value of this type lies in encapsulating intermediate state - # and caching data. - - def __init__(self, param, value, israw=True): - if isinstance(param, ParameterImplBase): - param = Parameter(param) - elif not isinstance(param, Parameter): - raise TypeError( - 'bad param (expected Parameter, got {!r})'.format(param)) - key = '_raw' if israw else '_value' - kwargs = {key: value} - self._bind_attrs( - param=param, - _validated=False, - **kwargs - ) - - def _init_args(self): - yield ('param', self.param) - try: - yield ('value', self._raw) - except AttributeError: - yield ('value', self._value) - yield ('israw', False) - - def __hash__(self): - try: - return hash(self.param.datatype) - except TypeError: - return hash(id(self)) - - def __eq__(self, other): - if type(self) is not type(other): + if self.param != other.param: return False - return self.param == other.param + return self._as_data() == other._as_data() def __ne__(self, other): return not (self == other) + @property + def datatype(self): + return self._handler.datatype + @property def raw(self): """The serialized value.""" @@ -241,25 +167,34 @@ class Arg(Readonly, WithRepr): # Use the cached value anyway. return self._value else: - return self.param.coerce(raw) + return self._handler.coerce(raw) try: return self._value except AttributeError: - value = self.param.coerce(self._raw) + value = self._handler.coerce(self._raw) self._bind_attrs( _value=value, ) return value def validate(self, force=False): - """Ensure that the (deserialized) value is correct.""" + """Ensure that the (deserialized) value is correct. + + If the value has a "validate()" method then it gets called. + Otherwise it's up to the handler. + """ if not self._validated or force: self.coerce() self._validate() def _validate(self): - self.param.validate(self._value) + try: + validate = self._value.validate + except AttributeError: + self._handler.validate(self._value) + else: + validate() self._bind_attrs( _validated=True, ) @@ -268,12 +203,18 @@ class Arg(Readonly, WithRepr): """Return a serialized equivalent of the value.""" self.validate() if not cached: - return self.param.as_data(self._value) + return self._handler.as_data(self._value) + return self._as_data() + def _as_data(self): try: return self._raw except AttributeError: - raw = self.param.as_data(self._value) + try: + as_data = self._value.as_data + except AttributeError: + as_data = self._handler.as_data + raw = as_data(self._value) self._bind_attrs( _raw=raw, ) diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index c3dfd2ee..079dc77d 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -14,6 +14,7 @@ class ModuleTests(unittest.TestCase): def test_normalize_datatype(self): NOOP = object() + param = SimpleParameter(str) tests = [ # explicitly handled (REF, TYPE_REFERENCE), @@ -33,9 +34,9 @@ class ModuleTests(unittest.TestCase): # others (Field('spam'), NOOP), (Fields(Field('spam')), NOOP), - (Parameter(object()), NOOP), - (ParameterImplBase(str), NOOP), - (Arg(object(), object()), NOOP), + (param, NOOP), + (DatatypeHandler(str), NOOP), + (Arg(param, 'spam'), NOOP), (SimpleParameter(str), NOOP), (UnionParameter(Union(str)), NOOP), (ArrayParameter(Array(str)), NOOP), @@ -70,8 +71,8 @@ class ModuleTests(unittest.TestCase): [str], (str,), Parameter(object()), - ParameterImplBase(str), - Arg(object(), object()), + DatatypeHandler(str), + Arg(SimpleParameter(str), 'spam'), SimpleParameter(str), UnionParameter(Union(str, int)), ArrayParameter(Array(str)), diff --git a/tests/debugger_protocol/arg/test__param.py b/tests/debugger_protocol/arg/test__param.py index 3d39ac67..47236a68 100644 --- a/tests/debugger_protocol/arg/test__param.py +++ b/tests/debugger_protocol/arg/test__param.py @@ -1,36 +1,21 @@ from types import SimpleNamespace import unittest -from debugger_protocol.arg import NOT_SET -from debugger_protocol.arg._param import Parameter, ParameterImplBase, Arg +from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from tests.helpers.stub import Stub -class FakeImpl(ParameterImplBase): +class FakeHandler(DatatypeHandler): - def __init__(self, stub=None): - super().__init__() - self._bind_attrs( - stub=stub or Stub(), - returns=SimpleNamespace( - match_type=None, - missing=None, - coerce=None, - as_data=None, - ), + def __init__(self, datatype=str, stub=None): + super().__init__(datatype) + self.stub = stub or Stub() + self.returns = SimpleNamespace( + coerce=None, + as_data=None, ) - def match_type(self, raw): - self.stub.add_call('match_type', raw) - self.stub.maybe_raise() - return self.returns.match_type - - def missing(self, raw): - self.stub.add_call('missing', raw) - self.stub.maybe_raise() - return self.returns.missing - def coerce(self, raw): self.stub.add_call('coerce', raw) self.stub.maybe_raise() @@ -51,182 +36,52 @@ class ParameterTests(unittest.TestCase): def setUp(self): super().setUp() self.stub = Stub() - self.impl = FakeImpl(self.stub) - - def test_bad_impl(self): - with self.assertRaises(TypeError): - Parameter(None) - with self.assertRaises(TypeError): - Parameter(str) + self.handler = FakeHandler(self.stub) def test_bind_matched(self): - self.impl.returns.match_type = self.impl - param = Parameter(self.impl) + param = Parameter(str, self.handler) arg = param.bind('spam') - self.assertEqual(arg, Arg(param, 'spam')) - self.assertEqual(self.stub.calls, [ - ('match_type', ('spam',), {}), - ]) + self.assertEqual(arg, Arg(param, 'spam', self.handler)) + self.assertEqual(self.stub.calls, []) def test_bind_no_match(self): - self.impl.returns.match_type = None - param = Parameter(self.impl) + param = Parameter(str) - with self.assertRaises(TypeError): - param.bind('spam') - self.assertEqual(self.stub.calls, [ - ('match_type', ('spam',), {}), - ]) + arg = param.bind('spam') + self.assertIs(arg, None) + self.assertEqual(self.stub.calls, []) def test_match_type_no_match(self): - self.impl.returns.match_type = None - param = Parameter(self.impl) + param = Parameter(str) matched = param.match_type('spam') self.assertIs(matched, None) - self.assertEqual(self.stub.calls, [ - ('match_type', ('spam',), {}), - ]) + self.assertEqual(self.stub.calls, []) - def test_match_type_param(self): - other = Parameter(ParameterImplBase(str)) - self.impl.returns.match_type = other - param = Parameter(self.impl) + def test_match_type_matched(self): + param = Parameter(str, self.handler) matched = param.match_type('spam') - self.assertIs(matched, other) - self.assertNotEqual(matched, param) - self.assertEqual(self.stub.calls, [ - ('match_type', ('spam',), {}), - ]) + self.assertIs(matched, self.handler) + self.assertEqual(self.stub.calls, []) - def test_match_type_impl_noop(self): - self.impl.returns.match_type = self.impl - param = Parameter(self.impl) - matched = param.match_type('spam') - self.assertIs(matched, param) - self.assertEqual(self.stub.calls, [ - ('match_type', ('spam',), {}), - ]) - - def test_match_type_impl_wrap(self): - other = ParameterImplBase(str) - self.impl.returns.match_type = other - param = Parameter(self.impl) - matched = param.match_type('spam') - - self.assertNotEqual(matched, param) - self.assertIs(matched._impl, other) - self.assertEqual(self.stub.calls, [ - ('match_type', ('spam',), {}), - ]) - - def test_missing(self): - self.impl.returns.missing = False - param = Parameter(self.impl) - missing = param.missing('spam') - - self.assertFalse(missing) - self.assertEqual(self.stub.calls, [ - ('missing', ('spam',), {}), - ]) +class DatatypeHandlerTests(unittest.TestCase): def test_coerce(self): - self.impl.returns.coerce = 'spam' - param = Parameter(self.impl) - coerced = param.coerce('spam') + handler = DatatypeHandler(str) + coerced = handler.coerce('spam') self.assertEqual(coerced, 'spam') - self.assertEqual(self.stub.calls, [ - ('coerce', ('spam',), {}), - ]) - - def test_validate_use_impl(self): - param = Parameter(self.impl) - param.validate('spam') - - self.assertEqual(self.stub.calls, [ - ('validate', ('spam',), {}), - ]) - - def test_validate_use_coerced(self): - other = FakeImpl() - arg = Arg(Parameter(other), 'spam', israw=False) - param = Parameter(self.impl) - param.validate(arg) - - self.assertEqual(self.stub.calls, []) - self.assertEqual(other.stub.calls, [ - ('validate', ('spam',), {}), - ]) - - def test_as_data_use_impl(self): - self.impl.returns.as_data = 'spam' - param = Parameter(self.impl) - data = param.as_data('spam') - - self.assertEqual(data, 'spam') - self.assertEqual(self.stub.calls, [ - ('as_data', ('spam',), {}), - ]) - - def test_as_data_use_coerced(self): - other = FakeImpl() - arg = Arg(Parameter(other), 'spam', israw=False) - other.returns.as_data = 'spam' - param = Parameter(self.impl) - data = param.as_data(arg) - - self.assertEqual(data, 'spam') - self.assertEqual(self.stub.calls, []) - self.assertEqual(other.stub.calls, [ - ('validate', ('spam',), {}), - ('as_data', ('spam',), {}), - ]) - - -class ParameterImplBaseTests(unittest.TestCase): - - def test_defaults(self): - impl = ParameterImplBase() - - self.assertIs(impl.datatype, NOT_SET) - - def test_match_type(self): - impl = ParameterImplBase() - param = impl.match_type('spam') - - self.assertIs(param, impl) - - def test_missing(self): - impl = ParameterImplBase() - missing = impl.missing('spam') - - self.assertFalse(missing) - - def test_coerce(self): - values = [ - (str, 'spam'), - (int, 10), - (str, 10), - (int, '10'), - ] - for datatype, value in values: - with self.subTest(value): - impl = ParameterImplBase(datatype) - coerced = impl.coerce(value) - - self.assertEqual(coerced, value) def test_validate(self): - impl = ParameterImplBase(str) - impl.validate('spam') + handler = DatatypeHandler(str) + handler.validate('spam') def test_as_data(self): - impl = ParameterImplBase(str) - data = impl.as_data('spam') + handler = DatatypeHandler(str) + data = handler.as_data('spam') self.assertEqual(data, 'spam') @@ -236,12 +91,12 @@ class ArgTests(unittest.TestCase): def setUp(self): super().setUp() self.stub = Stub() - self.impl = FakeImpl(self.stub) - self.param = Parameter(self.impl) + self.handler = FakeHandler(str, self.stub) + self.param = Parameter(str, self.handler) def test_raw_valid(self): - self.impl.returns.coerce = 'eggs' - arg = Arg(self.param, 'spam') + self.handler.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam', self.handler) raw = arg.raw self.assertEqual(raw, 'spam') @@ -251,12 +106,12 @@ class ArgTests(unittest.TestCase): ]) def test_raw_invalid(self): - self.impl.returns.coerce = 'eggs' + self.handler.returns.coerce = 'eggs' self.stub.set_exceptions( None, ValueError('oops'), ) - arg = Arg(self.param, 'spam') + arg = Arg(self.param, 'spam', self.handler) with self.assertRaises(ValueError): arg.raw @@ -266,8 +121,8 @@ class ArgTests(unittest.TestCase): ]) def test_raw_generated(self): - self.impl.returns.as_data = 'spam' - arg = Arg(self.param, 'eggs', israw=False) + self.handler.returns.as_data = 'spam' + arg = Arg(self.param, 'eggs', self.handler, israw=False) raw = arg.raw self.assertEqual(raw, 'spam') @@ -277,7 +132,7 @@ class ArgTests(unittest.TestCase): ]) def test_value_valid(self): - arg = Arg(self.param, 'eggs', israw=False) + arg = Arg(self.param, 'eggs', self.handler, israw=False) value = arg.value self.assertEqual(value, 'eggs') @@ -289,7 +144,7 @@ class ArgTests(unittest.TestCase): self.stub.set_exceptions( ValueError('oops'), ) - arg = Arg(self.param, 'eggs', israw=False) + arg = Arg(self.param, 'eggs', self.handler, israw=False) with self.assertRaises(ValueError): arg.value @@ -298,8 +153,8 @@ class ArgTests(unittest.TestCase): ]) def test_value_generated(self): - self.impl.returns.coerce = 'eggs' - arg = Arg(self.param, 'spam') + self.handler.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam', self.handler) value = arg.value self.assertEqual(value, 'eggs') @@ -309,8 +164,8 @@ class ArgTests(unittest.TestCase): ]) def test_coerce(self): - self.impl.returns.coerce = 'eggs' - arg = Arg(self.param, 'spam') + self.handler.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam', self.handler) value = arg.coerce() self.assertEqual(value, 'eggs') @@ -319,8 +174,8 @@ class ArgTests(unittest.TestCase): ]) def test_validate_okay(self): - self.impl.returns.coerce = 'eggs' - arg = Arg(self.param, 'spam') + self.handler.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam', self.handler) arg.validate() self.assertEqual(self.stub.calls, [ @@ -333,8 +188,8 @@ class ArgTests(unittest.TestCase): None, ValueError('oops'), ) - self.impl.returns.coerce = 'eggs' - arg = Arg(self.param, 'spam') + self.handler.returns.coerce = 'eggs' + arg = Arg(self.param, 'spam', self.handler) with self.assertRaises(ValueError): arg.validate() @@ -343,9 +198,21 @@ class ArgTests(unittest.TestCase): ('validate', ('eggs',), {}), ]) - def test_as_data(self): - self.impl.returns.as_data = 'spam' - arg = Arg(self.param, 'eggs', israw=False) + def test_validate_use_coerced(self): + handler = FakeHandler() + other = Arg(Parameter(str, handler), 'spam', handler, israw=False) + arg = Arg(Parameter(str, self.handler), other, self.handler, + israw=False) + arg.validate() + + self.assertEqual(self.stub.calls, []) + self.assertEqual(handler.stub.calls, [ + ('validate', ('spam',), {}), + ]) + + def test_as_data_use_handler(self): + self.handler.returns.as_data = 'spam' + arg = Arg(self.param, 'eggs', self.handler, israw=False) data = arg.as_data() self.assertEqual(data, 'spam') @@ -353,3 +220,18 @@ class ArgTests(unittest.TestCase): ('validate', ('eggs',), {}), ('as_data', ('eggs',), {}), ]) + + def test_as_data_use_coerced(self): + handler = FakeHandler() + other = Arg(Parameter(str, handler), 'spam', handler, israw=False) + handler.returns.as_data = 'spam' + arg = Arg(Parameter(str, self.handler), other, self.handler, + israw=False) + data = arg.as_data(other) + + self.assertEqual(data, 'spam') + self.assertEqual(self.stub.calls, []) + self.assertEqual(handler.stub.calls, [ + ('validate', ('spam',), {}), + ('as_data', ('spam',), {}), + ]) From 57e388bb5fb8f28effd8570f2a81995d49b4d53c Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Thu, 1 Feb 2018 16:43:05 +0000 Subject: [PATCH 10/10] Add parameter types. --- debugger_protocol/arg/__init__.py | 1 + debugger_protocol/arg/_decl.py | 1 - debugger_protocol/arg/_param.py | 2 +- debugger_protocol/arg/_params.py | 380 +++++++++++ tests/debugger_protocol/arg/test__params.py | 698 ++++++++++++++++++++ 5 files changed, 1080 insertions(+), 2 deletions(-) create mode 100644 debugger_protocol/arg/_params.py create mode 100644 tests/debugger_protocol/arg/test__params.py diff --git a/debugger_protocol/arg/__init__.py b/debugger_protocol/arg/__init__.py index f99af01c..f294ec8a 100644 --- a/debugger_protocol/arg/__init__.py +++ b/debugger_protocol/arg/__init__.py @@ -5,3 +5,4 @@ from ._errors import ( # noqa ArgumentError, ArgMissingError, IncompleteArgError, ArgTypeMismatchError, ) +from ._params import param_from_datatype # noqa diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 3b62e91c..f7fea26f 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -301,7 +301,6 @@ class Fields(Readonly, Sequence): def __getitem__(self, index): return self._fields[index] - @property def as_dict(self): return {field.name: field for field in self._fields} diff --git a/debugger_protocol/arg/_param.py b/debugger_protocol/arg/_param.py index 33e22199..bc87c124 100644 --- a/debugger_protocol/arg/_param.py +++ b/debugger_protocol/arg/_param.py @@ -17,7 +17,7 @@ class _ParameterBase(WithRepr): def __eq__(self, other): if type(self) is not type(other): - return False + return NotImplemented return self._datatype == other._datatype def __ne__(self, other): diff --git a/debugger_protocol/arg/_params.py b/debugger_protocol/arg/_params.py new file mode 100644 index 00000000..f988539a --- /dev/null +++ b/debugger_protocol/arg/_params.py @@ -0,0 +1,380 @@ +from ._common import ANY, SIMPLE_TYPES +from ._datatype import FieldsNamespace +from ._decl import Enum, Union, Array, Field, Fields +from ._errors import ArgTypeMismatchError +from ._param import Parameter, DatatypeHandler + + +#def as_parameter(cls): +# """Return a parameter that wraps the given FieldsNamespace subclass.""" +# # XXX inject_params +# cls.normalize(_inject_params) +# param = param_from_datatype(cls) +## cls.PARAM = param +# return param +# +# +#def _inject_params(datatype): +# return param_from_datatype(datatype) + + +def param_from_datatype(datatype, **kwargs): + """Return a parameter for the given datatype.""" + if isinstance(datatype, Parameter): + return datatype + + if isinstance(datatype, DatatypeHandler): + return Parameter(datatype.datatype, datatype, **kwargs) + elif isinstance(datatype, Fields): + return ComplexParameter(datatype, **kwargs) + elif isinstance(datatype, Field): + return param_from_datatype(datatype.datatype, **kwargs) + elif datatype is ANY: + return NoopParameter() + elif datatype is None: + return SingletonParameter(None) + elif datatype in list(SIMPLE_TYPES): + return SimpleParameter(datatype, **kwargs) + elif isinstance(datatype, Enum): + return EnumParameter(datatype.datatype, datatype.choice, **kwargs) + elif isinstance(datatype, Union): + return UnionParameter(datatype, **kwargs) + elif isinstance(datatype, (set, frozenset)): + return UnionParameter(Union(*datatype), **kwargs) + elif isinstance(datatype, Array): + return ArrayParameter(datatype, **kwargs) + elif isinstance(datatype, (list, tuple)): + datatype, = datatype + return ArrayParameter(Array(datatype), **kwargs) + elif not isinstance(datatype, type): + raise NotImplementedError + elif issubclass(datatype, FieldsNamespace): + return ComplexParameter(datatype, **kwargs) + else: + raise NotImplementedError + + +######################## +# param types + +class NoopParameter(Parameter): + """A parameter that treats any value as-is.""" + def __init__(self): + handler = DatatypeHandler(ANY) + super(NoopParameter, self).__init__(ANY, handler) + + +NOOP = NoopParameter() + + +class SingletonParameter(Parameter): + """A parameter that works only for the given value.""" + + class HANDLER(DatatypeHandler): + def validate(self, coerced): + if coerced is not self.datatype: + raise ValueError( + 'expected {!r}, got {!r}'.format(self.datatype, coerced)) + + def __init__(self, obj): + handler = self.HANDLER(obj) + super(SingletonParameter, self).__init__(obj, handler) + + def match_type(self, raw): + # Note we do not check equality for singletons. + if raw is not self.datatype: + return None + return super(SingletonParameter, self).match_type(raw) + + +class SimpleHandler(DatatypeHandler): + """A datatype handler for basic value types.""" + + def __init__(self, cls): + if not isinstance(cls, type): + raise ValueError('expected a class, got {!r}'.format(cls)) + super(SimpleHandler, self).__init__(cls) + + def coerce(self, raw): + if type(raw) is self.datatype: + return raw + return self.datatype(raw) + + def validate(self, coerced): + if type(coerced) is not self.datatype: + raise ValueError( + 'expected {!r}, got {!r}'.format(self.datatype, coerced)) + + +class SimpleParameter(Parameter): + """A parameter for basic value types.""" + + HANDLER = SimpleHandler + + def __init__(self, cls, strict=True): + handler = self.HANDLER(cls) + super(SimpleParameter, self).__init__(cls, handler) + self._strict = strict + + def match_type(self, raw): + if self._strict: + if type(raw) is not self.datatype: + return None + elif not isinstance(raw, self.datatype): + return None + return super(SimpleParameter, self).match_type(raw) + + +class EnumParameter(Parameter): + """A parameter for enums of basic value types.""" + + class HANDLER(SimpleHandler): + + def __init__(self, cls, enum): + if not enum: + raise TypeError('missing enum') + super(EnumParameter.HANDLER, self).__init__(cls) + if not callable(enum): + enum = set(enum) + self.enum = enum + + def validate(self, coerced): + super(EnumParameter.HANDLER, self).validate(coerced) + + if not self._match_enum(coerced): + msg = 'expected one of {!r}, got {!r}' + raise ValueError(msg.format(self.enum, coerced)) + + def _match_enum(self, coerced): + if callable(self.enum): + if not self.enum(coerced): + return False + elif coerced not in self.enum: + return False + return True + + def __init__(self, cls, enum): + handler = self.HANDLER(cls, enum) + super(EnumParameter, self).__init__(cls, handler) + self._match_enum = handler._match_enum + + def match_type(self, raw): + if type(raw) is not self.datatype: + return None + if not self._match_enum(raw): + return None + return super(EnumParameter, self).match_type(raw) + + +class UnionParameter(Parameter): + """A parameter that supports multiple different types.""" + + HANDLER = None # no handler + + @classmethod + def from_datatypes(cls, *datatypes, **kwargs): + datatype = Union(*datatypes) + return cls(datatype, **kwargs) + + def __init__(self, datatype, **kwargs): + if not isinstance(datatype, Union): + raise ValueError('expected Union, got {!r}'.format(datatype)) + super(UnionParameter, self).__init__(datatype) + + choice = [] + for dt in datatype: + param = param_from_datatype(dt) + choice.append(param) + self.choice = choice + + def __eq__(self, other): + if type(self) is not type(other): + return False + return set(self.datatype) == set(other.datatype) + + def match_type(self, raw): + for param in self.choice: + handler = param.match_type(raw) + if handler is not None: + return handler + return None + + +class ArrayParameter(Parameter): + """A parameter that is a list of some fixed type.""" + + class HANDLER(DatatypeHandler): + + def __init__(self, datatype, handlers=None, itemparam=None): + if not isinstance(datatype, Array): + raise ValueError( + 'expected an Array, got {!r}'.format(datatype)) + super(ArrayParameter.HANDLER, self).__init__(datatype) + self.handlers = handlers + self.itemparam = itemparam + + def coerce(self, raw): + if self.handlers is None: + if self.itemparam is None: + itemtype = self.datatype.itemtype + self.itemparam = param_from_datatype(itemtype) + handlers = [] + for item in raw: + handler = self.itemparam.match_type(item) + if handler is None: + raise ArgTypeMismatchError(item) + handlers.append(handler) + self.handlers = handlers + + result = [] + for i, item in enumerate(raw): + handler = self.handlers[i] + item = handler.coerce(item) + result.append(item) + return result + + def validate(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + for i, item in enumerate(coerced): + handler = self.handlers[i] + handler.validate(item) + + def as_data(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + data = [] + for i, item in enumerate(coerced): + handler = self.handlers[i] + datum = handler.as_data(item) + data.append(datum) + return data + + @classmethod + def from_itemtype(cls, itemtype, **kwargs): + datatype = Array(itemtype) + return cls(datatype, **kwargs) + + def __init__(self, datatype): + if not isinstance(datatype, Array): + raise ValueError('expected Array, got {!r}'.format(datatype)) + itemparam = param_from_datatype(datatype.itemtype) + handler = self.HANDLER(datatype, None, itemparam) + super(ArrayParameter, self).__init__(datatype, handler) + + self.itemparam = itemparam + + def match_type(self, raw): + if not isinstance(raw, list): + return None + handlers = [] + for item in raw: + handler = self.itemparam.match_type(item) + if handler is None: + return None + handlers.append(handler) + return self.HANDLER(self.datatype, handlers) + + +class ComplexParameter(Parameter): + + class HANDLER(DatatypeHandler): + + def __init__(self, datatype, handlers=None): + if (type(datatype) is not type or + not issubclass(datatype, FieldsNamespace) + ): + msg = 'expected FieldsNamespace, got {!r}' + raise ValueError(msg.format(datatype)) + super(ComplexParameter.HANDLER, self).__init__(datatype) + self.handlers = handlers + + def coerce(self, raw): + if self.handlers is None: + fields = self.datatype.FIELDS.as_dict() + handlers = {} + for name, value in raw.items(): + param = param_from_datatype(fields[name]) + handler = param.match_type(value) + if handler is None: + raise ArgTypeMismatchError((name, value)) + handlers[name] = handler + self.handlers = handlers + + result = {} + for name, value in raw.items(): + handler = self.handlers[name] + value = handler.coerce(value) + result[name] = value + return self.datatype(**result) + + def validate(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + for field in self.datatype.FIELDS: + try: + value = getattr(coerced, field.name) + except AttributeError: + continue + handler = self.handlers[field.name] + handler.validate(value) + + def as_data(self, coerced): + if self.handlers is None: + raise TypeError('coerce first') + data = {} + for field in self.datatype.FIELDS: + try: + value = getattr(coerced, field.name) + except AttributeError: + continue + handler = self.handlers[field.name] + datum = handler.as_data(value) + data[field.name] = datum + return data + + def __init__(self, datatype): + if isinstance(datatype, Fields): + class ArgNamespace(FieldsNamespace): + FIELDS = datatype + + datatype = ArgNamespace + elif (type(datatype) is not type or + not issubclass(datatype, FieldsNamespace)): + msg = 'expected Fields or FieldsNamespace, got {!r}' + raise ValueError(msg.format(datatype)) + datatype.normalize() + # We set handler later in match_type(). + super(ComplexParameter, self).__init__(datatype) + + self.params = {field.name: param_from_datatype(field) + for field in datatype.FIELDS} + + def __eq__(self, other): + if super(ComplexParameter, self).__eq__(other): + return True + try: + fields = self._datatype.FIELDS + other_fields = other._datatype.FIELDS + except AttributeError: + return NotImplemented + else: + return fields == other_fields + + def match_type(self, raw): + if not isinstance(raw, dict): + return None + handlers = {} + for field in self.datatype.FIELDS: + try: + value = raw[field.name] + except KeyError: + if not field.optional: + return None + value = field.default + param = self.params[field.name] + handler = param.match_type(value) + if handler is None: + return None + handlers[field.name] = handler + return self.HANDLER(self.datatype, handlers) diff --git a/tests/debugger_protocol/arg/test__params.py b/tests/debugger_protocol/arg/test__params.py new file mode 100644 index 00000000..2a05d69f --- /dev/null +++ b/tests/debugger_protocol/arg/test__params.py @@ -0,0 +1,698 @@ +import unittest + +from debugger_protocol.arg._common import NOT_SET, ANY +from debugger_protocol.arg._decl import Enum, Union, Array, Field, Fields +from debugger_protocol.arg._param import Parameter, DatatypeHandler +from debugger_protocol.arg._params import ( + param_from_datatype, + NoopParameter, SingletonParameter, + SimpleParameter, EnumParameter, + UnionParameter, ArrayParameter, ComplexParameter) + +from ._common import FIELDS_BASIC, BASIC_FULL, Basic + + +class String(str): + pass + + +class Integer(int): + pass + + +class ParamFromDatatypeTest(unittest.TestCase): + + def test_supported(self): + handler = DatatypeHandler(str) + tests = [ + (Parameter(str), Parameter(str)), + (handler, Parameter(str, handler)), + (Fields(Field('spam')), ComplexParameter(Fields(Field('spam')))), + (Field('spam'), SimpleParameter(str)), + (Field('spam', str, enum={'a'}), EnumParameter(str, {'a'})), + (ANY, NoopParameter()), + (None, SingletonParameter(None)), + (str, SimpleParameter(str)), + (int, SimpleParameter(int)), + (bool, SimpleParameter(bool)), + (Enum(str, {'a'}), EnumParameter(str, {'a'})), + (Union(str, int), UnionParameter(Union(str, int))), + ({str, int}, UnionParameter(Union(str, int))), + (frozenset([str, int]), UnionParameter(Union(str, int))), + (Array(str), ArrayParameter(Array(str))), + ([str], ArrayParameter(Array(str))), + ((str,), ArrayParameter(Array(str))), + (Basic, ComplexParameter(Basic)), + ] + for datatype, expected in tests: + with self.subTest(datatype): + param = param_from_datatype(datatype) + + self.assertEqual(param, expected) + + def test_not_supported(self): + datatypes = [ + String('spam'), + ..., + ] + for datatype in datatypes: + with self.subTest(datatype): + with self.assertRaises(NotImplementedError): + param_from_datatype(datatype) + + +class NoopParameterTests(unittest.TestCase): + + VALUES = [ + object(), + 'spam', + 10, + ['spam'], + {'spam': 42}, + True, + None, + NOT_SET, + ] + + def test_match_type(self): + values = [ + object(), + '', + 'spam', + b'spam', + 0, + 10, + 10.0, + 10+0j, + ('spam',), + (), + ['spam'], + [], + {'spam': 42}, + {}, + {'spam'}, + set(), + object, + type, + NoopParameterTests, + True, + None, + ..., + NotImplemented, + NOT_SET, + ANY, + Union(str, int), + Union(), + Array(str), + Field('spam'), + Fields(Field('spam')), + Fields(), + Basic, + ] + for value in values: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + + self.assertIs(type(handler), DatatypeHandler) + self.assertIs(handler.datatype, ANY) + + def test_coerce(self): + for value in self.VALUES: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertIs(coerced, value) + + def test_validate(self): + for value in self.VALUES: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + handler.validate(value) + + def test_as_data(self): + for value in self.VALUES: + with self.subTest(value): + param = NoopParameter() + handler = param.match_type(value) + data = handler.as_data(value) + + self.assertIs(data, value) + + +class SingletonParameterTests(unittest.TestCase): + + def test_match_type_matched(self): + param = SingletonParameter(None) + handler = param.match_type(None) + + self.assertIs(handler.datatype, None) + + def test_match_type_no_match(self): + tests = [ + # same type, different value + ('spam', 'eggs'), + (10, 11), + (True, False), + # different type but equivalent + ('spam', b'spam'), + (10, 10.0), + (10, 10+0j), + (10, '10'), + (10, b'\10'), + ] + for singleton, value in tests: + with self.subTest((singleton, value)): + param = SingletonParameter(singleton) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce(self): + param = SingletonParameter(None) + handler = param.match_type(None) + value = handler.coerce(None) + + self.assertIs(value, None) + + def test_validate_valid(self): + param = SingletonParameter(None) + handler = param.match_type(None) + handler.validate(None) + + def test_validate_wrong_type(self): + tests = [ + (None, True), + (True, None), + ('spam', 10), + (10, 'spam'), + ] + for singleton, value in tests: + with self.subTest(singleton): + param = SingletonParameter(singleton) + handler = param.match_type(singleton) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_validate_same_type_wrong_value(self): + tests = [ + ('spam', 'eggs'), + (True, False), + (10, 11), + ] + for singleton, value in tests: + with self.subTest(singleton): + param = SingletonParameter(singleton) + handler = param.match_type(singleton) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_as_data(self): + param = SingletonParameter(None) + handler = param.match_type(None) + data = handler.as_data(None) + + self.assertIs(data, None) + + +class SimpleParameterTests(unittest.TestCase): + + def test_match_type_match(self): + tests = [ + (str, 'spam'), + (str, String('spam')), + (int, 10), + (bool, True), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + param = SimpleParameter(datatype, strict=False) + handler = param.match_type(value) + + self.assertIs(handler.datatype, datatype) + + def test_match_type_no_match(self): + tests = [ + (int, 'spam'), + # coercible + (str, 10), + (int, 10.0), + (int, '10'), + (bool, 1), + # semi-coercible + (str, b'spam'), + (int, 10+0j), + (int, b'\10'), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + param = SimpleParameter(datatype, strict=False) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_match_type_strict_match(self): + tests = { + str: 'spam', + int: 10, + bool: True, + } + for datatype, value in tests.items(): + with self.subTest(datatype): + param = SimpleParameter(datatype, strict=True) + handler = param.match_type(value) + + self.assertIs(handler.datatype, datatype) + + def test_match_type_strict_no_match(self): + tests = { + str: String('spam'), + int: Integer(10), + } + for datatype, value in tests.items(): + with self.subTest(datatype): + param = SimpleParameter(datatype, strict=True) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce(self): + tests = [ + (str, 'spam', 'spam'), + (str, String('spam'), 'spam'), + (int, 10, 10), + (bool, True, True), + # did not match, but still coercible + (str, 10, '10'), + (str, str, ""), + (int, 10.0, 10), + (int, '10', 10), + (bool, 1, True), + ] + for datatype, value, expected in tests: + with self.subTest((datatype, value)): + handler = SimpleParameter.HANDLER(datatype) + coerced = handler.coerce(value) + + self.assertEqual(coerced, expected) + + def test_validate_valid(self): + tests = { + str: 'spam', + int: 10, + bool: True, + } + for datatype, value in tests.items(): + with self.subTest(datatype): + handler = SimpleParameter.HANDLER(datatype) + handler.validate(value) + + def test_validate_invalid(self): + tests = [ + (int, 'spam'), + # coercible + (str, String('spam')), + (str, 10), + (int, 10.0), + (int, '10'), + (bool, 1), + # semi-coercible + (str, b'spam'), + (int, 10+0j), + (int, b'\10'), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + handler = SimpleParameter.HANDLER(datatype) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_as_data(self): + tests = [ + (str, 'spam'), + (int, 10), + (bool, True), + # did not match, but still coercible + (str, String('spam')), + (str, 10), + (str, str), + (int, 10.0), + (int, '10'), + (bool, 1), + # semi-coercible + (str, b'spam'), + (int, 10+0j), + (int, b'\10'), + ] + for datatype, value in tests: + with self.subTest((datatype, value)): + handler = SimpleParameter.HANDLER(datatype) + data = handler.as_data(value) + + self.assertIs(data, value) + + +class EnumParameterTests(unittest.TestCase): + + def test_match_type_match(self): + tests = [ + (str, ('spam', 'eggs'), 'spam'), + (str, ('spam',), 'spam'), + (int, (1, 2, 3), 2), + (bool, (True,), True), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum)): + param = EnumParameter(datatype, enum) + handler = param.match_type(value) + + self.assertIs(handler.datatype, datatype) + + def test_match_type_no_match(self): + tests = [ + # enum mismatch + (str, ('spam', 'eggs'), 'ham'), + (int, (1, 2, 3), 10), + # type mismatch + (int, (1, 2, 3), 'spam'), + # coercible + (str, ('spam', 'eggs'), String('spam')), + (str, ('1', '2', '3'), 2), + (int, (1, 2, 3), 2.0), + (int, (1, 2, 3), '2'), + (bool, (True,), 1), + # semi-coercible + (str, ('spam', 'eggs'), b'spam'), + (int, (1, 2, 3), 2+0j), + (int, (1, 2, 3), b'\02'), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum, value)): + param = EnumParameter(datatype, enum) + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce(self): + tests = [ + (str, 'spam', 'spam'), + (int, 10, 10), + (bool, True, True), + # did not match, but still coercible + (str, String('spam'), 'spam'), + (str, 10, '10'), + (str, str, ""), + (int, 10.0, 10), + (int, '10', 10), + (bool, 1, True), + ] + for datatype, value, expected in tests: + with self.subTest((datatype, value)): + enum = (expected,) + handler = EnumParameter.HANDLER(datatype, enum) + coerced = handler.coerce(value) + + self.assertEqual(coerced, expected) + + def test_coerce_enum_mismatch(self): + enum = ('spam', 'eggs') + handler = EnumParameter.HANDLER(str, enum) + coerced = handler.coerce('ham') + + # It still works. + self.assertEqual(coerced, 'ham') + + def test_validate_valid(self): + tests = [ + (str, ('spam', 'eggs'), 'spam'), + (str, ('spam',), 'spam'), + (int, (1, 2, 3), 2), + (bool, (True, False), True), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum)): + handler = EnumParameter.HANDLER(datatype, enum) + handler.validate(value) + + def test_validate_invalid(self): + tests = [ + # enum mismatch + (str, ('spam', 'eggs'), 'ham'), + (int, (1, 2, 3), 10), + # type mismatch + (int, (1, 2, 3), 'spam'), + # coercible + (str, ('spam', 'eggs'), String('spam')), + (str, ('1', '2', '3'), 2), + (int, (1, 2, 3), 2.0), + (int, (1, 2, 3), '2'), + (bool, (True,), 1), + # semi-coercible + (str, ('spam', 'eggs'), b'spam'), + (int, (1, 2, 3), 2+0j), + (int, (1, 2, 3), b'\02'), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum, value)): + handler = EnumParameter.HANDLER(datatype, enum) + + with self.assertRaises(ValueError): + handler.validate(value) + + def test_as_data(self): + tests = [ + (str, ('spam', 'eggs'), 'spam'), + (str, ('spam',), 'spam'), + (int, (1, 2, 3), 2), + (bool, (True,), True), + # enum mismatch + (str, ('spam', 'eggs'), 'ham'), + (int, (1, 2, 3), 10), + # type mismatch + (int, (1, 2, 3), 'spam'), + # coercible + (str, ('spam', 'eggs'), String('spam')), + (str, ('1', '2', '3'), 2), + (int, (1, 2, 3), 2.0), + (int, (1, 2, 3), '2'), + (bool, (True,), 1), + # semi-coercible + (str, ('spam', 'eggs'), b'spam'), + (int, (1, 2, 3), 2+0j), + (int, (1, 2, 3), b'\02'), + ] + for datatype, enum, value in tests: + with self.subTest((datatype, enum, value)): + handler = EnumParameter.HANDLER(datatype, enum) + data = handler.as_data(value) + + self.assertIs(data, value) + + +class UnionParameterTests(unittest.TestCase): + + def test_match_type_all_simple(self): + tests = [ + 'spam', + 10, + True, + ] + datatype = Union(str, int, bool) + param = UnionParameter(datatype) + for value in tests: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(type(handler), SimpleParameter.HANDLER) + self.assertIs(handler.datatype, type(value)) + + def test_match_type_mixed(self): + datatype = Union( + str, + # XXX add dedicated enums + Enum(int, (1, 2, 3)), + Basic, + Array(str), + Array(int), + Union(int, bool), + ) + param = UnionParameter(datatype) + + tests = [ + ('spam', SimpleParameter.HANDLER(str)), + (2, EnumParameter.HANDLER(int, (1, 2, 3))), + (BASIC_FULL, ComplexParameter(Basic).match_type(BASIC_FULL)), + (['spam'], ArrayParameter.HANDLER(Array(str))), + ([], ArrayParameter.HANDLER(Array(str))), + ([10], ArrayParameter.HANDLER(Array(int))), + (10, SimpleParameter.HANDLER(int)), + (True, SimpleParameter.HANDLER(bool)), + # no match + (Integer(2), None), + ([True], None), + ({}, None), + ] + for value, expected in tests: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_catchall(self): + NOOP = DatatypeHandler(ANY) + param = UnionParameter(Union(int, str, ANY)) + tests = [ + ('spam', SimpleParameter.HANDLER(str)), + (10, SimpleParameter.HANDLER(int)), + # catchall + (BASIC_FULL, NOOP), + (['spam'], NOOP), + (True, NOOP), + (Integer(2), NOOP), + ([10], NOOP), + ({}, NOOP), + ] + for value, expected in tests: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_no_match(self): + param = UnionParameter(Union(int, str)) + values = [ + BASIC_FULL, + ['spam'], + True, + Integer(2), + [10], + {}, + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(handler, None) + + +class ArrayParameterTests(unittest.TestCase): + + def test_match_type_match(self): + param = ArrayParameter(Array(str)) + expected = ArrayParameter.HANDLER(Array(str)) + values = [ + ['a', 'b', 'c'], + [], + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertEqual(handler, expected) + + def test_match_type_no_match(self): + param = ArrayParameter(Array(str)) + values = [ + ['a', 1, 'c'], + ('a', 'b', 'c'), + 'spam', + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + + self.assertIs(handler, None) + + def test_coerce_simple(self): + param = ArrayParameter(Array(str)) + values = [ + ['a', 'b', 'c'], + [], + ] + for value in values: + with self.subTest(value): + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, value) + + def test_coerce_complicated(self): + param = ArrayParameter(Array(Union(str, Basic))) + value = [ + 'a', + BASIC_FULL, + 'c', + ] + handler = param.match_type(value) + coerced = handler.coerce(value) + + self.assertEqual(coerced, [ + 'a', + Basic(name='spam', value='eggs'), + 'c', + ]) + + def test_validate(self): + param = ArrayParameter(Array(str)) + handler = param.match_type(['a', 'b', 'c']) + handler.validate(['a', 'b', 'c']) + + def test_as_data_simple(self): + param = ArrayParameter(Array(str)) + handler = param.match_type(['a', 'b', 'c']) + data = handler.as_data(['a', 'b', 'c']) + + self.assertEqual(data, ['a', 'b', 'c']) + + def test_as_data_complicated(self): + param = ArrayParameter(Array(Union(str, Basic))) + value = [ + 'a', + BASIC_FULL, + 'c', + ] + handler = param.match_type(value) + data = handler.as_data([ + 'a', + Basic(name='spam', value='eggs'), + 'c', + ]) + + self.assertEqual(data, value) + + +class ComplexParameterTests(unittest.TestCase): + + def test_match_type(self): + fields = Fields(*FIELDS_BASIC) + param = ComplexParameter(fields) + handler = param.match_type(BASIC_FULL) + + self.assertIs(type(handler), ComplexParameter.HANDLER) + self.assertEqual(handler.datatype.FIELDS, fields) + + def test_coerce(self): + handler = ComplexParameter.HANDLER(Basic) + coerced = handler.coerce(BASIC_FULL) + + self.assertEqual(coerced, Basic(**BASIC_FULL)) + + def test_validate(self): + handler = ComplexParameter.HANDLER(Basic) + handler.coerce(BASIC_FULL) + coerced = Basic(**BASIC_FULL) + handler.validate(coerced) + + def test_as_data(self): + handler = ComplexParameter.HANDLER(Basic) + handler.coerce(BASIC_FULL) + coerced = Basic(**BASIC_FULL) + data = handler.as_data(coerced) + + self.assertEqual(data, BASIC_FULL)