Merge pull request #29 from ericsnowcurrently/declarative-framework

Add helpers for defining message args.
This commit is contained in:
Eric Snow 2018-02-01 12:39:11 -07:00 committed by GitHub
commit d0cf2f882c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 3055 additions and 0 deletions

View file

@ -0,0 +1,28 @@
class Readonly(object):
"""For read-only instances."""
def __setattr__(self, name, value):
raise AttributeError(
'{} objects are read-only'.format(type(self).__name__))
def __delattr__(self, name):
raise AttributeError(
'{} objects are read-only'.format(type(self).__name__))
def _bind_attrs(self, **attrs):
for name, value in attrs.items():
object.__setattr__(self, name, value)
class WithRepr(object):
def _init_args(self):
# XXX Extract from __init__()...
return ()
def __repr__(self):
args = ', '.join('{}={!r}'.format(arg, value)
for arg, value in self._init_args())
return '{}({})'.format(type(self).__name__, args)

View file

@ -0,0 +1,8 @@
from ._common import NOT_SET, ANY # noqa
from ._datatype import FieldsNamespace # noqa
from ._decl import Union, Array, Field # noqa
from ._errors import ( # noqa
ArgumentError,
ArgMissingError, IncompleteArgError, ArgTypeMismatchError,
)
from ._params import param_from_datatype # noqa

View file

@ -0,0 +1,17 @@
def sentinel(name):
"""Return a named value to use as a sentinel."""
class Sentinel(object):
def __repr__(self):
return name
return Sentinel()
# NOT_SET indicates that an arg was not provided.
NOT_SET = sentinel('NOT_SET')
# ANY is a datatype surrogate indicating that any value is okay.
ANY = sentinel('ANY')
SIMPLE_TYPES = {None, bool, int, str}

View file

@ -0,0 +1,286 @@
from debugger_protocol._base import Readonly, WithRepr
from ._common import NOT_SET, ANY, SIMPLE_TYPES
from ._decl import (
_transform_datatype, _replace_ref,
Enum, Union, Array, Field, Fields)
from ._errors import ArgTypeMismatchError, ArgMissingError, IncompleteArgError
def _coerce(datatype, value, call=True):
if datatype is ANY:
return value
elif type(value) is datatype:
return value
elif value is datatype:
return value
elif datatype is None:
pass # fail below
elif datatype in SIMPLE_TYPES:
# We already checked for exact type match above.
pass # fail below
# decl types
elif isinstance(datatype, Enum):
value = _coerce(datatype.datatype, value, call=False)
if value in datatype.choices:
return value
elif isinstance(datatype, Union):
for dt in datatype:
try:
return _coerce(dt, value, call=False)
except ArgTypeMismatchError:
continue
else:
raise ArgTypeMismatchError(value)
elif isinstance(datatype, Array):
try:
values = iter(value)
except TypeError:
raise ArgTypeMismatchError(value)
return [_coerce(datatype.itemtype, v, call=False)
for v in values]
elif isinstance(datatype, Field):
return _coerce(datatype.datatype, value)
elif isinstance(datatype, Fields):
class ArgNamespace(FieldsNamespace):
FIELDS = datatype
return _coerce(ArgNamespace, value)
elif issubclass(datatype, FieldsNamespace):
arg = datatype.bind(value)
try:
arg_coerce = arg.coerce
except AttributeError:
return arg
else:
return arg_coerce()
# fallbacks
elif callable(datatype) and call:
try:
return datatype(value)
except ArgTypeMismatchError:
raise
except (TypeError, ValueError):
raise ArgTypeMismatchError(value)
elif value == datatype:
return value
raise ArgTypeMismatchError(value)
########################
# fields
class FieldsNamespace(Readonly, WithRepr):
"""A namespace of field values exposed via attributes."""
FIELDS = None
PARAM_TYPE = None
PARAM = None
@classmethod
def traverse(cls, op, **kwargs):
"""Apply op to each field in cls.FIELDS."""
fields = cls._normalize(cls.FIELDS)
fields = fields.traverse(op)
cls.FIELDS = cls._normalize(fields)
return cls
@classmethod
def normalize(cls, *transforms):
"""Normalize FIELDS and apply the given ops."""
fields = cls._normalize(cls.FIELDS)
if not isinstance(fields, Fields):
fields = Fields(*fields)
for transform in transforms:
fields = _transform_datatype(fields, transform)
fields = cls._normalize(fields)
cls.FIELDS = fields
@classmethod
def _normalize(cls, fields):
if fields is None:
raise TypeError('missing FIELDS')
if isinstance(fields, Fields):
try:
normalized = cls._normalized
except AttributeError:
normalized = cls._normalized = False
else:
fields = Fields(*fields)
normalized = cls._normalized = False
if not normalized:
fields = _transform_datatype(fields,
lambda dt: _replace_ref(dt, cls))
return fields
@classmethod
def bind(cls, ns, **kwargs):
param = cls.PARAM
if param is None:
if cls.PARAM_TYPE is None:
return cls(**ns)
param = cls.PARAM_TYPE(cls.FIELDS, cls)
return param.bind(ns, **kwargs)
@classmethod
def _bind(cls, kwargs):
cls.FIELDS = cls._normalize(cls.FIELDS)
bound, missing = _fields_bind(cls.FIELDS, kwargs)
if missing:
raise IncompleteArgError(cls.FIELDS, missing)
values = {}
validators = []
serializers = {}
for field, arg in bound.items():
if arg is NOT_SET:
continue
try:
coerce = arg.coerce
except AttributeError:
value = arg
else:
value = coerce(arg)
values[field.name] = value
try:
validate = arg.validate
validate = value.validate
except AttributeError:
pass
else:
validators.append(validate)
try:
as_data = arg.as_data
as_data = value.as_data
except AttributeError:
pass
else:
serializers[field.name] = as_data
values['_validators'] = validators
values['_serializers'] = serializers
return values
def __init__(self, **kwargs):
super(FieldsNamespace, self).__init__()
validate = kwargs.pop('_validate', True)
kwargs = self._bind(kwargs)
self._bind_attrs(**kwargs)
if validate:
self.validate()
def _init_args(self):
if self.FIELDS is not None:
for field in self.FIELDS:
try:
value = getattr(self, field.name)
except AttributeError:
continue
yield (field.name, value)
else:
for item in sorted(vars(self).items()):
yield item
def __eq__(self, other):
try:
other_as_data = other.as_data
except AttributeError:
other_data = other
else:
other_data = other_as_data()
return self.as_data() == other_data
def __ne__(self, other):
return not (self == other)
def validate(self):
"""Ensure that the field values are valid."""
for validate in self._validators:
validate()
def as_data(self):
"""Return serializable data for the instance."""
data = {name: as_data()
for name, as_data in self._serializers.items()}
for field in self.FIELDS:
if field.name in data:
continue
try:
data[field.name] = getattr(self, field.name)
except AttributeError:
pass
return data
def _field_missing(field, value):
if value is NOT_SET:
return True
try:
missing = field.datatype.missing
except AttributeError:
return None
else:
return missing(value)
def _field_bind(field, value, applydefaults=True):
missing = _field_missing(field, value)
if missing:
if field.optional:
if applydefaults:
return field.default
return NOT_SET
raise ArgMissingError(field, missing)
try:
bind = field.datatype.bind
except AttributeError:
bind = (lambda v: _coerce(field.datatype, v))
return bind(value)
def _fields_iter_values(fields, remainder):
for field in fields or ():
value = remainder.pop(field.name, NOT_SET)
yield field, value
def _fields_iter_bound(fields, remainder, applydefaults=True):
for field, value in _fields_iter_values(fields, remainder):
try:
arg = _field_bind(field, value, applydefaults=applydefaults)
except ArgMissingError as exc:
yield field, value, exc, False
# except ArgTypeMismatchError as exc:
# yield field, value, None, exc
else:
yield field, arg, False, False
def _fields_bind(fields, kwargs, applydefaults=True):
bound = {}
missing = {}
mismatched = {}
remainder = dict(kwargs)
bound_iter = _fields_iter_bound(fields, remainder,
applydefaults=applydefaults)
for field, arg, missed, mismatch in bound_iter:
if missed:
missing[field.name] = missed
elif mismatch:
mismatched[field.name] = arg
else:
bound[field] = arg
if remainder:
remainder = ', '.join(sorted(remainder))
raise TypeError('got extra fields: {}'.format(remainder))
if mismatched:
raise ArgTypeMismatchError(mismatched)
return bound, missing

View file

