Refactor messaging to consistently use Request/Response/Event objects to represent both incoming and outgoing messages.

Fix messaging issues uncovered by test_fuzz, and fix the test itself.
This commit is contained in:
Pavel Minaev 2018-10-14 15:12:42 -07:00
parent 5065397489
commit 50a250f1fc
4 changed files with 180 additions and 149 deletions

View file

@ -125,12 +125,89 @@ class JsonIOStream(object):
self._writer.write(body)
class Request(object):
"""Represents an incoming or an outgoing request.
Incoming requests are represented by instances of this class.
Outgoing requests are represented by instances of OutgoingRequest, which
provides additional functionality to handle responses.
"""
def __init__(self, channel, seq, command, arguments):
self.channel = channel
self.seq = seq
self.command = command
self.arguments = arguments
class OutgoingRequest(Request):
"""Represents an outgoing request, for which it is possible to wait for a
response to be received, and register a response callback.
"""
def __init__(self, *args):
super(OutgoingRequest, self).__init__(*args)
self.response = None
self._lock = threading.Lock()
self._got_response = threading.Event()
self._callback = lambda _: None
def _handle_response(self, seq, command, body):
assert self.response is None
with self._lock:
response = Response(self.channel, seq, self, body)
self.response = response
callback = self._callback
callback(response)
self._got_response.set()
def wait_for_response(self, raise_if_failed=True):
"""Waits until a response is received for this request, records that
response as a new Response object accessible via self.response, and
returns self.response.body.
If raise_if_failed is True, and the received response does not indicate
success, raises RequestFailure. Otherwise, self.response.body has to be
inspected to determine whether the request failed or succeeded.
"""
self._got_response.wait()
if raise_if_failed and not self.response.success:
raise self.response.body
return self.response.body
def on_response(self, callback):
"""Registers a callback to invoke when a response is received for this
request. If response was already received, invokes callback immediately.
Callback is invoked with Response as the sole arugment.
To get access to the entire Response object in the callback, the callback
should be a lambda capturing the request on which on_response was called.
Then, request.response can be inspected inside the callback.
The callback is invoked on an unspecified background thread that performs
processing of incoming messages; therefore, no further message processing
occurs until the callback returns.
"""
with self._lock:
response = self.response
if response is None:
self._callback = callback
return
callback(response)
class Response(object):
"""Represents a response to a Request.
"""
def __init__(self, request, body):
self.request = None
def __init__(self, channel, seq, request, body):
self.channel = channel
self.seq = seq
self.request = request
"""Request object that this is a response to.
"""
@ -150,13 +227,16 @@ class Response(object):
def success(self):
return not isinstance(self.body, Exception)
def __eq__(self, other):
if not isinstance(other, Response):
return NotImplemented
return self.request is other.request and self.body == other.body
def __ne__(self, other):
return not self == other
class Event(object):
"""Represents a received event.
"""
def __init__(self, channel, seq, event, body):
self.channel = channel
self.seq = seq
self.event = event
self.body = body
class RequestFailure(Exception):
@ -164,71 +244,18 @@ class RequestFailure(Exception):
self.message = message
def __eq__(self, other):
if isinstance(other, RequestFailure) and other.message == self.message:
return True
return NotImplemented
if not isinstance(other, RequestFailure):
return NotImplemented
return self.message == other.message
def __ne__(self, other):
return not self == other
def __repr__(self):
return 'RequestFailure(%r)' % self.message
class Request(object):
"""Represents a request that was sent to the other party, and is awaiting or has
already received a response.
"""
def __init__(self, channel, seq, command, arguments):
self.channel = channel
self.seq = seq
self.command = command
self.arguments = arguments
self.response = None
self._lock = threading.Lock()
self._got_response = threading.Event()
self._callback = lambda _: None
def _handle_response(self, command, body):
assert self.response is None
with self._lock:
response = Response(self, body)
self.response = response
callback = self._callback
callback(response.body)
self._got_response.set()
def wait_for_response(self, raise_if_failed=True):
"""Waits until a response is received for this request, records that
response as a new Response object accessible via self.response,
and returns self.response.body.
If raise_if_failed is True, and the received response does not indicate
success, raises RequestFailure. Otherwise, self.response.success has to
be inspected to determine whether the request failed or succeeded, since
self.response.body can be None in either case.
"""
self._got_response.wait()
if raise_if_failed and not self.response.success:
raise self.response.body
return self.response.body
def on_response(self, callback):
"""Registers a callback to invoke when a response is received for this
request. If response was already received, invokes callback immediately.
Callback is invoked with Response object as the sole argument.
The callback is invoked on an unspecified background thread that performs
processing of incoming messages; therefore, no further message processing
occurs until the callback returns.
"""
with self._lock:
response = self.response
if response is None:
self._callback = callback
return
callback(response.body)
def __str__(self):
return self.message
class JsonMessageChannel(object):
@ -277,7 +304,7 @@ class JsonMessageChannel(object):
if arguments is not None:
d['arguments'] = arguments
with self._send_message('request', d) as seq:
request = Request(self, seq, command, arguments)
request = OutgoingRequest(self, seq, command, arguments)
self._requests[seq] = request
return request
@ -327,34 +354,32 @@ class JsonMessageChannel(object):
def on_request(self, seq, command, arguments):
handler_name = '%s_request' % command
specific_handler = getattr(self._handlers, handler_name, None)
if specific_handler is not None:
handler = lambda: specific_handler(self, arguments)
else:
handler = getattr(self._handlers, handler_name, None)
if handler is None:
try:
generic_handler = getattr(self._handlers, 'request')
handler = getattr(self._handlers, 'request')
except AttributeError:
raise AttributeError('%r has no handler for request %r' % (self._handlers, command))
handler = lambda: generic_handler(self, command, arguments)
request = Request(self, seq, command, arguments)
try:
response_body = handler()
response_body = handler(request)
except Exception as ex:
self._send_response(seq, False, command, str(ex), None)
else:
self._send_response(seq, True, command, None, response_body)
if isinstance(response_body, Exception):
self._send_response(seq, False, command, str(response_body), None)
else:
self._send_response(seq, True, command, None, response_body)
def on_event(self, seq, event, body):
handler_name = '%s_event' % event
specific_handler = getattr(self._handlers, handler_name, None)
if specific_handler is not None:
handler = lambda: specific_handler(self, body)
else:
handler = getattr(self._handlers, handler_name, None)
if handler is None:
try:
generic_handler = getattr(self._handlers, 'event')
handler = getattr(self._handlers, 'event')
except AttributeError:
raise AttributeError('%r has no handler for event %r' % (self._handlers, event))
handler = lambda: generic_handler(self, event, body)
handler()
handler(Event(self, seq, event, body))
def on_response(self, seq, request_seq, success, command, error_message, body):
try:
@ -364,7 +389,7 @@ class JsonMessageChannel(object):
raise KeyError('Received response to unknown request %d', request_seq)
if not success:
body = RequestFailure(error_message)
return request._handle_response(command, body)
return request._handle_response(seq, command, body)
def _process_incoming_messages(self):
try:
@ -384,7 +409,7 @@ class JsonMessageChannel(object):
# must be marked as failed to unblock anyone waiting on them.
with self._lock:
for request in self._requests.values():
request._handle_response(request.command, EOFError('No response'))
request._handle_response(None, request.command, EOFError('No response'))
class MessageHandlers(object):

