Factor out tests.helpers.protocol.Daemon.

This commit is contained in:
Eric Snow 2018-02-10 04:00:12 +00:00
parent 2c8048953b
commit 3af4b171c7
6 changed files with 282 additions and 340 deletions

169
tests/helpers/protocol.py Normal file
View file

@ -0,0 +1,169 @@
from collections import namedtuple
import threading
from . import socket
class StreamFailure(Exception):
"""Something went wrong while handling messages to/from a stream."""
def __init__(self, direction, msg, exception):
err = 'error while processing stream: {!r}'.format(exception)
super(StreamFailure, self).__init__(self, err)
self.direction = direction
self.msg = msg
self.exception = exception
def __repr__(self):
return '{}(direction={!r}, msg={!r}, exception={!r})'.format(
type(self).__name__,
self.direction,
self.msg,
self.exception,
)
class MessageProtocol(namedtuple('Protocol', 'parse encode iter')):
"""A basic abstraction of a message protocol.
parse(msg) - returns a message for the given data.
encode(msg) - returns the message, serialized to the line-format.
iter(stream, stop) - yield each message from the stream. "stop"
is a function called with no args which returns True if the
iterator should stop.
"""
class Started(object):
"""A simple wrapper around a started message protocol daemon."""
def __init__(self, fake):
self.fake = fake
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def send_message(self, msg):
return self.fake.send_response(msg)
def close(self):
self.fake.close()
class Daemon(object):
"""A testing double for a protocol daemon."""
STARTED = Started
def __init__(self, connect, protocol, handler=None):
self._connect = connect
self._protocol = protocol
self._handler = handler
self._closed = False
self._received = []
self._failures = []
# These are set when we start.
self._host = None
self._port = None
self._sock = None
self._server = None
self._listener = None
@property
def received(self):
"""All the messages received thus far."""
return list(self._received)
@property
def failures(self):
"""All send/recv failures thus far."""
return self._failures
def start(self, host, port):
"""Start the fake daemon.
This calls the earlier provided connect() function.
A listener loop is started in another thread to handle incoming
messages from the socket.
"""
self._host = host or None
self._port = port
self._start()
return self.STARTED(self)
def send_message(self, msg):
"""Serialize msg to the line format and send it to the socket."""
if self._closed:
raise EOFError('closed')
msg = self._protocol.parse(msg)
raw = self._protocol.encode(msg)
try:
self._send(raw)
except Exception as exc:
failure = StreamFailure('send', msg, exc)
self._failures.append(failure)
def close(self):
"""Clean up the daemon's resources (e.g. sockets, files, listener)."""
if self._closed:
return
self._closed = True
self._close()
def assert_received(self, case, expected):
"""Ensure that the received messages match the expected ones."""
received = [self._protocol.parse(msg) for msg in self._received]
expected = [self._protocol.parse(msg) for msg in expected]
case.assertEqual(received, expected)
# internal methods
def _start(self, host=None):
self._sock, self._server = self._connect(
host or self._host,
self._port,
)
# TODO: make it a daemon thread?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
def _listen(self):
with self._sock.makefile('rb') as sockfile:
for msg in self._protocol.iter(sockfile, lambda: self._closed):
if isinstance(msg, StreamFailure):
self._failures.append(msg)
else:
self._add_received(msg)
def _add_received(self, msg):
self._received.append(msg)
if self._handler is not None:
self._handler(msg, self.send_message)
def _send(self, raw):
while raw:
sent = self._sock.send(raw)
raw = raw[sent:]
def _close(self):
if self._sock is not None:
socket.close(self._sock)
self._sock = None
if self._server is not None:
socket.close(self._server)
self._server = None
if self._listener is not None:
self._listener.join(timeout=1)
# TODO: the listener isn't stopping!
#if self._listener.is_alive():
# raise RuntimeError('timed out')
self._listener = None

View file