@ -0,0 +1,319 @@
from collections import namedtuple
from collections.abc import Sequence
from debugger_protocol._base import Readonly
from ._common import sentinel, NOT_SET, ANY, SIMPLE_TYPES
REF = '<ref>'
TYPE_REFERENCE = sentinel('TYPE_REFERENCE')
def _is_simple(datatype):
if datatype is ANY:
return True
elif datatype in list(SIMPLE_TYPES):
return True
elif isinstance(datatype, Enum):
return True
else:
return False
def _normalize_datatype(datatype):
cls = type(datatype)
if datatype == REF or datatype is TYPE_REFERENCE:
return TYPE_REFERENCE
elif datatype is ANY:
return ANY
elif datatype in list(SIMPLE_TYPES):
return datatype
elif isinstance(datatype, Enum):
return datatype
elif isinstance(datatype, Union):
return datatype
elif isinstance(datatype, Array):
return datatype
elif cls is set or cls is frozenset:
return Union(*datatype)
elif cls is list or cls is tuple:
datatype, = datatype
return Array(datatype)
elif cls is dict:
raise NotImplementedError
else:
return datatype
def _transform_datatype(datatype, op):
try:
dt_traverse = datatype.traverse
except AttributeError:
pass
else:
datatype = dt_traverse(lambda dt: _transform_datatype(dt, op))
return op(datatype)
def _replace_ref(datatype, target):
if datatype is TYPE_REFERENCE:
return target
else:
return datatype
class Enum(namedtuple('Enum', 'datatype choice')):
"""A simple type with a limited set of allowed values."""
@classmethod
def _check_choice(cls, datatype, choice, strict=True):
if callable(choice):
return choice
if isinstance(choice, str):
msg = 'bad choice (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choice))
choice = frozenset(choice)
if not choice:
raise TypeError('missing choice')
if not strict:
return choice
for value in choice:
if type(value) is not datatype:
msg = 'bad choice (expected {!r} values, got {!r})'
raise ValueError(msg.format(datatype, choice))
return choice
def __new__(cls, datatype, choice, **kwargs):
strict = kwargs.pop('strict', True)
normalize = kwargs.pop('_normalize', True)
(lambda: None)(**kwargs) # Make sure there aren't any other kwargs.
if not isinstance(datatype, type):
raise ValueError('expected a class, got {!r}'.format(datatype))
if datatype not in list(SIMPLE_TYPES):
msg = 'only simple datatypes are supported, got {!r}'
raise ValueError(msg.format(datatype))
if normalize:
# There's no need to normalize datatype (it's a simple type).
pass
choice = cls._check_choice(datatype, choice, strict=strict)
self = super(Enum, cls).__new__(cls, datatype, choice)
return self
class Union(tuple):
"""Declare a union of different types.
The declared order is preserved and respected.
Sets and frozensets are treated Unions in declarations.
"""
@classmethod
def _traverse(cls, datatypes, op):
changed = False
result = []
for datatype in datatypes:
transformed = op(datatype)
if transformed is not datatype:
changed = True
result.append(transformed)
return result, changed
def __new__(cls, *datatypes, **kwargs):
normalize = kwargs.pop('_normalize', True)
(lambda: None)(**kwargs) # Make sure there aren't any other kwargs.
datatypes = list(datatypes)
if normalize:
datatypes, _ = cls._traverse(
datatypes,
lambda dt: _transform_datatype(dt, _normalize_datatype),
)
self = super(Union, cls).__new__(cls, datatypes)
self._simple = all(_is_simple(dt) for dt in datatypes)
return self
def __repr__(self):
return '{}{}'.format(type(self).__name__, tuple(self))
def __hash__(self):
return super(Union, self).__hash__()
def __eq__(self, other): # honors order
if not isinstance(other, Union):
return NotImplemented
if super(Union, self).__eq__(other):
return True
if set(self) != set(other):
return False
if self._simple and other._simple:
return True
return NotImplemented
def __ne__(self, other):
return not (self == other)
@property
def datatypes(self):
return set(self)
def traverse(self, op, **kwargs):
"""Return a copy with op applied to each contained datatype."""
datatypes, changed = self._traverse(self, op)
if not changed and not kwargs:
return self
return self.__class__(*datatypes, **kwargs)
class Array(Readonly):
"""Declare an array (of a single type).
Lists and tuples (single-item) are treated equivalently
in declarations.
"""
def __init__(self, itemtype, _normalize=True):
if _normalize:
itemtype = _normalize_datatype(itemtype)
self._bind_attrs(
itemtype=itemtype,
)
def __repr__(self):
return '{}(datatype={!r})'.format(type(self).__name__, self.itemtype)
def __hash__(self):
return hash(self.itemtype)
def __eq__(self, other):
try:
other_itemtype = other.itemtype
except AttributeError:
return False
return self.itemtype == other_itemtype
def __ne__(self, other):
return not (self == other)
def traverse(self, op, **kwargs):
"""Return a copy with op applied to the item datatype."""
datatype = op(self.itemtype)
if datatype is self.itemtype and not kwargs:
return self
return self.__class__(datatype, **kwargs)
class Field(namedtuple('Field', 'name datatype default optional')):
"""Declare a field in a data map param."""
START_OPTIONAL = sentinel('START_OPTIONAL')
def __new__(cls, name, datatype=str, enum=None, default=NOT_SET,
optional=False, _normalize=True, **kwargs):
if enum is not None and not isinstance(enum, Enum):
datatype = Enum(datatype, enum)
enum = None
if _normalize:
datatype = _normalize_datatype(datatype)
self = super(Field, cls).__new__(
cls,
name=str(name) if name else None,
datatype=datatype,
default=default,
optional=bool(optional),
)
self._kwargs = kwargs.items()
return self
@property
def kwargs(self):
return dict(self._kwargs)
def traverse(self, op, **kwargs):
"""Return a copy with op applied to the datatype."""
datatype = op(self.datatype)
if datatype is self.datatype and not kwargs:
return self
kwargs.setdefault('default', self.default)
kwargs.setdefault('optional', self.optional)
return self.__class__(self.name, datatype, **kwargs)
class Fields(Readonly, Sequence):
"""Declare a set of fields."""
@classmethod
def _iter_fixed(cls, fields, _normalize=True):
optional = None
for field in fields or ():
if field is Field.START_OPTIONAL:
if optional is not None:
raise RuntimeError('START_OPTIONAL used more than once')
optional = True
continue
if not isinstance(field, Field):
raise TypeError('got non-field {!r}'.format(field))
if _normalize:
field = _transform_datatype(field, _normalize_datatype)
if optional is not None and field.optional is not optional:
field = field._replace(optional=optional)
yield field
def __init__(self, *fields, **kwargs):
fields = list(self._iter_fixed(fields, **kwargs))
self._bind_attrs(
_fields=fields,
)
def __repr__(self):
return '{}(*{})'.format(type(self).__name__, self._fields)
def __hash__(self):
return hash(tuple(self))
def __eq__(self, other):
try:
other_len = len(other)
other_iter = iter(other)
except TypeError:
return False
if len(self) != other_len:
return False
for i, item in enumerate(other_iter):
if self[i] != item:
return False
return True
def __ne__(self, other):
return not (self == other)
def __len__(self):
return len(self._fields)
def __getitem__(self, index):
return self._fields[index]
def as_dict(self):
return {field.name: field for field in self._fields}
def traverse(self, op, **kwargs):
"""Return a copy with op applied to each field."""
changed = False
updated = []
for field in self._fields:
transformed = op(field)
if transformed is not field:
changed = True
updated.append(transformed)
if not changed and not kwargs:
return self
return self.__class__(*updated, **kwargs)

View file

@ -0,0 +1,32 @@
class ArgumentError(TypeError):
"""The base class for argument-related exceptions."""
class ArgMissingError(ArgumentError):
"""Indicates that the argument for the field is missing."""
def __init__(self, field):
super(ArgMissingError, self).__init__(
'missing arg {!r}'.format(field.name))
self.field = field
class IncompleteArgError(ArgumentError):
"""Indicates that the "complex" arg has missing fields."""
def __init__(self, fields, missing):
msg = 'incomplete arg (missing or incomplete fields: {})'
super(IncompleteArgError, self).__init__(
msg.format(', '.join(sorted(missing))))
self.fields = fields
self.missing = missing
class ArgTypeMismatchError(ArgumentError):
"""Indicates that the arg did not have the expected type."""
def __init__(self, value):
super(ArgTypeMismatchError, self).__init__(
'bad value {!r} (unsupported type)'.format(value))
self.value = value

View file

