mirror of
https://github.com/microsoft/debugpy.git
synced 2025-12-23 08:48:12 +00:00
325 lines
9.1 KiB
Python
325 lines
9.1 KiB
Python
from collections import namedtuple
|
|
import errno
|
|
import threading
|
|
|
|
from . import socket
|
|
from .counter import Counter
|
|
|
|
|
|
try:
|
|
BrokenPipeError
|
|
except NameError:
|
|
class BrokenPipeError(Exception):
|
|
pass
|
|
|
|
|
|
class StreamFailure(Exception):
|
|
"""Something went wrong while handling messages to/from a stream."""
|
|
|
|
def __init__(self, direction, msg, exception):
|
|
err = 'error while processing stream: {!r}'.format(exception)
|
|
super(StreamFailure, self).__init__(self, err)
|
|
self.direction = direction
|
|
self.msg = msg
|
|
self.exception = exception
|
|
|
|
def __repr__(self):
|
|
return '{}(direction={!r}, msg={!r}, exception={!r})'.format(
|
|
type(self).__name__,
|
|
self.direction,
|
|
self.msg,
|
|
self.exception,
|
|
)
|
|
|
|
|
|
class MessageProtocol(namedtuple('Protocol', 'parse encode iter')):
|
|
"""A basic abstraction of a message protocol.
|
|
|
|
parse(msg) - returns a message for the given data.
|
|
encode(msg) - returns the message, serialized to the line-format.
|
|
iter(stream, stop) - yield each message from the stream. "stop"
|
|
is a function called with no args which returns True if the
|
|
iterator should stop.
|
|
"""
|
|
|
|
def parse_each(self, messages):
|
|
"""Yield the parsed version of each message."""
|
|
for msg in messages:
|
|
yield self.parse(msg)
|
|
|
|
|
|
class MessageCounters(namedtuple('MessageCounters',
|
|
'request response event')):
|
|
"""Track the next "seq" value for the protocol message types."""
|
|
|
|
REQUEST_INC = 1
|
|
RESPONSE_INC = 1
|
|
EVENT_INC = 1
|
|
|
|
def __new__(cls, request=0, response=0, event=None):
|
|
request = Counter(request, cls.REQUEST_INC)
|
|
if response is None:
|
|
response = request
|
|
else:
|
|
response = Counter(response, cls.RESPONSE_INC)
|
|
if event is None:
|
|
event = response
|
|
else:
|
|
event = Counter(event, cls.EVENT_INC)
|
|
self = super(MessageCounters, cls).__new__(
|
|
cls,
|
|
request,
|
|
response,
|
|
event,
|
|
)
|
|
return self
|
|
|
|
def next_request(self):
|
|
return next(self.request)
|
|
|
|
def next_response(self):
|
|
return next(self.response)
|
|
|
|
def next_event(self):
|
|
return next(self.event)
|
|
|
|
|
|
class DaemonStarted(object):
|
|
"""A simple wrapper around a started protocol daemon."""
|
|
|
|
def __init__(self, daemon, address, starting=None):
|
|
self.daemon = daemon
|
|
self.address = address
|
|
self._starting = starting
|
|
|
|
def __enter__(self):
|
|
self.wait_until_connected()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.close()
|
|
|
|
def wait_until_connected(self, timeout=None):
|
|
starting = self._starting
|
|
if starting is None:
|
|
return
|
|
starting.join(timeout=timeout)
|
|
if starting.is_alive():
|
|
raise RuntimeError('timed out')
|
|
self._starting = None
|
|
|
|
def close(self):
|
|
self.wait_until_connected()
|
|
self.daemon.close()
|
|
|
|
|
|
class Daemon(object):
|
|
|
|
STARTED = DaemonStarted
|
|
|
|
def __init__(self, bind):
|
|
self._bind = bind
|
|
|
|
self._closed = False
|
|
|
|
# These are set when we start.
|
|
self._address = None
|
|
self._sock = None
|
|
|
|
def start(self, address):
|
|
"""Start the fake daemon.
|
|
|
|
This calls the earlier provided bind() function.
|
|
|
|
A listener loop is started in another thread to handle incoming
|
|
messages from the socket.
|
|
"""
|
|
self._address = address
|
|
addr, starting = self._start(address)
|
|
return self.STARTED(self, addr, starting)
|
|
|
|
def close(self):
|
|
"""Clean up the daemon's resources (e.g. sockets, files, listener)."""
|
|
if self._closed:
|
|
return
|
|
|
|
self._closed = True
|
|
self._close()
|
|
|
|
# internal methods
|
|
|
|
def _start(self, address):
|
|
connect, addr = self._bind(address)
|
|
|
|
def run():
|
|
self._sock = connect()
|
|
self._handle_connected()
|
|
t = threading.Thread(target=run)
|
|
t.start()
|
|
return addr, t
|
|
|
|
def _handle_connected(self):
|
|
pass
|
|
|
|
def _close(self):
|
|
if self._sock is not None:
|
|
socket.close(self._sock)
|
|
self._sock = None
|
|
|
|
|
|
class MessageDaemonStarted(DaemonStarted):
|
|
"""A simple wrapper around a started message protocol daemon."""
|
|
|
|
def send_message(self, msg):
|
|
self.wait_until_connected()
|
|
return self.daemon.send_message(msg)
|
|
|
|
|
|
class MessageDaemon(Daemon):
|
|
"""A testing double for a protocol daemon."""
|
|
|
|
STARTED = MessageDaemonStarted
|
|
|
|
@classmethod
|
|
def validate_message(cls, msg):
|
|
"""Ensure the message is legitimate."""
|
|
# By default check nothing.
|
|
|
|
def __init__(self, bind, protocol, handler):
|
|
super(MessageDaemon, self).__init__(bind)
|
|
|
|
self._protocol = protocol
|
|
|
|
self._received = []
|
|
self._failures = []
|
|
|
|
self._handlers = []
|
|
self._default_handler = handler
|
|
|
|
# These are set when we start.
|
|
self._listener = None
|
|
|
|
@property
|
|
def protocol(self):
|
|
return self._protocol
|
|
|
|
@property
|
|
def received(self):
|
|
"""All the messages received thus far."""
|
|
#parsed = self._protocol.parse_each(self._received)
|
|
#return list(parsed)
|
|
return list(self._received)
|
|
|
|
@property
|
|
def failures(self):
|
|
"""All send/recv failures thus far."""
|
|
return list(self._failures)
|
|
|
|
def send_message(self, msg):
|
|
"""Serialize msg to the line format and send it to the socket."""
|
|
if self._closed:
|
|
raise EOFError('closed')
|
|
self._validate_message(msg)
|
|
self._send_message(msg)
|
|
|
|
def add_handler(self, handler, handlername=None, oneoff=True):
|
|
"""Add the given handler to the list of possible handlers."""
|
|
entry = (
|
|
handler,
|
|
handlername or repr(handler),
|
|
1 if oneoff else None,
|
|
)
|
|
self._handlers.append(entry)
|
|
return handler
|
|
|
|
def reset(self, *initial, **kwargs):
|
|
"""Clear the recorded messages."""
|
|
self._reset(initial, **kwargs)
|
|
|
|
# internal methods
|
|
|
|
def _handle_connected(self):
|
|
# TODO: make it a daemon thread?
|
|
self._listener = threading.Thread(target=self._listen)
|
|
self._listener.start()
|
|
|
|
def _listen(self):
|
|
try:
|
|
with self._sock.makefile('rb') as sockfile:
|
|
for msg in self._protocol.iter(sockfile, lambda: self._closed):
|
|
if isinstance(msg, StreamFailure):
|
|
self._failures.append(msg)
|
|
else:
|
|
self._add_received(msg)
|
|
except BrokenPipeError:
|
|
if self._closed:
|
|
return
|
|
# TODO: try reconnecting?
|
|
raise
|
|
except OSError as exc:
|
|
if exc.errno in (errno.EPIPE, errno.ESHUTDOWN): # BrokenPipeError
|
|
return
|
|
if exc.errno == 9: # socket closed
|
|
return
|
|
if exc.errno == errno.EBADF: # socket closed
|
|
return
|
|
# TODO: try reconnecting?
|
|
raise
|
|
|
|
def _add_received(self, msg):
|
|
self._received.append(msg)
|
|
self._handle_message(msg)
|
|
|
|
def _handle_message(self, msg):
|
|
for i, entry in enumerate(list(self._handlers)):
|
|
handle_message, name, remaining = entry
|
|
handled = handle_message(msg, self._send_message)
|
|
if handled or handled is None:
|
|
if remaining is not None:
|
|
if remaining == 1:
|
|
self._handlers.pop(i)
|
|
else:
|
|
self._handlers[i] = (handle_message, name, remaining-1)
|
|
return handled
|
|
else:
|
|
if self._default_handler is not None:
|
|
return self._default_handler(msg, self._send_message)
|
|
return False
|
|
|
|
def _validate_message(self, msg):
|
|
return
|
|
|
|
def _send_message(self, msg):
|
|
msg = self._protocol.parse(msg)
|
|
raw = self._protocol.encode(msg)
|
|
try:
|
|
self._send(raw)
|
|
except Exception as exc:
|
|
raise
|
|
failure = StreamFailure('send', msg, exc)
|
|
self._failures.append(failure)
|
|
|
|
def _send(self, raw):
|
|
while raw:
|
|
sent = self._sock.send(raw)
|
|
raw = raw[sent:]
|
|
|
|
def _close(self):
|
|
super(MessageDaemon, self)._close()
|
|
if self._listener is not None:
|
|
self._listener.join(timeout=1)
|
|
# TODO: the listener isn't stopping!
|
|
#if self._listener.is_alive():
|
|
# raise RuntimeError('timed out')
|
|
self._listener = None
|
|
|
|
def _reset(self, initial, force=False):
|
|
if self._failures:
|
|
raise RuntimeError('have failures ({!r})'.format(self._failures))
|
|
if self._handlers:
|
|
if force:
|
|
self._handlers = []
|
|
else:
|
|
names = ', '.join(name for _, name, _ in self._handlers)
|
|
raise RuntimeError('have pending handlers: [{}]'.format(names))
|
|
self._received = list(self._protocol.parse_each(initial))
|