From 87fffa7bdb8a5458e6c730995ad34766b248ea34 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Thu, 1 Mar 2018 20:18:17 +0000 Subject: [PATCH] Factor out Request, Response, and Event. --- tests/helpers/vsc/_fake.py | 18 ++- tests/helpers/vsc/_vsc.py | 185 ++++++++++++++++++++++++- tests/ptvsd/highlevel/__init__.py | 10 +- tests/ptvsd/highlevel/test_messages.py | 6 +- 4 files changed, 203 insertions(+), 16 deletions(-) diff --git a/tests/helpers/vsc/_fake.py b/tests/helpers/vsc/_fake.py index c7cecaa0..4c65bf17 100644 --- a/tests/helpers/vsc/_fake.py +++ b/tests/helpers/vsc/_fake.py @@ -89,18 +89,26 @@ class FakeVSC(protocol.Daemon): command = req['command'] def match(msg): - msg = msg.data - if msg['type'] != 'response' or msg['request_seq'] != reqseq: + #msg = parse_message(msg) + try: + actual = msg.request_seq + except AttributeError: return False - assert(msg['command'] == command) + if actual != reqseq: + return False + assert(msg.command == command) return True return self._wait_for_message(match, req, **kwargs) def wait_for_event(self, event, **kwargs): def match(msg): - msg = msg.data - if msg['type'] != 'event' or msg['event'] != event: + #msg = parse_message(msg) + try: + actual = msg.event + except AttributeError: + return False + if actual != event: return False return True diff --git a/tests/helpers/vsc/_vsc.py b/tests/helpers/vsc/_vsc.py index 9b214a74..6aa5d6a1 100644 --- a/tests/helpers/vsc/_vsc.py +++ b/tests/helpers/vsc/_vsc.py @@ -7,6 +7,12 @@ from tests.helpers.protocol import StreamFailure # TODO: Use more of the code from debugger_protocol. +class ProtocolMessageError(Exception): pass # noqa +class MalformedMessageError(ProtocolMessageError): pass # noqa +class IncompleteMessageError(MalformedMessageError): pass # noqa +class UnsupportedMessageTypeError(ProtocolMessageError): pass # noqa + + def parse_message(msg): """Return a message object for the given "msg" data.""" if type(msg) is str: @@ -14,10 +20,29 @@ def parse_message(msg): elif isinstance(msg, bytes): data = json.loads(msg.decode('utf-8')) elif type(msg) is RawMessage: - return msg + try: + msg.data['seq'] + msg.data['type'] + except KeyError: + return msg + return parse_message(msg.data) + elif isinstance(msg, ProtocolMessage): + if msg.TYPE is not None: + return msg + try: + ProtocolMessage._look_up(msg.type) + except UnsupportedMessageTypeError: + return msg + data = msg.as_data() else: data = msg - return RawMessage.from_data(**data) + + cls = look_up(data) + try: + return cls.from_data(**data) + except IncompleteMessageError: + # TODO: simply fail? + return RawMessage.from_data(**data) def encode_message(msg): @@ -29,7 +54,8 @@ def iter_messages(stream, stop=lambda: False): """Yield the correct message for each line-formatted one found.""" while not stop(): try: - msg = wireformat.read(stream, lambda _: RawMessage) + #msg = wireformat.read(stream, lambda _: RawMessage) + msg = wireformat.read(stream, look_up) if msg is None: # EOF break yield msg @@ -37,6 +63,20 @@ def iter_messages(stream, stop=lambda: False): yield StreamFailure('recv', None, exc) +def look_up(data): + """Return the message type to use.""" + try: + msgtype = data['type'] + except KeyError: + # TODO: return RawMessage? + ProtocolMessage._check_data(data) + try: + return ProtocolMessage._look_up(msgtype) + except UnsupportedMessageTypeError: + # TODO: return Message? + raise + + class RawMessage(namedtuple('RawMessage', 'data')): """A wrapper around a line-formatted debugger protocol message.""" @@ -54,3 +94,142 @@ class RawMessage(namedtuple('RawMessage', 'data')): def as_data(self): """Return the corresponding data, ready to be JSON-encoded.""" return self.data + + +class ProtocolMessage(object): + """The base type for VSC debug adapter protocol message.""" + + TYPE = None + + @classmethod + def from_data(cls, **data): + """Return a message for the given JSON-decoded data.""" + try: + return cls(**data) + except TypeError: + cls._check_data(data) + raise + + @classmethod + def _check_data(cls, data): + missing = set(cls._fields) - set(data) + if missing: + raise IncompleteMessageError(','.join(missing)) + + @classmethod + def _look_up(cls, msgtype): + if msgtype == 'request': + return Request + elif msgtype == 'response': + return Response + elif msgtype == 'event': + return Event + else: + raise UnsupportedMessageTypeError(msgtype) + + def __new__(cls, seq, type, **kwargs): + if cls is ProtocolMessage: + return Message(seq, type, **kwargs) + seq = int(seq) + type = str(type) if type else None + unused = {k: kwargs.pop(k) + for k in tuple(kwargs) + if k not in cls._fields} + self = super(ProtocolMessage, cls).__new__(cls, seq, type, **kwargs) + self._unused = unused + return self + + def __init__(self, *args, **kwargs): + if self.TYPE is None: + if self.type is None: + raise TypeError('missing type') + elif self.type != self.TYPE: + msg = 'wrong type (expected {!r}, go {!r}' + raise ValueError(msg.format(self.TYPE, self.type)) + + def __repr__(self): + raw = super(ProtocolMessage, self).__repr__() + if self.TYPE is None: + return raw + return ', '.join(part + for part in raw.split(', ') + if not part.startswith('type=')) + + @property + def unused(self): + return dict(self._unused) + + def as_data(self): + """Return the corresponding data, ready to be JSON-encoded.""" + data = self._asdict() + data.update(self._unused) + return data + + +class Message(ProtocolMessage, namedtuple('Message', 'seq type')): + """A generic DAP message.""" + + def __getattr__(self, name): + try: + return self._unused[name] + except KeyError: + raise AttributeError(name) + + +class Request(ProtocolMessage, + namedtuple('Request', 'seq type command arguments')): + """A DAP request message.""" + + TYPE = 'request' + + def __new__(cls, seq, type, command, arguments, **unused): + # TODO: Make "arguments" immutable? + return super(Request, cls).__new__( + cls, + seq, + type, + command=command, + arguments=arguments, + **unused + ) + + +class Response(ProtocolMessage, + namedtuple('Response', + 'seq type request_seq command success message body'), + ): + """A DAP response message.""" + + TYPE = 'response' + + def __new__(cls, seq, type, request_seq, command, success, message, body, + **unused): + # TODO: Make "body" immutable? + return super(Response, cls).__new__( + cls, + seq, + type, + request_seq=request_seq, + command=command, + success=success, + message=message, + body=body, + **unused + ) + + +class Event(ProtocolMessage, namedtuple('Event', 'seq type event body')): + """A DAP event message.""" + + TYPE = 'event' + + def __new__(cls, seq, type, event, body, **unused): + # TODO: Make "body" immutable? + return super(Event, cls).__new__( + cls, + seq, + type, + event=event, + body=body, + **unused + ) diff --git a/tests/ptvsd/highlevel/__init__.py b/tests/ptvsd/highlevel/__init__.py index 32a92a6f..93a51037 100644 --- a/tests/ptvsd/highlevel/__init__.py +++ b/tests/ptvsd/highlevel/__init__.py @@ -249,7 +249,7 @@ class VSCLifecycle(object): See https://code.visualstudio.com/docs/extensionAPI/api-debugging#_the-vs-code-debug-protocol-in-a-nutshell """ # noqa def handle_response(resp, _): - self._capabilities = resp.data['body'] + self._capabilities = resp.body version = self._fix.debugger.VERSION self._fix.set_debugger_response(CMD_VERSION, version) self._fix.send_request( @@ -411,13 +411,13 @@ class HighlevelFixture(object): next(self.vsc_msgs.event_seq) for msg in reversed(self.vsc.received): - if msg.data['type'] == 'response': - if msg.data['command'] == 'threads': + if msg.type == 'response': + if msg.command == 'threads': break else: assert False, 'we waited for the response in send_request()' - for tinfo in msg.data['body']['threads']: + for tinfo in msg.body['threads']: try: thread = request[tinfo['name']] except KeyError: @@ -541,7 +541,7 @@ class HighlevelTest(object): failure = received[-1] expected = self.vsc.protocol.parse( - self.fix.vsc_msgs.new_failure(req, failure.data['message'])) + self.fix.vsc_msgs.new_failure(req, failure.message)) self.assertEqual(failure, expected) def assert_received(self, daemon, expected): diff --git a/tests/ptvsd/highlevel/test_messages.py b/tests/ptvsd/highlevel/test_messages.py index cae2469a..9cdef6ff 100644 --- a/tests/ptvsd/highlevel/test_messages.py +++ b/tests/ptvsd/highlevel/test_messages.py @@ -1507,7 +1507,7 @@ class ThreadEventTest(PyDevdEventTest): def send_event(self, *args, **kwargs): def handler(msg, _): - self._tid = msg.data['body']['threadId'] + self._tid = msg.body['threadId'] kwargs['handler'] = handler super(ThreadEventTest, self).send_event(*args, **kwargs) return self._tid @@ -1844,8 +1844,8 @@ class SendCurrExcTraceTests(PyDevdEventTest, unittest.TestCase): self.assert_vsc_received(received, []) self.assert_received(self.debugger, []) - self.assertTrue(resp.data['success'], resp.data['message']) - self.assertEqual(resp.data['body'], dict( + self.assertTrue(resp.success, resp.message) + self.assertEqual(resp.body, dict( exceptionId='RuntimeError', description='something went wrong', breakMode='unhandled',