@ -0,0 +1,221 @@
from debugger_protocol._base import Readonly, WithRepr
class _ParameterBase(WithRepr):
def __init__(self, datatype):
self._datatype = datatype
def _init_args(self):
yield ('datatype', self._datatype)
def __hash__(self):
try:
return hash(self._datatype)
except TypeError:
return hash(id(self))
def __eq__(self, other):
if type(self) is not type(other):
return NotImplemented
return self._datatype == other._datatype
def __ne__(self, other):
return not (self == other)
@property
def datatype(self):
return self._datatype
class Parameter(_ParameterBase):
"""Base class for different parameter types."""
def __init__(self, datatype, handler=None):
super(Parameter, self).__init__(datatype)
self._handler = handler
def _init_args(self):
for item in super(Parameter, self)._init_args():
yield item
if self._handler is not None:
yield ('handler', self._handler)
def bind(self, raw):
"""Return an Arg for the given raw value.
As with match_type(), if the value is not supported by this
parameter return None.
"""
handler = self.match_type(raw)
if handler is None:
return None
return Arg(self, raw, handler)
def match_type(self, raw):
"""Return the datatype handler to use for the given raw value.
If the value does not match then return None.
"""
return self._handler
class DatatypeHandler(_ParameterBase):
"""Base class for datatype handlers."""
def coerce(self, raw):
"""Return the deserialized equivalent of the given raw value."""
# By default this is a noop.
return raw
def validate(self, coerced):
"""Ensure that the already-deserialized value is correct."""
# By default this is a noop.
return
def as_data(self, coerced):
"""Return a serialized equivalent of the given value.
This method round-trips with the "coerce()" method.
"""
# By default this is a noop.
return coerced
class Arg(Readonly, WithRepr):
"""The bridge between a raw value and a deserialized one.
This is primarily the product of Parameter.bind().
"""
# The value of this type lies in encapsulating intermediate state
# and in caching data.
def __init__(self, param, value, handler=None, israw=True):
if not isinstance(param, Parameter):
raise TypeError(
'bad param (expected Parameter, got {!r})'.format(param))
if handler is None:
if israw:
handler = param.match_type(value)
else:
raise TypeError('missing handler')
if not isinstance(handler, DatatypeHandler):
msg = 'bad handler (expected DatatypeHandler, got {!r})'
raise TypeError(msg.format(handler))
key = '_raw' if israw else '_value'
kwargs = {key: value}
self._bind_attrs(
param=param,
_handler=handler,
_validated=False,
**kwargs
)
def _init_args(self):
yield ('param', self.param)
israw = True
try:
yield ('value', self._raw)
except AttributeError:
yield ('value', self._value)
israw = False
if self.datatype != self.param.datatype:
yield ('handler', self._handler)
if not israw:
yield ('israw', False)
def __hash__(self):
try:
return hash(self.datatype)
except TypeError:
return hash(id(self))
def __eq__(self, other):
if type(self) is not type(other):
return False
if self.param != other.param:
return False
return self._as_data() == other._as_data()
def __ne__(self, other):
return not (self == other)
@property
def datatype(self):
return self._handler.datatype
@property
def raw(self):
"""The serialized value."""
return self.as_data()
@property
def value(self):
"""The de-serialized value."""
value = self.coerce()
if not self._validated:
self._validate()
return value
def coerce(self, cached=True):
"""Return the deserialized equivalent of the raw value."""
if not cached:
try:
raw = self._raw
except AttributeError:
# Use the cached value anyway.
return self._value
else:
return self._handler.coerce(raw)
try:
return self._value
except AttributeError:
value = self._handler.coerce(self._raw)
self._bind_attrs(
_value=value,
)
return value
def validate(self, force=False):
"""Ensure that the (deserialized) value is correct.
If the value has a "validate()" method then it gets called.
Otherwise it's up to the handler.
"""
if not self._validated or force:
self.coerce()
self._validate()
def _validate(self):
try:
validate = self._value.validate
except AttributeError:
self._handler.validate(self._value)
else:
validate()
self._bind_attrs(
_validated=True,
)
def as_data(self, cached=True):
"""Return a serialized equivalent of the value."""
self.validate()
if not cached:
return self._handler.as_data(self._value)
return self._as_data()
def _as_data(self):
try:
return self._raw
except AttributeError:
try:
as_data = self._value.as_data
except AttributeError:
as_data = self._handler.as_data
raw = as_data(self._value)
self._bind_attrs(
_raw=raw,
)
return raw

View file

@ -0,0 +1,380 @@
from ._common import ANY, SIMPLE_TYPES
from ._datatype import FieldsNamespace
from ._decl import Enum, Union, Array, Field, Fields
from ._errors import ArgTypeMismatchError
from ._param import Parameter, DatatypeHandler
#def as_parameter(cls):
# """Return a parameter that wraps the given FieldsNamespace subclass."""
# # XXX inject_params
# cls.normalize(_inject_params)
# param = param_from_datatype(cls)
## cls.PARAM = param
# return param
#
#
#def _inject_params(datatype):
# return param_from_datatype(datatype)
def param_from_datatype(datatype, **kwargs):
"""Return a parameter for the given datatype."""
if isinstance(datatype, Parameter):
return datatype
if isinstance(datatype, DatatypeHandler):
return Parameter(datatype.datatype, datatype, **kwargs)
elif isinstance(datatype, Fields):
return ComplexParameter(datatype, **kwargs)
elif isinstance(datatype, Field):
return param_from_datatype(datatype.datatype, **kwargs)
elif datatype is ANY:
return NoopParameter()
elif datatype is None:
return SingletonParameter(None)
elif datatype in list(SIMPLE_TYPES):
return SimpleParameter(datatype, **kwargs)
elif isinstance(datatype, Enum):
return EnumParameter(datatype.datatype, datatype.choice, **kwargs)
elif isinstance(datatype, Union):
return UnionParameter(datatype, **kwargs)
elif isinstance(datatype, (set, frozenset)):
return UnionParameter(Union(*datatype), **kwargs)
elif isinstance(datatype, Array):
return ArrayParameter(datatype, **kwargs)
elif isinstance(datatype, (list, tuple)):
datatype, = datatype
return ArrayParameter(Array(datatype), **kwargs)
elif not isinstance(datatype, type):
raise NotImplementedError
elif issubclass(datatype, FieldsNamespace):
return ComplexParameter(datatype, **kwargs)
else:
raise NotImplementedError
########################
# param types
class NoopParameter(Parameter):
"""A parameter that treats any value as-is."""
def __init__(self):
handler = DatatypeHandler(ANY)
super(NoopParameter, self).__init__(ANY, handler)
NOOP = NoopParameter()
class SingletonParameter(Parameter):
"""A parameter that works only for the given value."""
class HANDLER(DatatypeHandler):
def validate(self, coerced):
if coerced is not self.datatype:
raise ValueError(
'expected {!r}, got {!r}'.format(self.datatype, coerced))
def __init__(self, obj):
handler = self.HANDLER(obj)
super(SingletonParameter, self).__init__(obj, handler)
def match_type(self, raw):
# Note we do not check equality for singletons.
if raw is not self.datatype:
return None
return super(SingletonParameter, self).match_type(raw)
class SimpleHandler(DatatypeHandler):
"""A datatype handler for basic value types."""
def __init__(self, cls):
if not isinstance(cls, type):
raise ValueError('expected a class, got {!r}'.format(cls))
super(SimpleHandler, self).__init__(cls)
def coerce(self, raw):
if type(raw) is self.datatype:
return raw
return self.datatype(raw)
def validate(self, coerced):
if type(coerced) is not self.datatype:
raise ValueError(
'expected {!r}, got {!r}'.format(self.datatype, coerced))
class SimpleParameter(Parameter):
"""A parameter for basic value types."""
HANDLER = SimpleHandler
def __init__(self, cls, strict=True):
handler = self.HANDLER(cls)
super(SimpleParameter, self).__init__(cls, handler)
self._strict = strict
def match_type(self, raw):
if self._strict:
if type(raw) is not self.datatype:
return None
elif not isinstance(raw, self.datatype):
return None
return super(SimpleParameter, self).match_type(raw)
class EnumParameter(Parameter):
"""A parameter for enums of basic value types."""
class HANDLER(SimpleHandler):
def __init__(self, cls, enum):
if not enum:
raise TypeError('missing enum')
super(EnumParameter.HANDLER, self).__init__(cls)
if not callable(enum):
enum = set(enum)
self.enum = enum
def validate(self, coerced):
super(EnumParameter.HANDLER, self).validate(coerced)
if not self._match_enum(coerced):
msg = 'expected one of {!r}, got {!r}'
raise ValueError(msg.format(self.enum, coerced))
def _match_enum(self, coerced):
if callable(self.enum):
if not self.enum(coerced):
return False
elif coerced not in self.enum:
return False
return True
def __init__(self, cls, enum):
handler = self.HANDLER(cls, enum)
super(EnumParameter, self).__init__(cls, handler)
self._match_enum = handler._match_enum
def match_type(self, raw):
if type(raw) is not self.datatype:
return None
if not self._match_enum(raw):
return None
return super(EnumParameter, self).match_type(raw)
class UnionParameter(Parameter):
"""A parameter that supports multiple different types."""
HANDLER = None # no handler
@classmethod
def from_datatypes(cls, *datatypes, **kwargs):
datatype = Union(*datatypes)
return cls(datatype, **kwargs)
def __init__(self, datatype, **kwargs):
if not isinstance(datatype, Union):
raise ValueError('expected Union, got {!r}'.format(datatype))
super(UnionParameter, self).__init__(datatype)
choice = []
for dt in datatype:
param = param_from_datatype(dt)
choice.append(param)
self.choice = choice
def __eq__(self, other):
if type(self) is not type(other):
return False
return set(self.datatype) == set(other.datatype)
def match_type(self, raw):
for param in self.choice:
handler = param.match_type(raw)
if handler is not None:
return handler
return None
class ArrayParameter(Parameter):
"""A parameter that is a list of some fixed type."""
class HANDLER(DatatypeHandler):
def __init__(self, datatype, handlers=None, itemparam=None):
if not isinstance(datatype, Array):
raise ValueError(
'expected an Array, got {!r}'.format(datatype))
super(ArrayParameter.HANDLER, self).__init__(datatype)
self.handlers = handlers
self.itemparam = itemparam
def coerce(self, raw):
if self.handlers is None:
if self.itemparam is None:
itemtype = self.datatype.itemtype
self.itemparam = param_from_datatype(itemtype)
handlers = []
for item in raw:
handler = self.itemparam.match_type(item)
if handler is None:
raise ArgTypeMismatchError(item)
handlers.append(handler)
self.handlers = handlers
result = []
for i, item in enumerate(raw):
handler = self.handlers[i]
item = handler.coerce(item)
result.append(item)
return result
def validate(self, coerced):
if self.handlers is None:
raise TypeError('coerce first')
for i, item in enumerate(coerced):
handler = self.handlers[i]
handler.validate(item)
def as_data(self, coerced):
if self.handlers is None:
raise TypeError('coerce first')
data = []
for i, item in enumerate(coerced):
handler = self.handlers[i]
datum = handler.as_data(item)
data.append(datum)
return data
@classmethod
def from_itemtype(cls, itemtype, **kwargs):
datatype = Array(itemtype)
return cls(datatype, **kwargs)
def __init__(self, datatype):
if not isinstance(datatype, Array):
raise ValueError('expected Array, got {!r}'.format(datatype))
itemparam = param_from_datatype(datatype.itemtype)
handler = self.HANDLER(datatype, None, itemparam)
super(ArrayParameter, self).__init__(datatype, handler)
self.itemparam = itemparam
def match_type(self, raw):
if not isinstance(raw, list):
return None
handlers = []
for item in raw:
handler = self.itemparam.match_type(item)
if handler is None:
return None
handlers.append(handler)
return self.HANDLER(self.datatype, handlers)
class ComplexParameter(Parameter):
class HANDLER(DatatypeHandler):
def __init__(self, datatype, handlers=None):
if (type(datatype) is not type or
not issubclass(datatype, FieldsNamespace)
):
msg = 'expected FieldsNamespace, got {!r}'
raise ValueError(msg.format(datatype))
super(ComplexParameter.HANDLER, self).__init__(datatype)
self.handlers = handlers
def coerce(self, raw):
if self.handlers is None:
fields = self.datatype.FIELDS.as_dict()
handlers = {}
for name, value in raw.items():
param = param_from_datatype(fields[name])
handler = param.match_type(value)
if handler is None:
raise ArgTypeMismatchError((name, value))
handlers[name] = handler
self.handlers = handlers
result = {}
for name, value in raw.items():
handler = self.handlers[name]
value = handler.coerce(value)
result[name] = value
return self.datatype(**result)
def validate(self, coerced):
if self.handlers is None:
raise TypeError('coerce first')
for field in self.datatype.FIELDS:
try:
value = getattr(coerced, field.name)
except AttributeError:
continue
handler = self.handlers[field.name]
handler.validate(value)
def as_data(self, coerced):
if self.handlers is None:
raise TypeError('coerce first')
data = {}
for field in self.datatype.FIELDS:
try:
value = getattr(coerced, field.name)
except AttributeError:
continue
handler = self.handlers[field.name]
datum = handler.as_data(value)
data[field.name] = datum
return data
def __init__(self, datatype):
if isinstance(datatype, Fields):
class ArgNamespace(FieldsNamespace):
FIELDS = datatype
datatype = ArgNamespace
elif (type(datatype) is not type or
not issubclass(datatype, FieldsNamespace)):
msg = 'expected Fields or FieldsNamespace, got {!r}'
raise ValueError(msg.format(datatype))
datatype.normalize()
# We set handler later in match_type().
super(ComplexParameter, self).__init__(datatype)
self.params = {field.name: param_from_datatype(field)
for field in datatype.FIELDS}
def __eq__(self, other):
if super(ComplexParameter, self).__eq__(other):
return True
try:
fields = self._datatype.FIELDS
other_fields = other._datatype.FIELDS
except AttributeError:
return NotImplemented
else:
return fields == other_fields
def match_type(self, raw):
if not isinstance(raw, dict):
return None
handlers = {}
for field in self.datatype.FIELDS:
try:
value = raw[field.name]
except KeyError:
if not field.optional:
return None
value = field.default
param = self.params[field.name]
handler = param.match_type(value)
if handler is None:
return None
handlers[field.name] = handler
return self.HANDLER(self.datatype, handlers)