@ -1,33 +1,23 @@
import socket
import threading
from ptvsd.wrapper import start_server, start_client
from ._pydevd import parse_message, iter_messages, StreamFailure
from ._pydevd import parse_message, encode_message, iter_messages
from tests.helpers import protocol
def socket_close(sock):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
PROTOCOL = protocol.MessageProtocol(
parse=parse_message,
encode=encode_message,
iter=iter_messages,
)
def _connect(host, port):
if host is None:
return start_server(port)
return start_server(port), None
else:
return start_client(host, port)
return start_client(host, port), None
class _Started(object):
def __init__(self, fake):
self.fake = fake
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class Started(protocol.Started):
def send_response(self, msg):
return self.fake.send_response(msg)
@ -35,11 +25,8 @@ class _Started(object):
def send_event(self, msg):
return self.fake.send_event(msg)
def close(self):
self.fake.close()
class FakePyDevd(object):
class FakePyDevd(protocol.Daemon):
"""A testing double for PyDevd.
Note that you have the option to provide a handler function. This
@ -66,118 +53,15 @@ class FakePyDevd(object):
https://github.com/fabioz/PyDev.Debugger/blob/master/_pydevd_bundle/pydevd_comm.py
""" # noqa
CONNECT = staticmethod(_connect)
STARTED = Started
def __init__(self, handler=None, connect=None):
if connect is None:
connect = self.CONNECT
self._handler = handler
self._connect = connect
self._closed = False
self._received = []
self._failures = []
# These are set when we start.
self._host = None
self._port = None
self._sock = None
self._listener = None
@property
def received(self):
"""All the messages received thus far."""
return list(self._received)
@property
def failures(self):
"""All send/recv failures thus far."""
return self._failures
def start(self, host, port):
"""Start the fake pydevd daemon.
This calls the earlier provided connect() function. By default
this calls either start_server() or start_client() (depending on
the host) from ptvsd.wrapper. Thus the ptvsd message processor
is started and a PydevdSocket is used as the connection.
A listener loop is started in another thread to handle incoming
messages from the socket (i.e. from ptvsd).
"""
self._host = host or None
self._port = port
self._sock = self._connect(self._host, self._port)
# TODO: make daemon?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
return _Started(self)
def __init__(self, handler=None):
super(FakePyDevd, self).__init__(_connect, PROTOCOL, handler)
def send_response(self, msg):
"""Send a response message to the adapter (ptvsd)."""
return self._send_message(msg)
return self.send_message(msg)
def send_event(self, msg):
"""Send an event message to the adapter (ptvsd)."""
return self._send_message(msg)
def close(self):
"""If started, close the socket and wait for the listener to finish."""
if self._closed:
return
self._closed = True
if self._sock is not None:
socket_close(self._sock)
self._sock = None
if self._listener is not None:
self._listener.join(timeout=1)
# TODO: the listener isn't stopping!
#if self._listener.is_alive():
# raise RuntimeError('timed out')
self._listener = None
def assert_received(self, case, expected):
"""Ensure that the received messages match the expected ones."""
received = [parse_message(msg) for msg in self._received]
expected = [parse_message(msg) for msg in expected]
case.assertEqual(received, expected)
# internal methods
def _listen(self):
with self._sock.makefile('rb') as sockfile:
for msg in iter_messages(sockfile, lambda: self._closed):
if isinstance(msg, StreamFailure):
self._failures.append(msg)
else:
self._add_received(msg)
def _add_received(self, msg):
self._received.append(msg)
if self._handler is not None:
self._handler(msg, self._send_message)
def _send_message(self, msg):
"""Serialize the message to the line format and send it to ptvsd.
If the message is bytes or a string then it is send as-is.
"""
msg = parse_message(msg)
raw = msg.as_bytes()
if not raw.endswith(b'\n'):
raw += b'\n'
try:
self._send(raw)
except Exception as exc:
failure = StreamFailure('send', msg, exc)
self._failures.append(failure)
def _send(self, raw):
while raw:
sent = self._sock.send(raw)
raw = raw[sent:]
return self.send_message(msg)

View file

