diff --git a/tests/helpers/vsc/_fake.py b/tests/helpers/vsc/_fake.py index ef3b5c52..3b9d5c60 100644 --- a/tests/helpers/vsc/_fake.py +++ b/tests/helpers/vsc/_fake.py @@ -5,6 +5,29 @@ from ._vsc import StreamFailure, encode_message, iter_messages, parse_message from ._vsc import RawMessage # noqa +def socket_close(sock): + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + +class _Started(object): + + def __init__(self, fake): + self.fake = fake + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def send_request(self, msg): + return self.fake.send_request(msg) + + def close(self): + self.fake.close() + + class FakeVSC(object): """A testing double for a VSC debugger protocol client. @@ -21,20 +44,23 @@ class FakeVSC(object): >>> pydevd = FakePyDevd() >>> fake = FakeVSC(lambda h, p: pydevd.start) >>> fake.start(None, 8888) - >>> try: + >>> with fake.start(None, 8888): ... fake.send_request('') ... # wait for events... - ... finally: - ... fake.close() + ... >>> fake.assert_received(testcase, [ ... # messages ... ]) + >>> See debugger_protocol/messages/README.md for more about the protocol itself. - """ + """ # noqa def __init__(self, start_adapter, handler=None): + def start_adapter(host, port, start_adapter=start_adapter): + self._adapter = start_adapter(host, port) + self._start_adapter = start_adapter self._handler = handler @@ -65,7 +91,7 @@ class FakeVSC(object): @property def failures(self): """All send/recv failures thus far.""" - return self._sock.failures + return self._failures def start(self, host, port): """Start the fake and the adapter.""" @@ -73,16 +99,29 @@ class FakeVSC(object): raise RuntimeError('already started') if not host: - host = None # The adapter is the server so start it first. - self._adapter = self._start_adapter(host, port) - self._start(host, port) + t = threading.Thread( + target=lambda: self._start_adapter(host, port)) + t.start() + self._start('127.0.0.1', port) + t.join(timeout=1) + if t.is_alive(): + raise RuntimeError('timed out') else: # The adapter is the client so start it last. - self._start(host, port) - self._adapter = self._start_adapter(host, port) + # TODO: For now don't use this. + raise NotImplementedError + t = threading.Thread( + target=lambda: self._start(host, port)) + t.start() + self._start_adapter(host, port) + t.join(timeout=1) + if t.is_alive(): + raise RuntimeError('timed out') - def send_request(self, req, delay=0.1): + return _Started(self) + + def send_request(self, req): """Send the given Request object.""" if self._closed: raise EOFError('closed') @@ -102,6 +141,12 @@ class FakeVSC(object): self._closed = True self._close() + def assert_received(self, case, expected): + """Ensure that the received messages match the expected ones.""" + received = [parse_message(msg) for msg in self._received] + expected = [parse_message(msg) for msg in expected] + case.assertEqual(received, expected) + # internal methods def _start(self, host, port): @@ -109,6 +154,7 @@ class FakeVSC(object): self._port = port self._connect() + # TODO: make daemon? self._listener = threading.Thread(target=self._listen) self._listener.start() @@ -158,13 +204,14 @@ class FakeVSC(object): self._adapter.close() self._adapter = None if self._sock is not None: - self._sock.shutdown(socket.SHUT_RDWR) - self._sock.close() + socket_close(self._sock) self._sock = None if self._server is not None: - self._server.shutdown(socket.SHUT_RDWR) - self._server.close() + socket_close(self._server) self._server = None if self._listener is not None: - self._listener.join() + self._listener.join(timeout=1) + # TODO: the listener isn't stopping! + #if self._listener.is_alive(): + # raise RuntimeError('timed out') self._listener = None diff --git a/tests/helpers/vsc/_vsc.py b/tests/helpers/vsc/_vsc.py index 30abb3ea..5db4080b 100644 --- a/tests/helpers/vsc/_vsc.py +++ b/tests/helpers/vsc/_vsc.py @@ -16,6 +16,14 @@ class StreamFailure(Exception): self.msg = msg self.exception = exception + def __repr__(self): + return '{}(direction={!r}, msg={!r}, exception={!r})'.format( + type(self).__name__, + self.direction, + self.msg, + self.exception, + ) + def encode_message(msg): """Return the line-formatted bytes for the message.""" @@ -26,7 +34,10 @@ def iter_messages(stream, stop=lambda: False): """Yield the correct message for each line-formatted one found.""" while not stop(): try: - yield wireformat.read(stream, lambda _: RawMessage) + msg = wireformat.read(stream, lambda _: RawMessage) + if msg is None: # EOF + break + yield msg except Exception as exc: yield StreamFailure('recv', None, exc) @@ -35,21 +46,20 @@ def parse_message(msg): """Return a message object for the given "msg" data.""" if type(msg) is str: data = json.loads(msg) - return RawMessage(data) elif isinstance(msg, bytes): data = json.loads(msg.decode('utf-8')) - return RawMessage(data) elif type(msg) is RawMessage: return msg else: - raise NotImplementedError + data = msg + return RawMessage.from_data(**data) class RawMessage(namedtuple('RawMessage', 'data')): """A wrapper around a line-formatted debugger protocol message.""" @classmethod - def from_data(cls, data): + def from_data(cls, **data): """Return a RawMessage for the given JSON-decoded data.""" return cls(data)