View file

@ -0,0 +1,22 @@
from debugger_protocol._base import Readonly, WithRepr
class Base(Readonly, WithRepr):
"""Base class for message-related types."""
_INIT_ARGS = None
@classmethod
def from_data(cls, **kwargs):
"""Return an instance based on the given raw data."""
return cls(**kwargs)
def __init__(self):
self._validate()
def _validate(self):
pass
def as_data(self):
"""Return serializable data for the instance."""
return {}

View file

View file

@ -0,0 +1,50 @@
from debugger_protocol.arg import ANY, FieldsNamespace, Field
FIELDS_BASIC = [
Field('name'),
Field.START_OPTIONAL,
Field('value'),
]
BASIC_FULL = {
'name': 'spam',
'value': 'eggs',
}
BASIC_MIN = {
'name': 'spam',
}
class Basic(FieldsNamespace):
FIELDS = FIELDS_BASIC
FIELDS_EXTENDED = [
Field('name', datatype=str, optional=False),
Field('valid', datatype=bool, optional=True),
Field('id', datatype=int, optional=False),
Field('value', datatype=ANY, optional=True),
Field('x', datatype=Basic, optional=True),
Field('y', datatype={int, str}, optional=True),
Field('z', datatype=[Basic], optional=True),
]
EXTENDED_FULL = {
'name': 'spam',
'valid': True,
'id': 10,
'value': None,
'x': BASIC_FULL,
'y': 11,
'z': [
BASIC_FULL,
BASIC_MIN,
],
}
EXTENDED_MIN = {
'name': 'spam',
'id': 10,
}

View file

@ -0,0 +1,255 @@
import itertools
import unittest
from debugger_protocol.arg._common import ANY
from debugger_protocol.arg._datatype import FieldsNamespace
from debugger_protocol.arg._decl import Array, Field, Fields
from ._common import (
BASIC_FULL, BASIC_MIN, Basic,
FIELDS_EXTENDED, EXTENDED_FULL, EXTENDED_MIN)
class FieldsNamespaceTests(unittest.TestCase):
def test_traverse_noop(self):
fields = [
Field('spam'),
Field('ham'),
Field('eggs'),
]
class Spam(FieldsNamespace):
FIELDS = Fields(*fields)
calls = []
op = (lambda dt: calls.append(dt) or dt)
transformed = Spam.traverse(op)
self.assertIs(transformed, Spam)
self.assertIs(transformed.FIELDS, Spam.FIELDS)
for i, field in enumerate(Spam.FIELDS):
self.assertIs(field, fields[i])
self.assertCountEqual(calls, [
# Note that it did not recurse into the fields.
Field('spam'),
Field('ham'),
Field('eggs'),
])
def test_traverse_unnormalized(self):
fields = [
Field('spam'),
Field('ham'),
Field('eggs'),
]
class Spam(FieldsNamespace):
FIELDS = fields
calls = []
op = (lambda dt: calls.append(dt) or dt)
transformed = Spam.traverse(op)
self.assertIs(transformed, Spam)
self.assertIsInstance(transformed.FIELDS, Fields)
for i, field in enumerate(Spam.FIELDS):
self.assertIs(field, fields[i])
self.assertCountEqual(calls, [
Field('spam'),
Field('ham'),
Field('eggs'),
])
def test_traverse_changed(self):
class Spam(FieldsNamespace):
FIELDS = Fields(
Field('spam', ANY),
Field('eggs', None),
)
calls = []
op = (lambda dt: calls.append(dt) or Field(dt.name, str))
transformed = Spam.traverse(op)
self.assertIs(transformed, Spam)
self.assertEqual(transformed.FIELDS, Fields(
Field('spam', str),
Field('eggs', str),
))
self.assertEqual(calls, [
Field('spam', ANY),
Field('eggs', None),
])
def test_normalize_without_ops(self):
fieldlist = [
Field('spam'),
Field('ham'),
Field('eggs'),
]
fields = Fields(*fieldlist)
class Spam(FieldsNamespace):
FIELDS = fields
Spam.normalize()
self.assertIs(Spam.FIELDS, fields)
for i, field in enumerate(Spam.FIELDS):
self.assertIs(field, fieldlist[i])
def test_normalize_unnormalized(self):
fieldlist = [
Field('spam'),
Field('ham'),
Field('eggs'),
]
class Spam(FieldsNamespace):
FIELDS = fieldlist
Spam.normalize()
self.assertIsInstance(Spam.FIELDS, Fields)
for i, field in enumerate(Spam.FIELDS):
self.assertIs(field, fieldlist[i])
def test_normalize_with_ops_noop(self):
fieldlist = [
Field('spam'),
Field('ham', int),
Field('eggs', Array(ANY)),
]
fields = Fields(*fieldlist)
class Spam(FieldsNamespace):
FIELDS = fields
calls = []
op1 = (lambda dt: calls.append((op1, dt)) or dt)
op2 = (lambda dt: calls.append((op2, dt)) or dt)
Spam.normalize(op1, op2)
self.assertIs(Spam.FIELDS, fields)
for i, field in enumerate(Spam.FIELDS):
self.assertIs(field, fieldlist[i])
self.assertEqual(calls, [
(op1, str),
(op1, Field('spam')),
(op1, int),
(op1, Field('ham', int)),
(op1, ANY),
(op1, Array(ANY)),
(op1, Field('eggs', Array(ANY))),
(op1, fields),
(op2, str),
(op2, Field('spam')),
(op2, int),
(op2, Field('ham', int)),
(op2, ANY),
(op2, Array(ANY)),
(op2, Field('eggs', Array(ANY))),
(op2, fields),
])
def test_normalize_with_op_changed(self):
class Spam(FieldsNamespace):
FIELDS = Fields(
Field('spam', Array(ANY)),
)
op = (lambda dt: int if dt is ANY else dt)
Spam.normalize(op)
self.assertEqual(Spam.FIELDS, Fields(
Field('spam', Array(int)),
))
def test_normalize_missing(self):
with self.assertRaises(TypeError):
FieldsNamespace.normalize()
#######
def test_fields_full(self):
class Spam(FieldsNamespace):
FIELDS = FIELDS_EXTENDED
spam = Spam(**EXTENDED_FULL)
ns = vars(spam)
del ns['_validators']
del ns['_serializers']
self.assertEqual(ns, {
'name': 'spam',
'valid': True,
'id': 10,
'value': None,
'x': Basic(**BASIC_FULL),
'y': 11,
'z': [
Basic(**BASIC_FULL),
Basic(**BASIC_MIN),
],
})
def test_fields_min(self):
class Spam(FieldsNamespace):
FIELDS = FIELDS_EXTENDED
spam = Spam(**EXTENDED_MIN)
ns = vars(spam)
del ns['_validators']
del ns['_serializers']
self.assertEqual(ns, {
'name': 'spam',
'id': 10,
})
def test_no_fields(self):
with self.assertRaises(TypeError):
FieldsNamespace(
x='spam',
y=42,
z=None,
)
def test_attrs(self):
ns = Basic(name='<name>', value='<value>')
self.assertEqual(ns.name, '<name>')
self.assertEqual(ns.value, '<value>')
def test_equality(self):
ns1 = Basic(name='<name>', value='<value>')
ns2 = Basic(name='<name>', value='<value>')
self.assertTrue(ns1 == ns1)
self.assertTrue(ns1 == ns2)
def test_inequality(self):
p = [Basic(name=n, value=v)
for n in ['<>', '<name>']
for v in ['<>', '<value>']]
for basic1, basic2 in itertools.combinations(p, 2):
with self.subTest((basic1, basic2)):
self.assertTrue(basic1 != basic2)
@unittest.skip('not ready')
def test_validate(self):
# TODO: finish
raise NotImplementedError
def test_as_data(self):
class Spam(FieldsNamespace):
FIELDS = FIELDS_EXTENDED
spam = Spam(**EXTENDED_FULL)
sdata = spam.as_data()
basic = Basic(**BASIC_FULL)
bdata = basic.as_data()
self.assertEqual(sdata, EXTENDED_FULL)
self.assertEqual(bdata, BASIC_FULL)

