Factor out Request, Response, and Event.

This commit is contained in:
Eric Snow 2018-03-01 20:18:17 +00:00
parent b0a71d0105
commit 87fffa7bdb
4 changed files with 203 additions and 16 deletions

View file

@ -89,18 +89,26 @@ class FakeVSC(protocol.Daemon):
command = req['command']
def match(msg):
msg = msg.data
if msg['type'] != 'response' or msg['request_seq'] != reqseq:
#msg = parse_message(msg)
try:
actual = msg.request_seq
except AttributeError:
return False
assert(msg['command'] == command)
if actual != reqseq:
return False
assert(msg.command == command)
return True
return self._wait_for_message(match, req, **kwargs)
def wait_for_event(self, event, **kwargs):
def match(msg):
msg = msg.data
if msg['type'] != 'event' or msg['event'] != event:
#msg = parse_message(msg)
try:
actual = msg.event
except AttributeError:
return False
if actual != event:
return False
return True

View file

@ -7,6 +7,12 @@ from tests.helpers.protocol import StreamFailure
# TODO: Use more of the code from debugger_protocol.
class ProtocolMessageError(Exception): pass # noqa
class MalformedMessageError(ProtocolMessageError): pass # noqa
class IncompleteMessageError(MalformedMessageError): pass # noqa
class UnsupportedMessageTypeError(ProtocolMessageError): pass # noqa
def parse_message(msg):
"""Return a message object for the given "msg" data."""
if type(msg) is str:
@ -14,10 +20,29 @@ def parse_message(msg):
elif isinstance(msg, bytes):
data = json.loads(msg.decode('utf-8'))
elif type(msg) is RawMessage:
return msg
try:
msg.data['seq']
msg.data['type']
except KeyError:
return msg
return parse_message(msg.data)
elif isinstance(msg, ProtocolMessage):
if msg.TYPE is not None:
return msg
try:
ProtocolMessage._look_up(msg.type)
except UnsupportedMessageTypeError:
return msg
data = msg.as_data()
else:
data = msg
return RawMessage.from_data(**data)
cls = look_up(data)
try:
return cls.from_data(**data)
except IncompleteMessageError:
# TODO: simply fail?
return RawMessage.from_data(**data)
def encode_message(msg):
@ -29,7 +54,8 @@ def iter_messages(stream, stop=lambda: False):
"""Yield the correct message for each line-formatted one found."""
while not stop():
try:
msg = wireformat.read(stream, lambda _: RawMessage)
#msg = wireformat.read(stream, lambda _: RawMessage)
msg = wireformat.read(stream, look_up)
if msg is None: # EOF
break
yield msg
@ -37,6 +63,20 @@ def iter_messages(stream, stop=lambda: False):
yield StreamFailure('recv', None, exc)
def look_up(data):
"""Return the message type to use."""
try:
msgtype = data['type']
except KeyError:
# TODO: return RawMessage?
ProtocolMessage._check_data(data)
try:
return ProtocolMessage._look_up(msgtype)
except UnsupportedMessageTypeError:
# TODO: return Message?
raise
class RawMessage(namedtuple('RawMessage', 'data')):
"""A wrapper around a line-formatted debugger protocol message."""
@ -54,3 +94,142 @@ class RawMessage(namedtuple('RawMessage', 'data')):
def as_data(self):
"""Return the corresponding data, ready to be JSON-encoded."""
return self.data
class ProtocolMessage(object):
"""The base type for VSC debug adapter protocol message."""
TYPE = None
@classmethod
def from_data(cls, **data):
"""Return a message for the given JSON-decoded data."""
try:
return cls(**data)
except TypeError:
cls._check_data(data)
raise
@classmethod
def _check_data(cls, data):
missing = set(cls._fields) - set(data)
if missing:
raise IncompleteMessageError(','.join(missing))
@classmethod
def _look_up(cls, msgtype):
if msgtype == 'request':
return Request
elif msgtype == 'response':
return Response
elif msgtype == 'event':
return Event
else:
raise UnsupportedMessageTypeError(msgtype)
def __new__(cls, seq, type, **kwargs):
if cls is ProtocolMessage:
return Message(seq, type, **kwargs)
seq = int(seq)
type = str(type) if type else None
unused = {k: kwargs.pop(k)
for k in tuple(kwargs)
if k not in cls._fields}
self = super(ProtocolMessage, cls).__new__(cls, seq, type, **kwargs)
self._unused = unused
return self
def __init__(self, *args, **kwargs):
if self.TYPE is None:
if self.type is None:
raise TypeError('missing type')
elif self.type != self.TYPE:
msg = 'wrong type (expected {!r}, go {!r}'
raise ValueError(msg.format(self.TYPE, self.type))
def __repr__(self):
raw = super(ProtocolMessage, self).__repr__()
if self.TYPE is None:
return raw
return ', '.join(part
for part in raw.split(', ')
if not part.startswith('type='))
@property
def unused(self):
return dict(self._unused)
def as_data(self):
"""Return the corresponding data, ready to be JSON-encoded."""
data = self._asdict()
data.update(self._unused)
return data
class Message(ProtocolMessage, namedtuple('Message', 'seq type')):
"""A generic DAP message."""
def __getattr__(self, name):
try:
return self._unused[name]
except KeyError:
raise AttributeError(name)
class Request(ProtocolMessage,
namedtuple('Request', 'seq type command arguments')):
"""A DAP request message."""
TYPE = 'request'
def __new__(cls, seq, type, command, arguments, **unused):
# TODO: Make "arguments" immutable?
return super(Request, cls).__new__(
cls,
seq,
type,
command=command,
arguments=arguments,
**unused
)
class Response(ProtocolMessage,
namedtuple('Response',
'seq type request_seq command success message body'),
):
"""A DAP response message."""
TYPE = 'response'
def __new__(cls, seq, type, request_seq, command, success, message, body,
**unused):
# TODO: Make "body" immutable?
return super(Response, cls).__new__(
cls,
seq,
type,
request_seq=request_seq,
command=command,
success=success,
message=message,
body=body,
**unused
)
class Event(ProtocolMessage, namedtuple('Event', 'seq type event body')):
"""A DAP event message."""
TYPE = 'event'
def __new__(cls, seq, type, event, body, **unused):
# TODO: Make "body" immutable?
return super(Event, cls).__new__(
cls,
seq,
type,
event=event,
body=body,
**unused
)

