From 6b4e02ef45ef929d44f4caf2f57d89bb7f79d295 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Wed, 31 Jan 2018 20:59:19 +0000 Subject: [PATCH] Add _decl.Enum. --- debugger_protocol/arg/_datatype.py | 13 +++++- debugger_protocol/arg/_decl.py | 53 ++++++++++++++++++++++- tests/debugger_protocol/arg/test__decl.py | 39 ++++++++++++++++- 3 files changed, 99 insertions(+), 6 deletions(-) diff --git a/debugger_protocol/arg/_datatype.py b/debugger_protocol/arg/_datatype.py index ee5b76ff..575eb8d9 100644 --- a/debugger_protocol/arg/_datatype.py +++ b/debugger_protocol/arg/_datatype.py @@ -1,8 +1,8 @@ from debugger_protocol._base import Readonly, WithRepr -from ._common import NOT_SET, ANY +from ._common import NOT_SET, ANY, SIMPLE_TYPES from ._decl import ( _transform_datatype, _replace_ref, - Union, Array, Field, Fields) + Enum, Union, Array, Field, Fields) from ._errors import ArgTypeMismatchError, ArgMissingError, IncompleteArgError @@ -13,8 +13,17 @@ def _coerce(datatype, value, call=True): return value elif value is datatype: return value + elif datatype is None: + pass # fail below + elif datatype in SIMPLE_TYPES: + # We already checked for exact type match above. + pass # fail below # decl types + elif isinstance(datatype, Enum): + value = _coerce(datatype.datatype, value, call=False) + if value in datatype.choices: + return value elif isinstance(datatype, Union): for dt in datatype: try: diff --git a/debugger_protocol/arg/_decl.py b/debugger_protocol/arg/_decl.py index 20743c8a..4f0c2ff9 100644 --- a/debugger_protocol/arg/_decl.py +++ b/debugger_protocol/arg/_decl.py @@ -17,6 +17,8 @@ def _normalize_datatype(datatype): return ANY elif datatype in list(SIMPLE_TYPES): return datatype + elif isinstance(datatype, Enum): + return datatype elif isinstance(datatype, Union): return datatype elif isinstance(datatype, Array): @@ -49,6 +51,49 @@ def _replace_ref(datatype, target): return datatype +class Enum(namedtuple('Enum', 'datatype choices')): + """A simple type with a limited set of allowed values.""" + + @classmethod + def _check_choices(cls, datatype, choices, strict=True): + if callable(choices): + return choices + + if isinstance(choices, str): + msg = 'bad choices (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choices)) + + choices = frozenset(choices) + if not choices: + raise TypeError('missing choices') + if not strict: + return choices + + for value in choices: + if type(value) is not datatype: + msg = 'bad choices (expected {!r} values, got {!r})' + raise ValueError(msg.format(datatype, choices)) + return choices + + def __new__(cls, datatype, choices, **kwargs): + strict = kwargs.pop('strict', True) + normalize = kwargs.pop('_normalize', True) + (lambda: None)(**kwargs) # Make sure there aren't any other kwargs. + + if not isinstance(datatype, type): + raise ValueError('expected a class, got {!r}'.format(datatype)) + if datatype not in list(SIMPLE_TYPES): + msg = 'only simple datatypes are supported, got {!r}' + raise ValueError(msg.format(datatype)) + if normalize: + # There's no need to normalize datatype (it's a simple type). + pass + choices = cls._check_choices(datatype, choices, strict=strict) + + self = super(Enum, cls).__new__(cls, datatype, choices) + return self + + class Union(frozenset): """Declare a union of different types. @@ -137,8 +182,12 @@ class Field(namedtuple('Field', 'name datatype default optional')): START_OPTIONAL = sentinel('START_OPTIONAL') - def __new__(cls, name, datatype=str, default=NOT_SET, optional=False, - _normalize=True, **kwargs): + def __new__(cls, name, datatype=str, enum=None, default=NOT_SET, + optional=False, _normalize=True, **kwargs): + if enum is not None and not isinstance(enum, Enum): + datatype = Enum(datatype, enum) + enum = None + if _normalize: datatype = _normalize_datatype(datatype) self = super(Field, cls).__new__( diff --git a/tests/debugger_protocol/arg/test__decl.py b/tests/debugger_protocol/arg/test__decl.py index 59ee2744..ec296402 100644 --- a/tests/debugger_protocol/arg/test__decl.py +++ b/tests/debugger_protocol/arg/test__decl.py @@ -4,8 +4,8 @@ from debugger_protocol.arg import NOT_SET, ANY from debugger_protocol.arg._datatype import FieldsNamespace from debugger_protocol.arg._decl import ( REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype, - Union, Array, Field, Fields) -from debugger_protocol.arg._param import Parameter, ParameterImplBase, Arg + Enum, Union, Array, Field, Fields) +from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg from debugger_protocol.arg._params import ( SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter) @@ -23,6 +23,7 @@ class ModuleTests(unittest.TestCase): (int, NOOP), (str, NOOP), (bool, NOOP), + (Enum(str, ('spam',)), NOOP), (Union(str, int), NOOP), ({str, int}, Union(str, int)), (frozenset([str, int]), Union(str, int)), @@ -169,6 +170,35 @@ class ModuleTests(unittest.TestCase): ]) +class EnumTests(unittest.TestCase): + + def test_attrs(self): + enum = Enum(str, ('spam', 'eggs')) + datatype, choices = enum + + self.assertIs(datatype, str) + self.assertEqual(choices, frozenset(['spam', 'eggs'])) + + def test_bad_datatype(self): + with self.assertRaises(ValueError): + Enum('spam', ('spam', 'eggs')) + with self.assertRaises(ValueError): + Enum(dict, ('spam', 'eggs')) + + def test_bad_choices(self): + class String(str): + pass + + with self.assertRaises(ValueError): + Enum(str, 'spam') + with self.assertRaises(TypeError): + Enum(str, ()) + with self.assertRaises(ValueError): + Enum(str, ('spam', 10)) + with self.assertRaises(ValueError): + Enum(str, ('spam', String)) + + class UnionTests(unittest.TestCase): def test_normalized(self): @@ -271,6 +301,11 @@ class FieldTests(unittest.TestCase): self.assertIs(field.default, NOT_SET) self.assertFalse(field.optional) + def test_enum(self): + field = Field('spam', str, enum=('a', 'b', 'c')) + + self.assertEqual(field.datatype, Enum(str, ('a', 'b', 'c'))) + def test_normalized(self): tests = [ (REF, TYPE_REFERENCE),