@ -1,25 +1,28 @@
from collections import namedtuple
from tests.helpers.protocol import StreamFailure
# TODO: Everything here belongs in a proper pydevd package.
class StreamFailure(Exception):
"""Something went wrong while handling messages to/from a stream."""
def parse_message(msg):
"""Return a message object for the given "msg" data."""
if type(msg) is bytes:
return RawMessage.from_bytes(msg)
elif isinstance(msg, str):
return RawMessage.from_bytes(msg)
elif type(msg) is RawMessage:
return msg
else:
raise NotImplementedError
def __init__(self, direction, msg, exception):
err = 'error while processing stream: {!r}'.format(exception)
super(StreamFailure, self).__init__(self, err)
self.direction = direction
self.msg = msg
self.exception = exception
def __repr__(self):
return '{}(direction={!r}, msg={!r}, exception={!r})'.format(
type(self).__name__,
self.direction,
self.msg,
self.exception,
)
def encode_message(msg):
"""Return the message, serialized to the line-format."""
raw = msg.as_bytes()
if not raw.endswith(b'\n'):
raw += b'\n'
return raw
def iter_messages(stream, stop=lambda: False):
@ -36,18 +39,6 @@ def iter_messages(stream, stop=lambda: False):
yield StreamFailure('recv', None, exc)
def parse_message(msg):
"""Return a message object for the given "msg" data."""
if type(msg) is bytes:
return RawMessage.from_bytes(msg)
elif isinstance(msg, str):
return RawMessage.from_bytes(msg)
elif type(msg) is RawMessage:
return msg
else:
raise NotImplementedError
class RawMessage(namedtuple('RawMessage', 'bytes')):
"""A pydevd message class that leaves the raw bytes unprocessed."""

38
tests/helpers/socket.py Normal file
View file

@ -0,0 +1,38 @@
from __future__ import absolute_import
import socket
def connect(host, port):
"""Return (client, server) after connecting.
If host is None then it's a server, so it will wait for a connection
on localhost. Otherwise it will connect to the remote host.
"""
sock = socket.socket(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
)
sock.setsockopt(
socket.SOL_SOCKET,
socket.SO_REUSEADDR,
1,
)
if host is None:
addr = ('127.0.0.1', port)
server = sock
server.bind(addr)
server.listen(1)
sock, _ = server.accept()
else:
addr = (host, port)
sock.connect(addr)
server = None
return sock, server
def close(sock):
"""Shutdown and close the socket."""
sock.shutdown(socket.SHUT_RDWR)
sock.close()

View file