View file

@ -0,0 +1,477 @@
import unittest
from debugger_protocol.arg import NOT_SET, ANY
from debugger_protocol.arg._datatype import FieldsNamespace
from debugger_protocol.arg._decl import (
REF, TYPE_REFERENCE, _normalize_datatype, _transform_datatype,
Enum, Union, Array, Field, Fields)
from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg
from debugger_protocol.arg._params import (
SimpleParameter, UnionParameter, ArrayParameter, ComplexParameter)
class ModuleTests(unittest.TestCase):
def test_normalize_datatype(self):
NOOP = object()
param = SimpleParameter(str)
tests = [
# explicitly handled
(REF, TYPE_REFERENCE),
(TYPE_REFERENCE, NOOP),
(ANY, NOOP),
(None, NOOP),
(int, NOOP),
(str, NOOP),
(bool, NOOP),
(Enum(str, ('spam',)), NOOP),
(Union(str, int), NOOP),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
(Array(str), NOOP),
([str], Array(str)),
((str,), Array(str)),
# others
(Field('spam'), NOOP),
(Fields(Field('spam')), NOOP),
(param, NOOP),
(DatatypeHandler(str), NOOP),
(Arg(param, 'spam'), NOOP),
(SimpleParameter(str), NOOP),
(UnionParameter(Union(str)), NOOP),
(ArrayParameter(Array(str)), NOOP),
(ComplexParameter(Fields()), NOOP),
(NOT_SET, NOOP),
(object(), NOOP),
(object, NOOP),
(type, NOOP),
]
for datatype, expected in tests:
if expected is NOOP:
expected = datatype
with self.subTest(datatype):
datatype = _normalize_datatype(datatype)
self.assertEqual(datatype, expected)
with self.assertRaises(NotImplementedError):
_normalize_datatype({1: 2})
def test_transform_datatype_simple(self):
datatypes = [
REF,
TYPE_REFERENCE,
ANY,
None,
int,
str,
bool,
{str, int},
frozenset([str, int]),
[str],
(str,),
Parameter(object()),
DatatypeHandler(str),
Arg(SimpleParameter(str), 'spam'),
SimpleParameter(str),
UnionParameter(Union(str, int)),
ArrayParameter(Array(str)),
ComplexParameter(Fields()),
NOT_SET,
object(),
object,
type,
]
for expected in datatypes:
transformed = []
op = (lambda dt: transformed.append(dt) or dt)
with self.subTest(expected):
datatype = _transform_datatype(expected, op)
self.assertIs(datatype, expected)
self.assertEqual(transformed, [expected])
def test_transform_datatype_container(self):
class Spam(FieldsNamespace):
FIELDS = [
Field('a'),
]
Spam.normalize()
fields = Fields(Field('...'))
field_spam = Field('spam', ANY)
field_ham = Field('ham', Union(
Array(Spam),
))
field_eggs = Field('eggs', Array(TYPE_REFERENCE))
nested = Fields(
Field('???', fields),
field_spam,
field_ham,
field_eggs,
)
tests = {
Array(str): [
str,
Array(str),
],
Field('...'): [
str,
Field('...'),
],
fields: [
str,
Field('...'),
fields,
],
nested: [
str,
Field('...'),
fields,
Field('???', fields),
# ...
ANY,
Field('spam', ANY),
# ...
str,
Field('a'),
Fields(Field('a')),
Spam,
Array(Spam),
Union(Array(Spam)),
field_ham,
# ...
TYPE_REFERENCE,
Array(TYPE_REFERENCE),
field_eggs,
# ...
nested,
],
}
for datatype, expected in tests.items():
calls = []
op = (lambda dt: calls.append(dt) or dt)
with self.subTest(datatype):
transformed = _transform_datatype(datatype, op)
self.assertIs(transformed, datatype)
self.assertEqual(calls, expected)
# Check Union separately due to set iteration order.
calls = []
op = (lambda dt: calls.append(dt) or dt)
datatype = Union(str, int)
transformed = _transform_datatype(datatype, op)
self.assertIs(transformed, datatype)
self.assertEqual(set(calls[:2]), {str, int})
self.assertEqual(calls[2:], [
Union(str, int),
])
class EnumTests(unittest.TestCase):
def test_attrs(self):
enum = Enum(str, ('spam', 'eggs'))
datatype, choices = enum
self.assertIs(datatype, str)
self.assertEqual(choices, frozenset(['spam', 'eggs']))
def test_bad_datatype(self):
with self.assertRaises(ValueError):
Enum('spam', ('spam', 'eggs'))
with self.assertRaises(ValueError):
Enum(dict, ('spam', 'eggs'))
def test_bad_choices(self):
class String(str):
pass
with self.assertRaises(ValueError):
Enum(str, 'spam')
with self.assertRaises(TypeError):
Enum(str, ())
with self.assertRaises(ValueError):
Enum(str, ('spam', 10))
with self.assertRaises(ValueError):
Enum(str, ('spam', String))
class UnionTests(unittest.TestCase):
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),
({str, int}, Union(*{str, int})),
(frozenset([str, int]), Union(*frozenset([str, int]))),
([str], Array(str)),
((str,), Array(str)),
(None, None),
]
for datatype, expected in tests:
with self.subTest(datatype):
union = Union(int, datatype, str)
self.assertEqual(union, Union(int, expected, str))
with self.assertRaises(NotImplementedError):
Union({1: 2})
def test_traverse_noop(self):
calls = []
op = (lambda dt: calls.append(dt) or dt)
union = Union(str, Array(int), int)
transformed = union.traverse(op)
self.assertIs(transformed, union)
self.assertCountEqual(calls, [
str,
# Note that it did not recurse into Array(int).
Array(int),
int,
])
def test_traverse_changed(self):
calls = []
op = (lambda dt: calls.append(dt) or str)
union = Union(ANY)
transformed = union.traverse(op)
self.assertIsNot(transformed, union)
self.assertEqual(transformed, Union(str))
self.assertEqual(calls, [
ANY,
])
class ArrayTests(unittest.TestCase):
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
(None, None),
]
for datatype, expected in tests:
with self.subTest(datatype):
array = Array(datatype)
self.assertEqual(array, Array(expected))
with self.assertRaises(NotImplementedError):
Array({1: 2})
def test_traverse_noop(self):
calls = []
op = (lambda dt: calls.append(dt) or dt)
array = Array(Union(str, int))
transformed = array.traverse(op)
self.assertIs(transformed, array)
self.assertCountEqual(calls, [
# Note that it did not recurse into Union(str, int).
Union(str, int),
])
def test_traverse_changed(self):
calls = []
op = (lambda dt: calls.append(dt) or str)
array = Array(ANY)
transformed = array.traverse(op)
self.assertIsNot(transformed, array)
self.assertEqual(transformed, Array(str))
self.assertEqual(calls, [
ANY,
])
class FieldTests(unittest.TestCase):
def test_defaults(self):
field = Field('spam')
self.assertEqual(field.name, 'spam')
self.assertIs(field.datatype, str)
self.assertIs(field.default, NOT_SET)
self.assertFalse(field.optional)
def test_enum(self):
field = Field('spam', str, enum=('a', 'b', 'c'))
self.assertEqual(field.datatype, Enum(str, ('a', 'b', 'c')))
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
(None, None),
]
for datatype, expected in tests:
with self.subTest(datatype):
field = Field('spam', datatype)
self.assertEqual(field, Field('spam', expected))
with self.assertRaises(NotImplementedError):
Field('spam', {1: 2})
def test_traverse_noop(self):
calls = []
op = (lambda dt: calls.append(dt) or dt)
field = Field('spam', Union(str, int))
transformed = field.traverse(op)
self.assertIs(transformed, field)
self.assertCountEqual(calls, [
# Note that it did not recurse into Union(str, int).
Union(str, int),
])
def test_traverse_changed(self):
calls = []
op = (lambda dt: calls.append(dt) or str)
field = Field('spam', ANY)
transformed = field.traverse(op)
self.assertIsNot(transformed, field)
self.assertEqual(transformed, Field('spam', str))
self.assertEqual(calls, [
ANY,
])
class FieldsTests(unittest.TestCase):
def test_single(self):
fields = Fields(
Field('spam'),
)
self.assertEqual(fields, [
Field('spam'),
])
def test_multiple(self):
fields = Fields(
Field('spam'),
Field('ham'),
Field('eggs'),
)
self.assertEqual(fields, [
Field('spam'),
Field('ham'),
Field('eggs'),
])
def test_empty(self):
fields = Fields()
self.assertCountEqual(fields, [])
def test_normalized(self):
tests = [
(REF, TYPE_REFERENCE),
({str, int}, Union(str, int)),
(frozenset([str, int]), Union(str, int)),
([str], Array(str)),
((str,), Array(str)),
(None, None),
]
for datatype, expected in tests:
with self.subTest(datatype):
fields = Fields(
Field('spam', datatype),
)
self.assertEqual(fields, [
Field('spam', expected),
])
with self.assertRaises(NotImplementedError):
Fields(
Field('spam', {1: 2}),
)
def test_with_START_OPTIONAL(self):
fields = Fields(
Field('spam'),
Field('ham', optional=True),
Field('eggs'),
Field.START_OPTIONAL,
Field('a'),
Field('b', optional=False),
)
self.assertEqual(fields, [
Field('spam'),
Field('ham', optional=True),
Field('eggs'),
Field('a', optional=True),
Field('b', optional=True),
])
def test_non_field(self):
with self.assertRaises(TypeError):
Fields(str)
def test_as_dict(self):
fields = Fields(
Field('spam', int),
Field('ham'),
Field('eggs', Array(str)),
)
result = fields.as_dict
self.assertEqual(result, {
'spam': fields[0],
'ham': fields[1],
'eggs': fields[2],
})
def test_traverse_noop(self):
calls = []
op = (lambda dt: calls.append(dt) or dt)
fields = Fields(
Field('spam'),
Field('ham'),
Field('eggs'),
)
transformed = fields.traverse(op)
self.assertIs(transformed, fields)
self.assertCountEqual(calls, [
# Note that it did not recurse into the fields.
Field('spam'),
Field('ham'),
Field('eggs'),
])
def test_traverse_changed(self):
calls = []
op = (lambda dt: calls.append(dt) or Field(dt.name, str))
fields = Fields(
Field('spam', ANY),
Field('eggs', None),
)
transformed = fields.traverse(op)
self.assertIsNot(transformed, fields)
self.assertEqual(transformed, Fields(
Field('spam', str),
Field('eggs', str),
))
self.assertEqual(calls, [
Field('spam', ANY),
Field('eggs', None),
])

