Add _decl.Enum.

This commit is contained in:
Eric Snow 2018-01-31 20:59:19 +00:00
parent 6c097a4f54
commit 6b4e02ef45
3 changed files with 99 additions and 6 deletions

View file

@ -1,8 +1,8 @@
from debugger_protocol._base import Readonly, WithRepr
from ._common import NOT_SET, ANY
from ._common import NOT_SET, ANY, SIMPLE_TYPES
from ._decl import (
_transform_datatype, _replace_ref,
Union, Array, Field, Fields)
Enum, Union, Array, Field, Fields)
from ._errors import ArgTypeMismatchError, ArgMissingError, IncompleteArgError
@ -13,8 +13,17 @@ def _coerce(datatype, value, call=True):
return value
elif value is datatype:
return value
elif datatype is None:
pass # fail below
elif datatype in SIMPLE_TYPES:
# We already checked for exact type match above.
pass # fail below
# decl types
elif isinstance(datatype, Enum):
value = _coerce(datatype.datatype, value, call=False)
if value in datatype.choices:
return value
elif isinstance(datatype, Union):
for dt in datatype:
try:

View file

@ -17,6 +17,8 @@ def _normalize_datatype(datatype):
return ANY
elif datatype in list(SIMPLE_TYPES):
return datatype
elif isinstance(datatype, Enum):
return datatype
elif isinstance(datatype, Union):
return datatype
elif isinstance(datatype, Array):
@ -49,6 +51,49 @@ def _replace_ref(datatype, target):
return datatype
class Enum(namedtuple('Enum', 'datatype choices')):
"""A simple type with a limited set of allowed values."""
@classmethod
def _check_choices(cls, datatype, choices, strict=True):
if callable(choices):
return choices
if isinstance(choices, str):
msg = 'bad choices (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choices))
choices = frozenset(choices)
if not choices:
raise TypeError('missing choices')
if not strict:
return choices
for value in choices:
if type(value) is not datatype:
msg = 'bad choices (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choices))
return choices
def __new__(cls, datatype, choices, **kwargs):
strict = kwargs.pop('strict', True)
normalize = kwargs.pop('_normalize', True)
(lambda: None)(**kwargs) # Make sure there aren't any other kwargs.
if not isinstance(datatype, type):
raise ValueError('expected a class, got {!r}'.format(datatype))
if datatype not in list(SIMPLE_TYPES):
msg = 'only simple datatypes are supported, got {!r}'
raise ValueError(msg.format(datatype))
if normalize:
# There's no need to normalize datatype (it's a simple type).
pass
choices = cls._check_choices(datatype, choices, strict=strict)
self = super(Enum, cls).__new__(cls, datatype, choices)
return self
class Union(frozenset):
"""Declare a union of different types.
@ -137,8 +182,12 @@ class Field(namedtuple('Field', 'name datatype default optional')):
START_OPTIONAL = sentinel('START_OPTIONAL')
def __new__(cls, name, datatype=str, default=NOT_SET, optional=False,
_normalize=True, **kwargs):
def __new__(cls, name, datatype=str, enum=None, default=NOT_SET,
optional=False, _normalize=True, **kwargs):
if enum is not None and not isinstance(enum, Enum):
datatype = Enum(datatype, enum)
enum = None
if _normalize:
datatype = _normalize_datatype(datatype)
self = super(Field, cls).__new__(

View file

@ -4,8 +4,8 @@ from debugger_protocol.arg import NOT_SET, ANY
from debugger_protocol.arg._datatype import FieldsNamespace
from debugger_protocol.arg._decl import (
REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype,
Union, Array, Field, Fields)
from debugger_protocol.arg._param import Parameter, ParameterImplBase, Arg
Enum, Union, Array, Field, Fields)
from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg
from debugger_protocol.arg._params import (
SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter)
@ -23,6 +23,7 @@ class ModuleTests(unittest.TestCase):
(int, NOOP),
(str, NOOP),
(bool, NOOP),
(Enum(str, ('spam',)), NOOP),
(Union(str, int), NOOP),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
@ -169,6 +170,35 @@ class ModuleTests(unittest.TestCase):
])
class EnumTests(unittest.TestCase):
def test_attrs(self):
enum = Enum(str, ('spam', 'eggs'))
datatype, choices = enum
self.assertIs(datatype, str)
self.assertEqual(choices, frozenset(['spam', 'eggs']))
def test_bad_datatype(self):
with self.assertRaises(ValueError):
Enum('spam', ('spam', 'eggs'))
with self.assertRaises(ValueError):
Enum(dict, ('spam', 'eggs'))
def test_bad_choices(self):
class String(str):
pass
with self.assertRaises(ValueError):
Enum(str, 'spam')
with self.assertRaises(TypeError):
Enum(str, ())
with self.assertRaises(ValueError):
Enum(str, ('spam', 10))
with self.assertRaises(ValueError):
Enum(str, ('spam', String))
class UnionTests(unittest.TestCase):
def test_normalized(self):
@ -271,6 +301,11 @@ class FieldTests(unittest.TestCase):
self.assertIs(field.default, NOT_SET)
self.assertFalse(field.optional)
def test_enum(self):
field = Field('spam', str, enum=('a', 'b', 'c'))
self.assertEqual(field.datatype, Enum(str, ('a', 'b', 'c')))
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),