@ -1,34 +1,23 @@
import socket
import threading
from ._vsc import StreamFailure, encode_message, iter_messages, parse_message
from ._vsc import RawMessage # noqa
from tests.helpers import protocol, socket
from ._vsc import encode_message, iter_messages, parse_message
def socket_close(sock):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
PROTOCOL = protocol.MessageProtocol(
parse=parse_message,
encode=encode_message,
iter=iter_messages,
)
class _Started(object):
def __init__(self, fake):
self.fake = fake
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class Started(protocol.Started):
def send_request(self, msg):
return self.fake.send_request(msg)
def close(self):
self.fake.close()
class FakeVSC(object):
class FakeVSC(protocol.Daemon):
"""A testing double for a VSC debugger protocol client.
This class facilitates sending VSC debugger protocol messages over
@ -58,52 +47,33 @@ class FakeVSC(object):
""" # noqa
def __init__(self, start_adapter, handler=None):
def start_adapter(host, port, start_adapter=start_adapter):
self._adapter = start_adapter(host, port)
super(FakeVSC, self).__init__(socket.connect, PROTOCOL, handler)
def start_adapter(host, port, _start_adapter=start_adapter):
self._adapter = _start_adapter(host, port)
self._start_adapter = start_adapter
self._handler = handler
self._closed = False
self._received = []
self._failures = []
# These are set when we start.
self._host = None
self._port = None
self._adapter = None
self._sock = None
self._server = None
self._listener = None
@property
def addr(self):
host, port = self._host, self._port
if host is None:
host = '127.0.0.1'
return (host, port)
@property
def received(self):
"""All the messages received thus far."""
return list(self._received)
@property
def failures(self):
"""All send/recv failures thus far."""
return self._failures
def start(self, host, port):
"""Start the fake and the adapter."""
if self._closed or self._adapter is not None:
if self._adapter is not None:
raise RuntimeError('already started')
return super(FakeVSC, self).start(host, port)
if not host:
def send_request(self, req):
"""Send the given Request object."""
return self.send_message(req)
# internal methods
def _start(self, host=None):
start_adapter = (lambda: self._start_adapter(self._host, self._port))
if not self._host:
# The adapter is the server so start it first.
t = threading.Thread(
target=lambda: self._start_adapter(host, port))
t = threading.Thread(target=start_adapter)
t.start()
self._start('127.0.0.1', port)
super(FakeVSC, self)._start('127.0.0.1')
t.join(timeout=1)
if t.is_alive():
raise RuntimeError('timed out')
@ -111,107 +81,15 @@ class FakeVSC(object):
# The adapter is the client so start it last.
# TODO: For now don't use this.
raise NotImplementedError
t = threading.Thread(
target=lambda: self._start(host, port))
t = threading.Thread(target=super(FakeVSC, self)._start)
t.start()
self._start_adapter(host, port)
start_adapter()
t.join(timeout=1)
if t.is_alive():
raise RuntimeError('timed out')
return _Started(self)
def send_request(self, req):
"""Send the given Request object."""
if self._closed:
raise EOFError('closed')
req = parse_message(req)
raw = encode_message(req)
try:
self._send(raw)
except Exception as exc:
failure = ('send', req, exc)
self._failures.append(failure)
def close(self):
"""Close the fake's resources (e.g. socket, adapter)."""
if self._closed:
return
self._closed = True
self._close()
def assert_received(self, case, expected):
"""Ensure that the received messages match the expected ones."""
received = [parse_message(msg) for msg in self._received]
expected = [parse_message(msg) for msg in expected]
case.assertEqual(received, expected)
# internal methods
def _start(self, host, port):
self._host = host
self._port = port
self._connect()
# TODO: make daemon?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
def _connect(self):
sock = socket.socket(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
)
sock.setsockopt(
socket.SOL_SOCKET,
socket.SO_REUSEADDR,
1,
)
if self._host is None:
server = sock
server.bind(self.addr)
server.listen(1)
sock, _ = server.accept()
else:
sock.connect(self.addr)
server = None
self._server = server
self._sock = sock
def _listen(self):
with self._sock.makefile('rb') as sockfile:
for msg in iter_messages(sockfile, lambda: self._closed):
if isinstance(msg, StreamFailure):
self._failures.append(msg)
else:
self._add_received(msg)
def _add_received(self, msg):
self._received.append(msg)
if self._handler is not None:
self._handler(msg, self.send_request)
def _send(self, raw):
while raw:
sent = self._sock.send(raw)
raw = raw[sent:]
def _close(self):
if self._adapter is not None:
self._adapter.close()
self._adapter = None
if self._sock is not None:
socket_close(self._sock)
self._sock = None
if self._server is not None:
socket_close(self._server)
self._server = None
if self._listener is not None:
self._listener.join(timeout=1)
# TODO: the listener isn't stopping!
#if self._listener.is_alive():
# raise RuntimeError('timed out')
self._listener = None
super(FakeVSC, self)._close()

View file

@ -2,27 +2,22 @@ from collections import namedtuple
import json
from debugger_protocol.messages import wireformat
from tests.helpers.protocol import StreamFailure
# TODO: Use more of the code from debugger_protocol.
class StreamFailure(Exception):
"""Something went wrong while handling messages to/from a stream."""
def __init__(self, direction, msg, exception):
err = 'error while processing stream: {!r}'.format(exception)
super(StreamFailure, self).__init__(self, err)
self.direction = direction
self.msg = msg
self.exception = exception
def __repr__(self):
return '{}(direction={!r}, msg={!r}, exception={!r})'.format(
type(self).__name__,
self.direction,
self.msg,
self.exception,
)
def parse_message(msg):
"""Return a message object for the given "msg" data."""
if type(msg) is str:
data = json.loads(msg)
elif isinstance(msg, bytes):
data = json.loads(msg.decode('utf-8'))
elif type(msg) is RawMessage:
return msg
else:
data = msg
return RawMessage.from_data(**data)
def encode_message(msg):
@ -42,19 +37,6 @@ def iter_messages(stream, stop=lambda: False):
yield StreamFailure('recv', None, exc)
def parse_message(msg):
"""Return a message object for the given "msg" data."""
if type(msg) is str:
data = json.loads(msg)
elif isinstance(msg, bytes):
data = json.loads(msg.decode('utf-8'))
elif type(msg) is RawMessage:
return msg
else:
data = msg
return RawMessage.from_data(**data)
class RawMessage(namedtuple('RawMessage', 'data')):
"""A wrapper around a line-formatted debugger protocol message."""