diff --git a/src/ptvsd/common/messaging.py b/src/ptvsd/common/messaging.py index 6e7675f9..d70c98f4 100644 --- a/src/ptvsd/common/messaging.py +++ b/src/ptvsd/common/messaging.py @@ -5,6 +5,7 @@ from __future__ import print_function, with_statement, absolute_import import contextlib +import inspect import itertools import json import sys @@ -156,7 +157,6 @@ class JsonIOStream(object): ptvsd.common.log.debug('{0} <-- {1!j}', self.name, value) - class Request(object): """Represents an incoming or an outgoing request. @@ -171,6 +171,7 @@ class Request(object): self.seq = seq self.command = command self.arguments = arguments + self.response = None class OutgoingRequest(Request): @@ -180,7 +181,6 @@ class OutgoingRequest(Request): def __init__(self, *args): super(OutgoingRequest, self).__init__(*args) - self.response = None self._lock = threading.Lock() self._got_response = threading.Event() self._callback = lambda _: None @@ -193,6 +193,7 @@ class OutgoingRequest(Request): callback = self._callback callback(response) self._got_response.set() + return response def wait_for_response(self, raise_if_failed=True): """Waits until a response is received for this request, records that @@ -330,7 +331,7 @@ class JsonMessageChannel(object): def send_request(self, command, arguments=None): d = {'command': command} - if arguments is not None: + if arguments is not None and arguments != {}: d['arguments'] = arguments with self._send_message('request', d) as seq: request = OutgoingRequest(self, seq, command, arguments) @@ -339,25 +340,27 @@ class JsonMessageChannel(object): def send_event(self, event, body=None): d = {'event': event} - if body is not None: + if body is not None and body != {}: d['body'] = body with self._send_message('event', d): pass - def _send_response(self, request_seq, success, command, error_message, body): + def _send_response(self, request, body): d = { - 'request_seq': request_seq, - 'success': success, - 'command': command, + 'request_seq': request.seq, + 'command': request.command, } - if success: - if body is not None: - d['body'] = body + if isinstance(body, Exception): + d['success'] = False + d['message'] = str(body) else: - if error_message is not None: - d['message'] = error_message - with self._send_message('response', d): + d['success'] = True + if body is not None and body != {}: + d['body'] = body + + with self._send_message('response', d) as seq: pass + return Response(self, seq, request, body) def on_message(self, message): seq = message['seq'] @@ -365,18 +368,18 @@ class JsonMessageChannel(object): if typ == 'request': command = message['command'] arguments = message.get('arguments', None) - self.on_request(seq, command, arguments) + return self.on_request(seq, command, arguments) elif typ == 'event': event = message['event'] body = message.get('body', None) - self.on_event(seq, event, body) + return self.on_event(seq, event, body) elif typ == 'response': request_seq = message['request_seq'] success = message['success'] command = message['command'] error_message = message.get('message', None) body = message.get('body', None) - self.on_response(seq, request_seq, success, command, error_message, body) + return self.on_response(seq, request_seq, success, command, error_message, body) else: raise IOError('Incoming message has invalid "type":\n%r' % message) @@ -388,16 +391,45 @@ class JsonMessageChannel(object): handler = getattr(self._handlers, 'request') except AttributeError: raise AttributeError('%r has no handler for request %r' % (self._handlers, command)) + request = Request(self, seq, command, arguments) try: - response_body = handler(request) + result = handler(request) except RequestFailure as ex: - self._send_response(seq, False, command, str(ex), None) + result = ex + + # A request handler can either be a simple function that returns the body of the + # response directly, or a generator that yields. If it is a generator, then every + # yield of None is treated as request to process another pending message recursively, + # after which the generator is resumed. Once any object other than None is yielded, + # that is the body of the response. If the generator stops before yielding a body, + # it is treated as if it had yielded {}. + if inspect.isgenerator(result): + gen = result else: - if isinstance(response_body, Exception): - self._send_response(seq, False, command, str(response_body), None) - else: - self._send_response(seq, True, command, None, response_body) + # Wrap a non-generator return into a generator, to unify processing below. + # Note that return None is the same as return {} in this case, unlike yield. + def gen(): + yield {} if result is None else result + gen = gen() + + last_message = None + while True: + try: + response_body = gen.send(last_message) + except RequestFailure as ex: + response_body = ex + break + except StopIteration: + response_body = {} + + if response_body is not None: + gen.close() + break + last_message = self._process_incoming_message() # re-entrant + + request.response = self._send_response(request, response_body) + return request def on_event(self, seq, event, body): handler_name = '%s_event' % event @@ -407,7 +439,10 @@ class JsonMessageChannel(object): handler = getattr(self._handlers, 'event') except AttributeError: raise AttributeError('%r has no handler for event %r' % (self._handlers, event)) - handler(Event(self, seq, event, body)) + + event = Event(self, seq, event, body) + handler(event) + return event def on_response(self, seq, request_seq, success, command, error_message, body): try: @@ -427,18 +462,21 @@ class JsonMessageChannel(object): request._handle_response(None, request.command, EOFError('No response')) getattr(self._handlers, 'disconnect', lambda: None)() + def _process_incoming_message(self): + message = self.stream.read_json() + try: + return self.on_message(message) + except Exception: + ptvsd.common.log.exception('Error while processing message for {0}:\n\n{1!r}', self.name, message) + raise + def _process_incoming_messages(self): try: while True: try: - message = self.stream.read_json() + self._process_incoming_message() except EOFError: - break - try: - self.on_message(message) - except Exception: - ptvsd.common.log.exception('Error while processing message for {0}:\n\n{1!r}', self.name, message) - raise + return False finally: try: self.on_disconnect() diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 2b630700..d783ef26 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -219,6 +219,183 @@ class TestJsonMessageChannel(object): assert response3 is request3.response assert response3.body == {'threadId': 5} + def test_yield(self): + REQUESTS = [ + {'seq': 10, 'type': 'request', 'command': 'launch', 'arguments': {'noDebug': False}}, + {'seq': 20, 'type': 'request', 'command': 'setBreakpoints', 'arguments': {'main.py': 1}}, + {'seq': 30, 'type': 'event', 'event': 'expected'}, + {'seq': 40, 'type': 'request', 'command': 'launch', 'arguments': {'noDebug': True}}, # test re-entrancy + {'seq': 50, 'type': 'request', 'command': 'setBreakpoints', 'arguments': {'main.py': 2}}, + {'seq': 60, 'type': 'event', 'event': 'unexpected'}, + {'seq': 80, 'type': 'request', 'command': 'configurationDone'}, + {'seq': 90, 'type': 'request', 'command': 'launch'}, # test handler not yielding body + ] + + class Handlers(object): + + received = { + 'launch': 0, + 'setBreakpoints': 0, + 'configurationDone': 0, + 'expected': 0, + 'unexpected': 0, + } + + def launch_request(self, request): + assert request.seq in (10, 40, 90) + self.received['launch'] += 1 + + if request.seq == 10: # launch #1 + assert self.received == { + 'launch': 1, + 'setBreakpoints': 0, + 'configurationDone': 0, + 'expected': 0, + 'unexpected': 0, + } + + msg = yield # setBreakpoints #1 + assert msg.seq == 20 + assert self.received == { + 'launch': 1, + 'setBreakpoints': 1, + 'configurationDone': 0, + 'expected': 0, + 'unexpected': 0, + } + + msg = yield # expected + assert msg.seq == 30 + assert self.received == { + 'launch': 1, + 'setBreakpoints': 1, + 'configurationDone': 0, + 'expected': 1, + 'unexpected': 0, + } + + msg = yield # launch #2 + nested messages + assert msg.seq == 40 + assert self.received == { + 'launch': 2, + 'setBreakpoints': 2, + 'configurationDone': 0, + 'expected': 1, + 'unexpected': 1, + } + + # We should see that it failed, but no exception bubbling up here. + assert not msg.response.success + assert msg.response.body == RequestFailure('test failure') + + msg = yield # configurationDone + assert msg.seq == 80 + assert self.received == { + 'launch': 2, + 'setBreakpoints': 2, + 'configurationDone': 1, + 'expected': 1, + 'unexpected': 1, + } + + yield {'answer': 42} + + elif request.seq == 40: # launch #1 + assert self.received == { + 'launch': 2, + 'setBreakpoints': 1, + 'configurationDone': 0, + 'expected': 1, + 'unexpected': 0, + } + + msg = yield # setBreakpoints #2 + assert msg.seq == 50 + assert self.received == { + 'launch': 2, + 'setBreakpoints': 2, + 'configurationDone': 0, + 'expected': 1, + 'unexpected': 0, + } + + msg = yield # unexpected + assert msg.seq == 60 + assert self.received == { + 'launch': 2, + 'setBreakpoints': 2, + 'configurationDone': 0, + 'expected': 1, + 'unexpected': 1, + } + + raise RequestFailure('test failure') + + elif request.seq == 90: # launch #3 + assert self.received == { + 'launch': 3, + 'setBreakpoints': 2, + 'configurationDone': 1, + 'expected': 1, + 'unexpected': 1, + } + + # Don't yield anything. + pass + + def setBreakpoints_request(self, request): + assert request.seq in (20, 50, 70) + self.received['setBreakpoints'] += 1 + return {'which': self.received['setBreakpoints']} + + def request(self, request): + assert request.seq == 80 + assert request.command == 'configurationDone' + self.received['configurationDone'] += 1 + + def expected_event(self, event): + assert event.seq == 30 + self.received['expected'] += 1 + + def event(self, event): + assert event.seq == 60 + assert event.event == 'unexpected' + self.received['unexpected'] += 1 + + input, input_exhausted = self.iter_with_event(REQUESTS) + output = [] + stream = LoggingJsonStream(JsonMemoryStream(input, output)) + channel = JsonMessageChannel(stream, Handlers()) + channel.start() + input_exhausted.wait() + + assert output == [ + { + 'seq': 1, 'type': 'response', 'request_seq': 20, 'command': 'setBreakpoints', + 'success': True, 'body': {'which': 1}, + }, + { + 'seq': 2, 'type': 'response', 'request_seq': 50, 'command': 'setBreakpoints', + 'success': True, 'body': {'which': 2}, + }, + { + 'seq': 3, 'type': 'response', 'request_seq': 40, 'command': 'launch', + 'success': False, 'message': 'test failure', + }, + { + 'seq': 4, 'type': 'response', 'request_seq': 80, 'command': 'configurationDone', + 'success': True, + }, + { + 'seq': 5, 'type': 'response', 'request_seq': 10, 'command': 'launch', + 'success': True, 'body': {'answer': 42}, + }, + { + 'seq': 6, 'type': 'response', 'request_seq': 90, 'command': 'launch', + 'success': True, + }, + ] + def test_fuzz(self): # Set up two channels over the same stream that send messages to each other # asynchronously, and record everything that they send and receive. @@ -345,4 +522,3 @@ class TestJsonMessageChannel(object): assert fuzzer2.sent == fuzzer1.received assert fuzzer1.responses_sent == fuzzer2.responses_received assert fuzzer2.responses_sent == fuzzer1.responses_received -