diff --git a/tests/helpers/protocol.py b/tests/helpers/protocol.py new file mode 100644 index 00000000..50ee8f5a --- /dev/null +++ b/tests/helpers/protocol.py @@ -0,0 +1,169 @@ +from collections import namedtuple +import threading + +from . import socket + + +class StreamFailure(Exception): + """Something went wrong while handling messages to/from a stream.""" + + def __init__(self, direction, msg, exception): + err = 'error while processing stream: {!r}'.format(exception) + super(StreamFailure, self).__init__(self, err) + self.direction = direction + self.msg = msg + self.exception = exception + + def __repr__(self): + return '{}(direction={!r}, msg={!r}, exception={!r})'.format( + type(self).__name__, + self.direction, + self.msg, + self.exception, + ) + + +class MessageProtocol(namedtuple('Protocol', 'parse encode iter')): + """A basic abstraction of a message protocol. + + parse(msg) - returns a message for the given data. + encode(msg) - returns the message, serialized to the line-format. + iter(stream, stop) - yield each message from the stream. "stop" + is a function called with no args which returns True if the + iterator should stop. + """ + + +class Started(object): + """A simple wrapper around a started message protocol daemon.""" + + def __init__(self, fake): + self.fake = fake + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def send_message(self, msg): + return self.fake.send_response(msg) + + def close(self): + self.fake.close() + + +class Daemon(object): + """A testing double for a protocol daemon.""" + + STARTED = Started + + def __init__(self, connect, protocol, handler=None): + self._connect = connect + self._protocol = protocol + self._handler = handler + + self._closed = False + self._received = [] + self._failures = [] + + # These are set when we start. + self._host = None + self._port = None + self._sock = None + self._server = None + self._listener = None + + @property + def received(self): + """All the messages received thus far.""" + return list(self._received) + + @property + def failures(self): + """All send/recv failures thus far.""" + return self._failures + + def start(self, host, port): + """Start the fake daemon. + + This calls the earlier provided connect() function. + + A listener loop is started in another thread to handle incoming + messages from the socket. + """ + self._host = host or None + self._port = port + self._start() + return self.STARTED(self) + + def send_message(self, msg): + """Serialize msg to the line format and send it to the socket.""" + if self._closed: + raise EOFError('closed') + msg = self._protocol.parse(msg) + raw = self._protocol.encode(msg) + try: + self._send(raw) + except Exception as exc: + failure = StreamFailure('send', msg, exc) + self._failures.append(failure) + + def close(self): + """Clean up the daemon's resources (e.g. sockets, files, listener).""" + if self._closed: + return + + self._closed = True + self._close() + + def assert_received(self, case, expected): + """Ensure that the received messages match the expected ones.""" + received = [self._protocol.parse(msg) for msg in self._received] + expected = [self._protocol.parse(msg) for msg in expected] + case.assertEqual(received, expected) + + # internal methods + + def _start(self, host=None): + self._sock, self._server = self._connect( + host or self._host, + self._port, + ) + + # TODO: make it a daemon thread? + self._listener = threading.Thread(target=self._listen) + self._listener.start() + + def _listen(self): + with self._sock.makefile('rb') as sockfile: + for msg in self._protocol.iter(sockfile, lambda: self._closed): + if isinstance(msg, StreamFailure): + self._failures.append(msg) + else: + self._add_received(msg) + + def _add_received(self, msg): + self._received.append(msg) + + if self._handler is not None: + self._handler(msg, self.send_message) + + def _send(self, raw): + while raw: + sent = self._sock.send(raw) + raw = raw[sent:] + + def _close(self): + if self._sock is not None: + socket.close(self._sock) + self._sock = None + if self._server is not None: + socket.close(self._server) + self._server = None + if self._listener is not None: + self._listener.join(timeout=1) + # TODO: the listener isn't stopping! + #if self._listener.is_alive(): + # raise RuntimeError('timed out') + self._listener = None diff --git a/tests/helpers/pydevd/_fake.py b/tests/helpers/pydevd/_fake.py index 6e38b5d3..100e3e8f 100644 --- a/tests/helpers/pydevd/_fake.py +++ b/tests/helpers/pydevd/_fake.py @@ -1,33 +1,23 @@ -import socket -import threading - from ptvsd.wrapper import start_server, start_client - -from ._pydevd import parse_message, iter_messages, StreamFailure +from ._pydevd import parse_message, encode_message, iter_messages +from tests.helpers import protocol -def socket_close(sock): - sock.shutdown(socket.SHUT_RDWR) - sock.close() +PROTOCOL = protocol.MessageProtocol( + parse=parse_message, + encode=encode_message, + iter=iter_messages, +) def _connect(host, port): if host is None: - return start_server(port) + return start_server(port), None else: - return start_client(host, port) + return start_client(host, port), None -class _Started(object): - - def __init__(self, fake): - self.fake = fake - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() +class Started(protocol.Started): def send_response(self, msg): return self.fake.send_response(msg) @@ -35,11 +25,8 @@ class _Started(object): def send_event(self, msg): return self.fake.send_event(msg) - def close(self): - self.fake.close() - -class FakePyDevd(object): +class FakePyDevd(protocol.Daemon): """A testing double for PyDevd. Note that you have the option to provide a handler function. This @@ -66,118 +53,15 @@ class FakePyDevd(object): https://github.com/fabioz/PyDev.Debugger/blob/master/_pydevd_bundle/pydevd_comm.py """ # noqa - CONNECT = staticmethod(_connect) + STARTED = Started - def __init__(self, handler=None, connect=None): - if connect is None: - connect = self.CONNECT - - self._handler = handler - self._connect = connect - - self._closed = False - self._received = [] - self._failures = [] - - # These are set when we start. - self._host = None - self._port = None - self._sock = None - self._listener = None - - @property - def received(self): - """All the messages received thus far.""" - return list(self._received) - - @property - def failures(self): - """All send/recv failures thus far.""" - return self._failures - - def start(self, host, port): - """Start the fake pydevd daemon. - - This calls the earlier provided connect() function. By default - this calls either start_server() or start_client() (depending on - the host) from ptvsd.wrapper. Thus the ptvsd message processor - is started and a PydevdSocket is used as the connection. - - A listener loop is started in another thread to handle incoming - messages from the socket (i.e. from ptvsd). - """ - self._host = host or None - self._port = port - self._sock = self._connect(self._host, self._port) - - # TODO: make daemon? - self._listener = threading.Thread(target=self._listen) - self._listener.start() - - return _Started(self) + def __init__(self, handler=None): + super(FakePyDevd, self).__init__(_connect, PROTOCOL, handler) def send_response(self, msg): """Send a response message to the adapter (ptvsd).""" - return self._send_message(msg) + return self.send_message(msg) def send_event(self, msg): """Send an event message to the adapter (ptvsd).""" - return self._send_message(msg) - - def close(self): - """If started, close the socket and wait for the listener to finish.""" - if self._closed: - return - - self._closed = True - if self._sock is not None: - socket_close(self._sock) - self._sock = None - if self._listener is not None: - self._listener.join(timeout=1) - # TODO: the listener isn't stopping! - #if self._listener.is_alive(): - # raise RuntimeError('timed out') - self._listener = None - - def assert_received(self, case, expected): - """Ensure that the received messages match the expected ones.""" - received = [parse_message(msg) for msg in self._received] - expected = [parse_message(msg) for msg in expected] - case.assertEqual(received, expected) - - # internal methods - - def _listen(self): - with self._sock.makefile('rb') as sockfile: - for msg in iter_messages(sockfile, lambda: self._closed): - if isinstance(msg, StreamFailure): - self._failures.append(msg) - else: - self._add_received(msg) - - def _add_received(self, msg): - self._received.append(msg) - - if self._handler is not None: - self._handler(msg, self._send_message) - - def _send_message(self, msg): - """Serialize the message to the line format and send it to ptvsd. - - If the message is bytes or a string then it is send as-is. - """ - msg = parse_message(msg) - raw = msg.as_bytes() - if not raw.endswith(b'\n'): - raw += b'\n' - try: - self._send(raw) - except Exception as exc: - failure = StreamFailure('send', msg, exc) - self._failures.append(failure) - - def _send(self, raw): - while raw: - sent = self._sock.send(raw) - raw = raw[sent:] + return self.send_message(msg) diff --git a/tests/helpers/pydevd/_pydevd.py b/tests/helpers/pydevd/_pydevd.py index 691c6ada..c29a6bba 100644 --- a/tests/helpers/pydevd/_pydevd.py +++ b/tests/helpers/pydevd/_pydevd.py @@ -1,25 +1,28 @@ from collections import namedtuple +from tests.helpers.protocol import StreamFailure + # TODO: Everything here belongs in a proper pydevd package. -class StreamFailure(Exception): - """Something went wrong while handling messages to/from a stream.""" +def parse_message(msg): + """Return a message object for the given "msg" data.""" + if type(msg) is bytes: + return RawMessage.from_bytes(msg) + elif isinstance(msg, str): + return RawMessage.from_bytes(msg) + elif type(msg) is RawMessage: + return msg + else: + raise NotImplementedError - def __init__(self, direction, msg, exception): - err = 'error while processing stream: {!r}'.format(exception) - super(StreamFailure, self).__init__(self, err) - self.direction = direction - self.msg = msg - self.exception = exception - def __repr__(self): - return '{}(direction={!r}, msg={!r}, exception={!r})'.format( - type(self).__name__, - self.direction, - self.msg, - self.exception, - ) +def encode_message(msg): + """Return the message, serialized to the line-format.""" + raw = msg.as_bytes() + if not raw.endswith(b'\n'): + raw += b'\n' + return raw def iter_messages(stream, stop=lambda: False): @@ -36,18 +39,6 @@ def iter_messages(stream, stop=lambda: False): yield StreamFailure('recv', None, exc) -def parse_message(msg): - """Return a message object for the given "msg" data.""" - if type(msg) is bytes: - return RawMessage.from_bytes(msg) - elif isinstance(msg, str): - return RawMessage.from_bytes(msg) - elif type(msg) is RawMessage: - return msg - else: - raise NotImplementedError - - class RawMessage(namedtuple('RawMessage', 'bytes')): """A pydevd message class that leaves the raw bytes unprocessed.""" diff --git a/tests/helpers/socket.py b/tests/helpers/socket.py new file mode 100644 index 00000000..29832154 --- /dev/null +++ b/tests/helpers/socket.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import + +import socket + + +def connect(host, port): + """Return (client, server) after connecting. + + If host is None then it's a server, so it will wait for a connection + on localhost. Otherwise it will connect to the remote host. + """ + sock = socket.socket( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + ) + sock.setsockopt( + socket.SOL_SOCKET, + socket.SO_REUSEADDR, + 1, + ) + if host is None: + addr = ('127.0.0.1', port) + server = sock + server.bind(addr) + server.listen(1) + sock, _ = server.accept() + else: + addr = (host, port) + sock.connect(addr) + server = None + return sock, server + + +def close(sock): + """Shutdown and close the socket.""" + sock.shutdown(socket.SHUT_RDWR) + sock.close() diff --git a/tests/helpers/vsc/_fake.py b/tests/helpers/vsc/_fake.py index 3b9d5c60..3fc86b9e 100644 --- a/tests/helpers/vsc/_fake.py +++ b/tests/helpers/vsc/_fake.py @@ -1,34 +1,23 @@ -import socket import threading -from ._vsc import StreamFailure, encode_message, iter_messages, parse_message -from ._vsc import RawMessage # noqa +from tests.helpers import protocol, socket +from ._vsc import encode_message, iter_messages, parse_message -def socket_close(sock): - sock.shutdown(socket.SHUT_RDWR) - sock.close() +PROTOCOL = protocol.MessageProtocol( + parse=parse_message, + encode=encode_message, + iter=iter_messages, +) -class _Started(object): - - def __init__(self, fake): - self.fake = fake - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() +class Started(protocol.Started): def send_request(self, msg): return self.fake.send_request(msg) - def close(self): - self.fake.close() - -class FakeVSC(object): +class FakeVSC(protocol.Daemon): """A testing double for a VSC debugger protocol client. This class facilitates sending VSC debugger protocol messages over @@ -58,52 +47,33 @@ class FakeVSC(object): """ # noqa def __init__(self, start_adapter, handler=None): - def start_adapter(host, port, start_adapter=start_adapter): - self._adapter = start_adapter(host, port) + super(FakeVSC, self).__init__(socket.connect, PROTOCOL, handler) + + def start_adapter(host, port, _start_adapter=start_adapter): + self._adapter = _start_adapter(host, port) self._start_adapter = start_adapter - self._handler = handler - - self._closed = False - self._received = [] - self._failures = [] - - # These are set when we start. - self._host = None - self._port = None self._adapter = None - self._sock = None - self._server = None - self._listener = None - - @property - def addr(self): - host, port = self._host, self._port - if host is None: - host = '127.0.0.1' - return (host, port) - - @property - def received(self): - """All the messages received thus far.""" - return list(self._received) - - @property - def failures(self): - """All send/recv failures thus far.""" - return self._failures def start(self, host, port): """Start the fake and the adapter.""" - if self._closed or self._adapter is not None: + if self._adapter is not None: raise RuntimeError('already started') + return super(FakeVSC, self).start(host, port) - if not host: + def send_request(self, req): + """Send the given Request object.""" + return self.send_message(req) + + # internal methods + + def _start(self, host=None): + start_adapter = (lambda: self._start_adapter(self._host, self._port)) + if not self._host: # The adapter is the server so start it first. - t = threading.Thread( - target=lambda: self._start_adapter(host, port)) + t = threading.Thread(target=start_adapter) t.start() - self._start('127.0.0.1', port) + super(FakeVSC, self)._start('127.0.0.1') t.join(timeout=1) if t.is_alive(): raise RuntimeError('timed out') @@ -111,107 +81,15 @@ class FakeVSC(object): # The adapter is the client so start it last. # TODO: For now don't use this. raise NotImplementedError - t = threading.Thread( - target=lambda: self._start(host, port)) + t = threading.Thread(target=super(FakeVSC, self)._start) t.start() - self._start_adapter(host, port) + start_adapter() t.join(timeout=1) if t.is_alive(): raise RuntimeError('timed out') - return _Started(self) - - def send_request(self, req): - """Send the given Request object.""" - if self._closed: - raise EOFError('closed') - req = parse_message(req) - raw = encode_message(req) - try: - self._send(raw) - except Exception as exc: - failure = ('send', req, exc) - self._failures.append(failure) - - def close(self): - """Close the fake's resources (e.g. socket, adapter).""" - if self._closed: - return - - self._closed = True - self._close() - - def assert_received(self, case, expected): - """Ensure that the received messages match the expected ones.""" - received = [parse_message(msg) for msg in self._received] - expected = [parse_message(msg) for msg in expected] - case.assertEqual(received, expected) - - # internal methods - - def _start(self, host, port): - self._host = host - self._port = port - self._connect() - - # TODO: make daemon? - self._listener = threading.Thread(target=self._listen) - self._listener.start() - - def _connect(self): - sock = socket.socket( - socket.AF_INET, - socket.SOCK_STREAM, - socket.IPPROTO_TCP, - ) - sock.setsockopt( - socket.SOL_SOCKET, - socket.SO_REUSEADDR, - 1, - ) - if self._host is None: - server = sock - server.bind(self.addr) - server.listen(1) - sock, _ = server.accept() - else: - sock.connect(self.addr) - server = None - self._server = server - self._sock = sock - - def _listen(self): - with self._sock.makefile('rb') as sockfile: - for msg in iter_messages(sockfile, lambda: self._closed): - if isinstance(msg, StreamFailure): - self._failures.append(msg) - else: - self._add_received(msg) - - def _add_received(self, msg): - self._received.append(msg) - - if self._handler is not None: - self._handler(msg, self.send_request) - - def _send(self, raw): - while raw: - sent = self._sock.send(raw) - raw = raw[sent:] - def _close(self): if self._adapter is not None: self._adapter.close() self._adapter = None - if self._sock is not None: - socket_close(self._sock) - self._sock = None - if self._server is not None: - socket_close(self._server) - self._server = None - if self._listener is not None: - self._listener.join(timeout=1) - # TODO: the listener isn't stopping! - #if self._listener.is_alive(): - # raise RuntimeError('timed out') - self._listener = None + super(FakeVSC, self)._close() diff --git a/tests/helpers/vsc/_vsc.py b/tests/helpers/vsc/_vsc.py index 5db4080b..9b214a74 100644 --- a/tests/helpers/vsc/_vsc.py +++ b/tests/helpers/vsc/_vsc.py @@ -2,27 +2,22 @@ from collections import namedtuple import json from debugger_protocol.messages import wireformat +from tests.helpers.protocol import StreamFailure # TODO: Use more of the code from debugger_protocol. -class StreamFailure(Exception): - """Something went wrong while handling messages to/from a stream.""" - - def __init__(self, direction, msg, exception): - err = 'error while processing stream: {!r}'.format(exception) - super(StreamFailure, self).__init__(self, err) - self.direction = direction - self.msg = msg - self.exception = exception - - def __repr__(self): - return '{}(direction={!r}, msg={!r}, exception={!r})'.format( - type(self).__name__, - self.direction, - self.msg, - self.exception, - ) +def parse_message(msg): + """Return a message object for the given "msg" data.""" + if type(msg) is str: + data = json.loads(msg) + elif isinstance(msg, bytes): + data = json.loads(msg.decode('utf-8')) + elif type(msg) is RawMessage: + return msg + else: + data = msg + return RawMessage.from_data(**data) def encode_message(msg): @@ -42,19 +37,6 @@ def iter_messages(stream, stop=lambda: False): yield StreamFailure('recv', None, exc) -def parse_message(msg): - """Return a message object for the given "msg" data.""" - if type(msg) is str: - data = json.loads(msg) - elif isinstance(msg, bytes): - data = json.loads(msg.decode('utf-8')) - elif type(msg) is RawMessage: - return msg - else: - data = msg - return RawMessage.from_data(**data) - - class RawMessage(namedtuple('RawMessage', 'data')): """A wrapper around a line-formatted debugger protocol message."""