View file

@ -11,6 +11,7 @@ import re
import socket
import sys
import time
import traceback
try:
import queue
@ -99,18 +100,19 @@ def _subprocess_listener():
def _handle_subprocess(n, stream):
class Handlers(object):
def ptvsd_subprocess_request(self, channel, body):
def ptvsd_subprocess_request(self, request):
# When child process is spawned, the notification it sends only
# contains information about itself and its immediate parent.
# Add information about the root process before passing it on.
body.update({
arguments = dict(request.arguments)
arguments.update({
'rootProcessId': os.getpid(),
'rootStartRequest': root_start_request,
})
debug('ptvsd_subprocess: %r' % body)
debug('ptvsd_subprocess: %r' % arguments)
response = {'incomingConnection': False}
subprocess_queue.put((body, response))
subprocess_queue.put((arguments, response))
subprocess_queue.join()
return response
@ -145,7 +147,8 @@ def notify_root(port):
try:
response = request.wait_for_response()
except Exception:
debug('Failed to send subprocess notification; exiting')
print('Failed to send subprocess notification; exiting', file=sys.__stderr__)
traceback.print_exc()
sys.exit(0)
if not response['incomingConnection']:

View file

@ -310,7 +310,7 @@ class DebugSession(object):
request = self.timeline.record_request(command, arguments)
request.sent = self.channel.send_request(command, arguments)
request.sent.on_response(lambda body: self._process_response(request, body))
request.sent.on_response(lambda response: self._process_response(request, response))
def causing(*expectations):
for exp in expectations:
@ -372,13 +372,13 @@ class DebugSession(object):
return start
def _process_event(self, channel, event, body):
self.timeline.record_event(event, body, block=False)
def _process_event(self, event):
self.timeline.record_event(event.event, event.body, block=False)
def _process_response(self, request, body):
self.timeline.record_response(request, body, block=False)
def _process_response(self, request_occ, response):
self.timeline.record_response(request_occ, response.body, block=False)
def _process_request(self, channel, command, arguments):
def _process_request(self, request):
assert False, 'ptvsd should not be sending requests.'
def setup_backchannel(self):

View file

@ -12,7 +12,7 @@ import socket
import threading
import time
from ptvsd.messaging import JsonIOStream, JsonMessageChannel, Response, RequestFailure
from ptvsd.messaging import JsonIOStream, JsonMessageChannel, RequestFailure
from .helpers.messaging import JsonMemoryStream, LoggingJsonStream
@ -96,11 +96,12 @@ class TestJsonMessageChannel(object):
events_received = []
class Handlers(object):
def stopped_event(self, channel, body):
events_received.append((channel, body))
def stopped_event(self, event):
assert event.event == 'stopped'
events_received.append((event.channel, event.body))
def event(self, channel, event, body):
events_received.append((channel, event, body))
def event(self, event):
events_received.append((event.channel, event.event, event.body))
input, input_exhausted = self.iter_with_event(EVENTS)
stream = LoggingJsonStream(JsonMemoryStream(input, []))
@ -123,15 +124,17 @@ class TestJsonMessageChannel(object):
requests_received = []
class Handlers(object):
def next_request(self, channel, arguments):
requests_received.append((channel, arguments))
def next_request(self, request):
assert request.command == 'next'
requests_received.append((request.channel, request.arguments))
return {'threadId': 7}
def request(self, channel, command, arguments):
requests_received.append((channel, command, arguments))
def request(self, request):
requests_received.append((request.channel, request.command, request.arguments))
def pause_request(self, channel, arguments):
requests_received.append((channel, arguments))
def pause_request(self, request):
assert request.command == 'pause'
requests_received.append((request.channel, request.arguments))
raise RuntimeError('pause error')
input, input_exhausted = self.iter_with_event(REQUESTS)
@ -176,44 +179,45 @@ class TestJsonMessageChannel(object):
response1_body = request1.wait_for_response()
response1 = request1.response
assert response1 == Response(request1, body={'threadId': 3})
assert response1.success
assert response1.request is request1
assert response1.body == response1_body
assert response1.body == {'threadId': 3}
# Async callback, registered before response is received.
request2 = channel.send_request('pause')
response2_body = []
response2 = []
response2_received = threading.Event()
def response2_handler(body):
response2_body.append(body)
def response2_handler(resp):
response2.append(resp)
response2_received.set()
request2.on_response(response2_handler)
request2_sent.set()
response2_received.wait()
response2_body, = response2_body
response2 = request2.response
response2, = response2
assert response2 == Response(request2, RequestFailure('pause error'))
assert not response2.success
assert response2.body == response2_body
assert response2.request is request2
assert response2 is request2.response
assert response2.body == RequestFailure('pause error')
# Async callback, registered after response is received.
request3 = channel.send_request('next')
request3_sent.set()
request3.wait_for_response()
response3_body = []
response3 = []
response3_received = threading.Event()
def response3_handler(body):
response3_body.append(body)
def response3_handler(resp):
response3.append(resp)
response3_received.set()
request3.on_response(response3_handler)
response3_received.wait()
response3_body, = response3_body
response3 = request3.response
response3, = response3
assert response3 == Response(request3, body={'threadId': 5})
assert response3.success
assert response3.body == response3_body
assert response3.request is request3
assert response3 is request3.response
assert response3.body == {'threadId': 5}
def test_fuzz(self):
# Set up two channels over the same stream that send messages to each other
@ -237,45 +241,44 @@ class TestJsonMessageChannel(object):
def wait(self):
self._worker.join()
def fizz_event(self, channel, body):
def fizz_event(self, event):
assert event.event == 'fizz'
with self.lock:
self.received.append(('event', 'fizz', body))
self.received.append(('event', 'fizz', event.body))
def buzz_event(self, channel, body):
def buzz_event(self, event):
assert event.event == 'buzz'
with self.lock:
self.received.append(('event', 'buzz', body))
self.received.append(('event', 'buzz', event.body))
def event(self, channel, event, body):
def event(self, event):
with self.lock:
self.received.append(('event', event, body))
self.received.append(('event', event.event, event.body))
def make_and_log_response(self, request):
x = random.randint(-100, 100)
if x >= 0:
response = Response(request, x)
else:
response = Response(request, RequestFailure(str(x)))
if x < 0:
x = RequestFailure(str(x))
with self.lock:
self.responses_sent.append(response)
if response.success:
return x
else:
raise response.body
self.responses_sent.append((request.seq, x))
return x
def fizz_request(self, channel, arguments):
def fizz_request(self, request):
assert request.command == 'fizz'
with self.lock:
self.received.append(('request', 'fizz', arguments))
return self.make_and_log_response('fizz')
self.received.append(('request', 'fizz', request.arguments))
return self.make_and_log_response(request)
def buzz_request(self, channel, arguments):
def buzz_request(self, request):
assert request.command == 'buzz'
with self.lock:
self.received.append(('request', 'buzz', arguments))
return self.make_and_log_response('buzz')
self.received.append(('request', 'buzz', request.arguments))
return self.make_and_log_response(request)
def request(self, channel, command, arguments):
def request(self, request):
with self.lock:
self.received.append(('request', command, arguments))
return self.make_and_log_response(command)
self.received.append(('request', request.command, request.arguments))
return self.make_and_log_response(request)
def _send_requests_and_events(self, channel):
pending_requests = [0]
@ -291,9 +294,9 @@ class TestJsonMessageChannel(object):
with self.lock:
pending_requests[0] += 1
req = channel.send_request(name, body)
def response_handler(body):
def response_handler(response):
with self.lock:
self.responses_received.append(Response(req, body))
self.responses_received.append((response.request.seq, response.body))
pending_requests[0] -= 1
req.on_response(response_handler)
# Spin until we get responses to all requests.