View file

@ -0,0 +1,237 @@
from types import SimpleNamespace
import unittest
from debugger_protocol.arg._param import Parameter, DatatypeHandler, Arg
from tests.helpers.stub import Stub
class FakeHandler(DatatypeHandler):
def __init__(self, datatype=str, stub=None):
super().__init__(datatype)
self.stub = stub or Stub()
self.returns = SimpleNamespace(
coerce=None,
as_data=None,
)
def coerce(self, raw):
self.stub.add_call('coerce', raw)
self.stub.maybe_raise()
return self.returns.coerce
def validate(self, coerced):
self.stub.add_call('validate', coerced)
self.stub.maybe_raise()
def as_data(self, coerced):
self.stub.add_call('as_data', coerced)
self.stub.maybe_raise()
return self.returns.as_data
class ParameterTests(unittest.TestCase):
def setUp(self):
super().setUp()
self.stub = Stub()
self.handler = FakeHandler(self.stub)
def test_bind_matched(self):
param = Parameter(str, self.handler)
arg = param.bind('spam')
self.assertEqual(arg, Arg(param, 'spam', self.handler))
self.assertEqual(self.stub.calls, [])
def test_bind_no_match(self):
param = Parameter(str)
arg = param.bind('spam')
self.assertIs(arg, None)
self.assertEqual(self.stub.calls, [])
def test_match_type_no_match(self):
param = Parameter(str)
matched = param.match_type('spam')
self.assertIs(matched, None)
self.assertEqual(self.stub.calls, [])
def test_match_type_matched(self):
param = Parameter(str, self.handler)
matched = param.match_type('spam')
self.assertIs(matched, self.handler)
self.assertEqual(self.stub.calls, [])
class DatatypeHandlerTests(unittest.TestCase):
def test_coerce(self):
handler = DatatypeHandler(str)
coerced = handler.coerce('spam')
self.assertEqual(coerced, 'spam')
def test_validate(self):
handler = DatatypeHandler(str)
handler.validate('spam')
def test_as_data(self):
handler = DatatypeHandler(str)
data = handler.as_data('spam')
self.assertEqual(data, 'spam')
class ArgTests(unittest.TestCase):
def setUp(self):
super().setUp()
self.stub = Stub()
self.handler = FakeHandler(str, self.stub)
self.param = Parameter(str, self.handler)
def test_raw_valid(self):
self.handler.returns.coerce = 'eggs'
arg = Arg(self.param, 'spam', self.handler)
raw = arg.raw
self.assertEqual(raw, 'spam')
self.assertEqual(self.stub.calls, [
('coerce', ('spam',), {}),
('validate', ('eggs',), {}),
])
def test_raw_invalid(self):
self.handler.returns.coerce = 'eggs'
self.stub.set_exceptions(
None,
ValueError('oops'),
)
arg = Arg(self.param, 'spam', self.handler)
with self.assertRaises(ValueError):
arg.raw
self.assertEqual(self.stub.calls, [
('coerce', ('spam',), {}),
('validate', ('eggs',), {}),
])
def test_raw_generated(self):
self.handler.returns.as_data = 'spam'
arg = Arg(self.param, 'eggs', self.handler, israw=False)
raw = arg.raw
self.assertEqual(raw, 'spam')
self.assertEqual(self.stub.calls, [
('validate', ('eggs',), {}),
('as_data', ('eggs',), {}),
])
def test_value_valid(self):
arg = Arg(self.param, 'eggs', self.handler, israw=False)
value = arg.value
self.assertEqual(value, 'eggs')
self.assertEqual(self.stub.calls, [
('validate', ('eggs',), {}),
])
def test_value_invalid(self):
self.stub.set_exceptions(
ValueError('oops'),
)
arg = Arg(self.param, 'eggs', self.handler, israw=False)
with self.assertRaises(ValueError):
arg.value
self.assertEqual(self.stub.calls, [
('validate', ('eggs',), {}),
])
def test_value_generated(self):
self.handler.returns.coerce = 'eggs'
arg = Arg(self.param, 'spam', self.handler)
value = arg.value
self.assertEqual(value, 'eggs')
self.assertEqual(self.stub.calls, [
('coerce', ('spam',), {}),
('validate', ('eggs',), {}),
])
def test_coerce(self):
self.handler.returns.coerce = 'eggs'
arg = Arg(self.param, 'spam', self.handler)
value = arg.coerce()
self.assertEqual(value, 'eggs')
self.assertEqual(self.stub.calls, [
('coerce', ('spam',), {}),
])
def test_validate_okay(self):
self.handler.returns.coerce = 'eggs'
arg = Arg(self.param, 'spam', self.handler)
arg.validate()
self.assertEqual(self.stub.calls, [
('coerce', ('spam',), {}),
('validate', ('eggs',), {}),
])
def test_validate_invalid(self):
self.stub.set_exceptions(
None,
ValueError('oops'),
)
self.handler.returns.coerce = 'eggs'
arg = Arg(self.param, 'spam', self.handler)
with self.assertRaises(ValueError):
arg.validate()
self.assertEqual(self.stub.calls, [
('coerce', ('spam',), {}),
('validate', ('eggs',), {}),
])
def test_validate_use_coerced(self):
handler = FakeHandler()
other = Arg(Parameter(str, handler), 'spam', handler, israw=False)
arg = Arg(Parameter(str, self.handler), other, self.handler,
israw=False)
arg.validate()
self.assertEqual(self.stub.calls, [])
self.assertEqual(handler.stub.calls, [
('validate', ('spam',), {}),
])
def test_as_data_use_handler(self):
self.handler.returns.as_data = 'spam'
arg = Arg(self.param, 'eggs', self.handler, israw=False)
data = arg.as_data()
self.assertEqual(data, 'spam')
self.assertEqual(self.stub.calls, [
('validate', ('eggs',), {}),
('as_data', ('eggs',), {}),
])
def test_as_data_use_coerced(self):
handler = FakeHandler()
other = Arg(Parameter(str, handler), 'spam', handler, israw=False)
handler.returns.as_data = 'spam'
arg = Arg(Parameter(str, self.handler), other, self.handler,
israw=False)
data = arg.as_data(other)
self.assertEqual(data, 'spam')
self.assertEqual(self.stub.calls, [])
self.assertEqual(handler.stub.calls, [
('validate', ('spam',), {}),
('as_data', ('spam',), {}),
])

