Change patterns to use "==" rather than "in".

This commit is contained in:
Pavel Minaev 2018-10-05 00:30:36 -07:00 committed by Karthik Nadig
parent 7d14600c24
commit b2ceebac3b
2 changed files with 70 additions and 67 deletions

View file

@ -14,7 +14,7 @@ class BasePattern(object):
def __repr__(self):
raise NotImplementedError()
def __contains__(self, data):
def __eq__(self, value):
raise NotImplementedError()
def such_that(self, condition):
@ -25,14 +25,14 @@ class Pattern(BasePattern):
"""Represents a pattern of a data structure, that can be matched against the
actual data by using operator "in".
For lists and tuples, (data in pattern) is true if both are sequences of the
same length, and for all valid I, (data[I] in Pattern(pattern[I])).
For lists and tuples, (data == Pattern(pattern)) is true if both are sequences of the
same length, and for all valid I, (data[I] == Pattern(pattern[I])).
For dicts, (data in pattern) is true if, for all K in data.keys() + pattern.keys(),
data.has_key(K) and (data[K] in Pattern(pattern[K])). ANY.dict_with() can be used
For dicts, (data == Pattern(pattern)) is true if, for all K in data.keys() + pattern.keys(),
data.has_key(K) and (data[K] == Pattern(pattern[K])). ANY.dict_with() can be used
to perform partial matches.
For any other type, (data in pattern) is true if pattern is ANY or data == pattern.
For any other type, (data == Pattern(pattern)) is true if pattern is ANY or data == pattern.
If the match has failed, but data has a member called __data__, then it is invoked
without arguments, and the same match is performed against the returned value.
@ -49,14 +49,11 @@ class Pattern(BasePattern):
return repr(self.pattern)
def _matches(self, data):
import pytests.helpers.timeline as timeline
pattern = self.pattern
if isinstance(pattern, BasePattern):
return data in pattern
elif isinstance(data, tuple) and isinstance(pattern, tuple):
return len(data) == len(pattern) and all(d in Pattern(p) for (p, d) in zip(pattern, data))
if isinstance(data, tuple) and isinstance(pattern, tuple):
return len(data) == len(pattern) and all(d == Pattern(p) for (p, d) in zip(pattern, data))
elif isinstance(data, list) and isinstance(pattern, list):
return tuple(data) in Pattern(tuple(pattern))
return tuple(data) == Pattern(tuple(pattern))
elif isinstance(data, dict) and isinstance(pattern, dict):
keys = set(tuple(data.keys()) + tuple(pattern.keys()))
def pairs_match(key):
@ -65,14 +62,12 @@ class Pattern(BasePattern):
p = pattern[key]
except KeyError:
return False
return d in Pattern(p)
return d == Pattern(p)
return all(pairs_match(key) for key in keys)
elif isinstance(data, timeline.Occurrence) and isinstance(pattern, timeline.Expectation):
return pattern.has_occurred_by(data)
else:
return data == pattern
def __contains__(self, value):
def __eq__(self, value):
if self._matches(value):
return True
try:
@ -80,7 +75,10 @@ class Pattern(BasePattern):
except AttributeError:
return False
else:
return value.__data__() in self
return value.__data__() == self
def __ne__(self, value):
return not self == value
class Any(BasePattern):
@ -94,33 +92,24 @@ class Any(BasePattern):
def __repr__(self):
return 'ANY'
def __contains__(self, other):
def __eq__(self, other):
return True
@staticmethod
def dict_with(items):
return AnyDictWith(items)
"""A pattern that matches any dict that contains the specified key-value pairs.
d1 = {'a': 1, 'b': 2, 'c': 3}
d2 = {'a': 1, 'b': 2}
class AnyDictWith(BasePattern):
"""A pattern that matches any dict that contains the specified key-value pairs.
d1 = {'a': 1, 'b': 2, 'c': 3}
d2 = {'a': 1, 'b': 2}
d1 in Pattern(d2) # False (need exact match)
d1 in ANY.dict_with(d2) # True (subset matches)
"""
def __init__(self, items=None):
self.items = dict(items)
self.pattern = Pattern(defaultdict(lambda: ANY, items))
def __repr__(self):
return repr(self.items)[:-1] + ', ...}'
def __contains__(self, value):
return value in self.pattern
d1 == Pattern(d2) # False (need exact match)
d1 == ANY.dict_with(d2) # True (subset matches)
"""
class AnyDictWith(defaultdict):
def __repr__(self):
return repr(items)[:-1] + ', ...}'
items = AnyDictWith(lambda: ANY, items)
return items
class Maybe(BasePattern):
@ -134,8 +123,8 @@ class Maybe(BasePattern):
def __repr__(self):
return 'Maybe(%r)' % self.pattern
def __contains__(self, value):
return self.condition(value) and value in self.pattern
def __eq__(self, value):
return self.condition(value) and value == self.pattern
class Success(BasePattern):
@ -148,10 +137,24 @@ class Success(BasePattern):
def __repr__(self):
return 'SUCCESS' if self.success else 'FAILURE'
def __contains__(self, response_body):
def __eq__(self, response_body):
return self.success != isinstance(response_body, Exception)
class Is(BasePattern):
"""A pattern that matches a specific object only (i.e. uses operator 'is' rather than '==').
"""
def __init__(self, obj):
self.obj = obj
def __repr__(self):
return 'Is(%r)' % self.obj
def __eq__(self, value):
return self.obj is value
SUCCESS = Success(True)
FAILURE = Success(False)

View file

@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root
# Licensed under the MIT License. See LICENSE == the project root
# for license information.
from __future__ import print_function, with_statement, absolute_import
@ -25,33 +25,33 @@ VALUES = [
@pytest.mark.parametrize('x', VALUES)
def test_eq(x):
assert x in Pattern(x)
assert x == Pattern(x)
@pytest.mark.parametrize('xy', zip(VALUES[1:], VALUES[:-1]))
def test_ne(xy):
x, y = xy
if x != y:
assert x not in Pattern(y)
assert x != Pattern(y)
@pytest.mark.parametrize('x', VALUES)
def test_any(x):
assert x in ANY
assert x == ANY
def test_lists():
assert [1, 2, 3] not in Pattern([1, 2, 3, 4])
assert [1, 2, 3, 4] not in Pattern([1, 2, 3])
assert [2, 3, 1] not in Pattern([1, 2, 3])
assert [1, 2, 3] in Pattern([1, ANY, 3])
assert [1, 2, 3, 4] not in Pattern([1, ANY, 4])
assert [1, 2, 3] != Pattern([1, 2, 3, 4])
assert [1, 2, 3, 4] != Pattern([1, 2, 3])
assert [2, 3, 1] != Pattern([1, 2, 3])
assert [1, 2, 3] == Pattern([1, ANY, 3])
assert [1, 2, 3, 4] != Pattern([1, ANY, 4])
def test_dicts():
assert {'a': 1, 'b': 2} not in Pattern({'a': 1, 'b': 2, 'c': 3})
assert {'a': 1, 'b': 2, 'c': 3} not in Pattern({'a': 1, 'b': 2})
assert {'a': 1, 'b': 2} in Pattern({'a': ANY, 'b': 2})
assert {'a': 1, 'b': 2} in Pattern(ANY.dict_with({'a': 1}))
assert {'a': 1, 'b': 2} != Pattern({'a': 1, 'b': 2, 'c': 3})
assert {'a': 1, 'b': 2, 'c': 3} != Pattern({'a': 1, 'b': 2})
assert {'a': 1, 'b': 2} == Pattern({'a': ANY, 'b': 2})
assert {'a': 1, 'b': 2} == Pattern(ANY.dict_with({'a': 1}))
def test_maybe():
@ -59,25 +59,25 @@ def test_maybe():
return x != 0
pattern = Pattern(1).such_that(nonzero)
assert 0 not in pattern
assert 1 in pattern
assert 2 not in pattern
assert 0 != pattern
assert 1 == pattern
assert 2 != pattern
pattern = ANY.such_that(nonzero)
assert 0 not in pattern
assert 1 in pattern
assert 2 in pattern
assert 0 != pattern
assert 1 == pattern
assert 2 == pattern
def test_success():
assert {} in SUCCESS
assert {} not in FAILURE
assert {} == SUCCESS
assert {} != FAILURE
def test_failure():
error = RequestFailure('error!')
assert error not in SUCCESS
assert error in FAILURE
assert error != SUCCESS
assert error == FAILURE
class DataObject(object):
@ -90,8 +90,8 @@ class DataObject(object):
def test_data():
something = DataObject(('Something', {'a': 1}))
assert something in Pattern(something.data)
assert something not in Pattern(('Another', {'b': 2}))
assert something == Pattern(something.data)
assert something != Pattern(('Another', {'b': 2}))
def test_recursive():
@ -107,7 +107,7 @@ def test_recursive():
'bd': True,
'be': [],
}
] in Pattern([
] == Pattern([
ANY,
True,
('Something', {'a': 1}),