View file

@ -249,7 +249,7 @@ class VSCLifecycle(object):
See https://code.visualstudio.com/docs/extensionAPI/api-debugging#_the-vs-code-debug-protocol-in-a-nutshell
""" # noqa
def handle_response(resp, _):
self._capabilities = resp.data['body']
self._capabilities = resp.body
version = self._fix.debugger.VERSION
self._fix.set_debugger_response(CMD_VERSION, version)
self._fix.send_request(
@ -411,13 +411,13 @@ class HighlevelFixture(object):
next(self.vsc_msgs.event_seq)
for msg in reversed(self.vsc.received):
if msg.data['type'] == 'response':
if msg.data['command'] == 'threads':
if msg.type == 'response':
if msg.command == 'threads':
break
else:
assert False, 'we waited for the response in send_request()'
for tinfo in msg.data['body']['threads']:
for tinfo in msg.body['threads']:
try:
thread = request[tinfo['name']]
except KeyError:
@ -541,7 +541,7 @@ class HighlevelTest(object):
failure = received[-1]
expected = self.vsc.protocol.parse(
self.fix.vsc_msgs.new_failure(req, failure.data['message']))
self.fix.vsc_msgs.new_failure(req, failure.message))
self.assertEqual(failure, expected)
def assert_received(self, daemon, expected):

View file

@ -1507,7 +1507,7 @@ class ThreadEventTest(PyDevdEventTest):
def send_event(self, *args, **kwargs):
def handler(msg, _):
self._tid = msg.data['body']['threadId']
self._tid = msg.body['threadId']
kwargs['handler'] = handler
super(ThreadEventTest, self).send_event(*args, **kwargs)
return self._tid
@ -1844,8 +1844,8 @@ class SendCurrExcTraceTests(PyDevdEventTest, unittest.TestCase):
self.assert_vsc_received(received, [])
self.assert_received(self.debugger, [])
self.assertTrue(resp.data['success'], resp.data['message'])
self.assertEqual(resp.data['body'], dict(
self.assertTrue(resp.success, resp.message)
self.assertEqual(resp.body, dict(
exceptionId='RuntimeError',
description='something went wrong',
breakMode='unhandled',