View file

@ -0,0 +1,698 @@
import unittest
from debugger_protocol.arg._common import NOT_SET, ANY
from debugger_protocol.arg._decl import Enum, Union, Array, Field, Fields
from debugger_protocol.arg._param import Parameter, DatatypeHandler
from debugger_protocol.arg._params import (
param_from_datatype,
NoopParameter, SingletonParameter,
SimpleParameter, EnumParameter,
UnionParameter, ArrayParameter, ComplexParameter)
from ._common import FIELDS_BASIC, BASIC_FULL, Basic
class String(str):
pass
class Integer(int):
pass
class ParamFromDatatypeTest(unittest.TestCase):
def test_supported(self):
handler = DatatypeHandler(str)
tests = [
(Parameter(str), Parameter(str)),
(handler, Parameter(str, handler)),
(Fields(Field('spam')), ComplexParameter(Fields(Field('spam')))),
(Field('spam'), SimpleParameter(str)),
(Field('spam', str, enum={'a'}), EnumParameter(str, {'a'})),
(ANY, NoopParameter()),
(None, SingletonParameter(None)),
(str, SimpleParameter(str)),
(int, SimpleParameter(int)),
(bool, SimpleParameter(bool)),
(Enum(str, {'a'}), EnumParameter(str, {'a'})),
(Union(str, int), UnionParameter(Union(str, int))),
({str, int}, UnionParameter(Union(str, int))),
(frozenset([str, int]), UnionParameter(Union(str, int))),
(Array(str), ArrayParameter(Array(str))),
([str], ArrayParameter(Array(str))),
((str,), ArrayParameter(Array(str))),
(Basic, ComplexParameter(Basic)),
]
for datatype, expected in tests:
with self.subTest(datatype):
param = param_from_datatype(datatype)
self.assertEqual(param, expected)
def test_not_supported(self):
datatypes = [
String('spam'),
...,
]
for datatype in datatypes:
with self.subTest(datatype):
with self.assertRaises(NotImplementedError):
param_from_datatype(datatype)
class NoopParameterTests(unittest.TestCase):
VALUES = [
object(),
'spam',
10,
['spam'],
{'spam': 42},
True,
None,
NOT_SET,
]
def test_match_type(self):
values = [
object(),
'',
'spam',
b'spam',
0,
10,
10.0,
10+0j,
('spam',),
(),
['spam'],
[],
{'spam': 42},
{},
{'spam'},
set(),
object,
type,
NoopParameterTests,
True,
None,
...,
NotImplemented,
NOT_SET,
ANY,
Union(str, int),
Union(),
Array(str),
Field('spam'),
Fields(Field('spam')),
Fields(),
Basic,
]
for value in values:
with self.subTest(value):
param = NoopParameter()
handler = param.match_type(value)
self.assertIs(type(handler), DatatypeHandler)
self.assertIs(handler.datatype, ANY)
def test_coerce(self):
for value in self.VALUES:
with self.subTest(value):
param = NoopParameter()
handler = param.match_type(value)
coerced = handler.coerce(value)
self.assertIs(coerced, value)
def test_validate(self):
for value in self.VALUES:
with self.subTest(value):
param = NoopParameter()
handler = param.match_type(value)
handler.validate(value)
def test_as_data(self):
for value in self.VALUES:
with self.subTest(value):
param = NoopParameter()
handler = param.match_type(value)
data = handler.as_data(value)
self.assertIs(data, value)
class SingletonParameterTests(unittest.TestCase):
def test_match_type_matched(self):
param = SingletonParameter(None)
handler = param.match_type(None)
self.assertIs(handler.datatype, None)
def test_match_type_no_match(self):
tests = [
# same type, different value
('spam', 'eggs'),
(10, 11),
(True, False),
# different type but equivalent
('spam', b'spam'),
(10, 10.0),
(10, 10+0j),
(10, '10'),
(10, b'\10'),
]
for singleton, value in tests:
with self.subTest((singleton, value)):
param = SingletonParameter(singleton)
handler = param.match_type(value)
self.assertIs(handler, None)
def test_coerce(self):
param = SingletonParameter(None)
handler = param.match_type(None)
value = handler.coerce(None)
self.assertIs(value, None)
def test_validate_valid(self):
param = SingletonParameter(None)
handler = param.match_type(None)
handler.validate(None)
def test_validate_wrong_type(self):
tests = [
(None, True),
(True, None),
('spam', 10),
(10, 'spam'),
]
for singleton, value in tests:
with self.subTest(singleton):
param = SingletonParameter(singleton)
handler = param.match_type(singleton)
with self.assertRaises(ValueError):
handler.validate(value)
def test_validate_same_type_wrong_value(self):
tests = [
('spam', 'eggs'),
(True, False),
(10, 11),
]
for singleton, value in tests:
with self.subTest(singleton):
param = SingletonParameter(singleton)
handler = param.match_type(singleton)
with self.assertRaises(ValueError):
handler.validate(value)
def test_as_data(self):
param = SingletonParameter(None)
handler = param.match_type(None)
data = handler.as_data(None)
self.assertIs(data, None)
class SimpleParameterTests(unittest.TestCase):
def test_match_type_match(self):
tests = [
(str, 'spam'),
(str, String('spam')),
(int, 10),
(bool, True),
]
for datatype, value in tests:
with self.subTest((datatype, value)):
param = SimpleParameter(datatype, strict=False)
handler = param.match_type(value)
self.assertIs(handler.datatype, datatype)
def test_match_type_no_match(self):
tests = [
(int, 'spam'),
# coercible
(str, 10),
(int, 10.0),
(int, '10'),
(bool, 1),
# semi-coercible
(str, b'spam'),
(int, 10+0j),
(int, b'\10'),
]
for datatype, value in tests:
with self.subTest((datatype, value)):
param = SimpleParameter(datatype, strict=False)
handler = param.match_type(value)
self.assertIs(handler, None)
def test_match_type_strict_match(self):
tests = {
str: 'spam',
int: 10,
bool: True,
}
for datatype, value in tests.items():
with self.subTest(datatype):
param = SimpleParameter(datatype, strict=True)
handler = param.match_type(value)
self.assertIs(handler.datatype, datatype)
def test_match_type_strict_no_match(self):
tests = {
str: String('spam'),
int: Integer(10),
}
for datatype, value in tests.items():
with self.subTest(datatype):
param = SimpleParameter(datatype, strict=True)
handler = param.match_type(value)
self.assertIs(handler, None)
def test_coerce(self):
tests = [
(str, 'spam', 'spam'),
(str, String('spam'), 'spam'),
(int, 10, 10),
(bool, True, True),
# did not match, but still coercible
(str, 10, '10'),
(str, str, "<class 'str'>"),
(int, 10.0, 10),
(int, '10', 10),
(bool, 1, True),
]
for datatype, value, expected in tests:
with self.subTest((datatype, value)):
handler = SimpleParameter.HANDLER(datatype)
coerced = handler.coerce(value)
self.assertEqual(coerced, expected)
def test_validate_valid(self):
tests = {
str: 'spam',
int: 10,
bool: True,
}
for datatype, value in tests.items():
with self.subTest(datatype):
handler = SimpleParameter.HANDLER(datatype)
handler.validate(value)
def test_validate_invalid(self):
tests = [
(int, 'spam'),
# coercible
(str, String('spam')),
(str, 10),
(int, 10.0),
(int, '10'),
(bool, 1),
# semi-coercible
(str, b'spam'),
(int, 10+0j),
(int, b'\10'),
]
for datatype, value in tests:
with self.subTest((datatype, value)):
handler = SimpleParameter.HANDLER(datatype)
with self.assertRaises(ValueError):
handler.validate(value)
def test_as_data(self):
tests = [
(str, 'spam'),
(int, 10),
(bool, True),
# did not match, but still coercible
(str, String('spam')),
(str, 10),
(str, str),
(int, 10.0),
(int, '10'),
(bool, 1),
# semi-coercible
(str, b'spam'),
(int, 10+0j),
(int, b'\10'),
]
for datatype, value in tests:
with self.subTest((datatype, value)):
handler = SimpleParameter.HANDLER(datatype)
data = handler.as_data(value)
self.assertIs(data, value)
class EnumParameterTests(unittest.TestCase):
def test_match_type_match(self):
tests = [
(str, ('spam', 'eggs'), 'spam'),
(str, ('spam',), 'spam'),
(int, (1, 2, 3), 2),
(bool, (True,), True),
]
for datatype, enum, value in tests:
with self.subTest((datatype, enum)):
param = EnumParameter(datatype, enum)
handler = param.match_type(value)
self.assertIs(handler.datatype, datatype)
def test_match_type_no_match(self):
tests = [
# enum mismatch
(str, ('spam', 'eggs'), 'ham'),
(int, (1, 2, 3), 10),
# type mismatch
(int, (1, 2, 3), 'spam'),
# coercible
(str, ('spam', 'eggs'), String('spam')),
(str, ('1', '2', '3'), 2),
(int, (1, 2, 3), 2.0),
(int, (1, 2, 3), '2'),
(bool, (True,), 1),
# semi-coercible
(str, ('spam', 'eggs'), b'spam'),
(int, (1, 2, 3), 2+0j),
(int, (1, 2, 3), b'\02'),
]
for datatype, enum, value in tests:
with self.subTest((datatype, enum, value)):
param = EnumParameter(datatype, enum)
handler = param.match_type(value)
self.assertIs(handler, None)
def test_coerce(self):
tests = [
(str, 'spam', 'spam'),
(int, 10, 10),
(bool, True, True),
# did not match, but still coercible
(str, String('spam'), 'spam'),
(str, 10, '10'),
(str, str, "<class 'str'>"),
(int, 10.0, 10),
(int, '10', 10),
(bool, 1, True),
]
for datatype, value, expected in tests:
with self.subTest((datatype, value)):
enum = (expected,)
handler = EnumParameter.HANDLER(datatype, enum)
coerced = handler.coerce(value)
self.assertEqual(coerced, expected)
def test_coerce_enum_mismatch(self):
enum = ('spam', 'eggs')
handler = EnumParameter.HANDLER(str, enum)
coerced = handler.coerce('ham')
# It still works.
self.assertEqual(coerced, 'ham')
def test_validate_valid(self):
tests = [
(str, ('spam', 'eggs'), 'spam'),
(str, ('spam',), 'spam'),
(int, (1, 2, 3), 2),
(bool, (True, False), True),
]
for datatype, enum, value in tests:
with self.subTest((datatype, enum)):
handler = EnumParameter.HANDLER(datatype, enum)
handler.validate(value)
def test_validate_invalid(self):
tests = [
# enum mismatch
(str, ('spam', 'eggs'), 'ham'),
(int, (1, 2, 3), 10),
# type mismatch
(int, (1, 2, 3), 'spam'),
# coercible
(str, ('spam', 'eggs'), String('spam')),
(str, ('1', '2', '3'), 2),
(int, (1, 2, 3), 2.0),
(int, (1, 2, 3), '2'),
(bool, (True,), 1),
# semi-coercible
(str, ('spam', 'eggs'), b'spam'),
(int, (1, 2, 3), 2+0j),
(int, (1, 2, 3), b'\02'),
]
for datatype, enum, value in tests:
with self.subTest((datatype, enum, value)):
handler = EnumParameter.HANDLER(datatype, enum)
with self.assertRaises(ValueError):
handler.validate(value)
def test_as_data(self):
tests = [
(str, ('spam', 'eggs'), 'spam'),
(str, ('spam',), 'spam'),
(int, (1, 2, 3), 2),
(bool, (True,), True),
# enum mismatch
(str, ('spam', 'eggs'), 'ham'),
(int, (1, 2, 3), 10),
# type mismatch
(int, (1, 2, 3), 'spam'),
# coercible
(str, ('spam', 'eggs'), String('spam')),
(str, ('1', '2', '3'), 2),
(int, (1, 2, 3), 2.0),
(int, (1, 2, 3), '2'),
(bool, (True,), 1),
# semi-coercible
(str, ('spam', 'eggs'), b'spam'),
(int, (1, 2, 3), 2+0j),
(int, (1, 2, 3), b'\02'),
]
for datatype, enum, value in tests:
with self.subTest((datatype, enum, value)):
handler = EnumParameter.HANDLER(datatype, enum)
data = handler.as_data(value)
self.assertIs(data, value)
class UnionParameterTests(unittest.TestCase):
def test_match_type_all_simple(self):
tests = [
'spam',
10,
True,
]
datatype = Union(str, int, bool)
param = UnionParameter(datatype)
for value in tests:
with self.subTest(value):
handler = param.match_type(value)
self.assertIs(type(handler), SimpleParameter.HANDLER)
self.assertIs(handler.datatype, type(value))
def test_match_type_mixed(self):
datatype = Union(
str,
# XXX add dedicated enums
Enum(int, (1, 2, 3)),
Basic,
Array(str),
Array(int),
Union(int, bool),
)
param = UnionParameter(datatype)
tests = [
('spam', SimpleParameter.HANDLER(str)),
(2, EnumParameter.HANDLER(int, (1, 2, 3))),
(BASIC_FULL, ComplexParameter(Basic).match_type(BASIC_FULL)),
(['spam'], ArrayParameter.HANDLER(Array(str))),
([], ArrayParameter.HANDLER(Array(str))),
([10], ArrayParameter.HANDLER(Array(int))),
(10, SimpleParameter.HANDLER(int)),
(True, SimpleParameter.HANDLER(bool)),
# no match
(Integer(2), None),
([True], None),
({}, None),
]
for value, expected in tests:
with self.subTest(value):
handler = param.match_type(value)
self.assertEqual(handler, expected)
def test_match_type_catchall(self):
NOOP = DatatypeHandler(ANY)
param = UnionParameter(Union(int, str, ANY))
tests = [
('spam', SimpleParameter.HANDLER(str)),
(10, SimpleParameter.HANDLER(int)),
# catchall
(BASIC_FULL, NOOP),
(['spam'], NOOP),
(True, NOOP),
(Integer(2), NOOP),
([10], NOOP),
({}, NOOP),
]
for value, expected in tests:
with self.subTest(value):
handler = param.match_type(value)
self.assertEqual(handler, expected)
def test_match_type_no_match(self):
param = UnionParameter(Union(int, str))
values = [
BASIC_FULL,
['spam'],
True,
Integer(2),
[10],
{},
]
for value in values:
with self.subTest(value):
handler = param.match_type(value)
self.assertIs(handler, None)
class ArrayParameterTests(unittest.TestCase):
def test_match_type_match(self):
param = ArrayParameter(Array(str))
expected = ArrayParameter.HANDLER(Array(str))
values = [
['a', 'b', 'c'],
[],
]
for value in values:
with self.subTest(value):
handler = param.match_type(value)
self.assertEqual(handler, expected)
def test_match_type_no_match(self):
param = ArrayParameter(Array(str))
values = [
['a', 1, 'c'],
('a', 'b', 'c'),
'spam',
]
for value in values:
with self.subTest(value):
handler = param.match_type(value)
self.assertIs(handler, None)
def test_coerce_simple(self):
param = ArrayParameter(Array(str))
values = [
['a', 'b', 'c'],
[],
]
for value in values:
with self.subTest(value):
handler = param.match_type(value)
coerced = handler.coerce(value)
self.assertEqual(coerced, value)
def test_coerce_complicated(self):
param = ArrayParameter(Array(Union(str, Basic)))
value = [
'a',
BASIC_FULL,
'c',
]
handler = param.match_type(value)
coerced = handler.coerce(value)
self.assertEqual(coerced, [
'a',
Basic(name='spam', value='eggs'),
'c',
])
def test_validate(self):
param = ArrayParameter(Array(str))
handler = param.match_type(['a', 'b', 'c'])
handler.validate(['a', 'b', 'c'])
def test_as_data_simple(self):
param = ArrayParameter(Array(str))
handler = param.match_type(['a', 'b', 'c'])
data = handler.as_data(['a', 'b', 'c'])
self.assertEqual(data, ['a', 'b', 'c'])
def test_as_data_complicated(self):
param = ArrayParameter(Array(Union(str, Basic)))
value = [
'a',
BASIC_FULL,
'c',
]
handler = param.match_type(value)
data = handler.as_data([
'a',
Basic(name='spam', value='eggs'),
'c',
])
self.assertEqual(data, value)
class ComplexParameterTests(unittest.TestCase):
def test_match_type(self):
fields = Fields(*FIELDS_BASIC)
param = ComplexParameter(fields)
handler = param.match_type(BASIC_FULL)
self.assertIs(type(handler), ComplexParameter.HANDLER)
self.assertEqual(handler.datatype.FIELDS, fields)
def test_coerce(self):
handler = ComplexParameter.HANDLER(Basic)
coerced = handler.coerce(BASIC_FULL)
self.assertEqual(coerced, Basic(**BASIC_FULL))
def test_validate(self):
handler = ComplexParameter.HANDLER(Basic)
handler.coerce(BASIC_FULL)
coerced = Basic(**BASIC_FULL)
handler.validate(coerced)
def test_as_data(self):
handler = ComplexParameter.HANDLER(Basic)
handler.coerce(BASIC_FULL)
coerced = Basic(**BASIC_FULL)
data = handler.as_data(coerced)
self.assertEqual(data, BASIC_FULL)

25
tests/helpers/stub.py Normal file
View file

@ -0,0 +1,25 @@
class Stub(object):
"""A testing double that tracks calls."""
def __init__(self):
self.calls = []
self._exceptions = []
def set_exceptions(self, *exceptions):
self._exceptions = list(exceptions)
def add_call(self, name, *args, **kwargs):
self.add_call_exact(name, args, kwargs)
def add_call_exact(self, name, args, kwargs):
self.calls.append((name, args, kwargs))
def maybe_raise(self):
if not self._exceptions:
return
exc = self._exceptions.pop(0)
if exc is None:
return
raise exc