Merge pull request #97 from ericsnowcurrently/close-socket

Ensure that all the sockets get closed after tests.
This commit is contained in:
Eric Snow 2018-02-20 14:59:11 -07:00 committed by GitHub
commit 3b4d2f310a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 334 additions and 134 deletions

View file

@ -95,12 +95,15 @@ class EventLoop(object):
self._event = threading.Event()
self._event.set()
self._stop = False
def create_future(self):
return Future(self)
def run_forever(self):
while True:
self._event.wait()
while not self._stop:
if not self._event.wait(timeout=0.1):
continue
with self._lock:
queue = self._queue
self._queue = []
@ -108,6 +111,9 @@ class EventLoop(object):
for (f, args) in queue:
f(*args)
def stop(self):
self._stop = True
def call_soon(self, f, *args):
with self._lock:
self._queue.append((f, args))

View file

@ -11,9 +11,10 @@ from __future__ import print_function, with_statement, absolute_import
# the main thread. This will cause issues when the thread goes away
# after attach completes.
import errno
import itertools
import json
import os.path
import itertools
import socket
import sys
import traceback
@ -93,8 +94,11 @@ class SocketIO(object):
self.__logfile.write(content)
self.__logfile.write('\n'.encode('utf-8'))
self.__logfile.flush()
self.__socket.send(headers)
self.__socket.send(content)
try:
self.__socket.send(headers)
self.__socket.send(content)
except BrokenPipeError:
pass
def _buffered_read_line_as_ascii(self):
"""Return the next line from the buffer as a string.
@ -279,7 +283,12 @@ class IpcChannel(object):
try:
msg = self.__message.pop(0)
except IndexError:
self._wait_for_message()
try:
self._wait_for_message()
except OSError as exc:
if exc.errno == errno.EBADF: # socket closed
return self.__exit
raise
try:
msg = self.__message.pop(0)
except IndexError:

View file

@ -156,6 +156,8 @@ class PydevdSocket(object):
awaited.
"""
_vscprocessor = None
def __init__(self, event_handler):
#self.log = open('pydevd.log', 'w')
self.event_handler = event_handler
@ -164,13 +166,36 @@ class PydevdSocket(object):
self.pipe_r, self.pipe_w = os.pipe()
self.requests = {}
self._closed = False
self._closing = False
def close(self):
# TODO: docstring
pass
"""Mark the socket as closed and release any resources."""
if self._closing:
return
with self.lock:
if self._closed:
return
self._closing = True
if self.pipe_w is not None:
pipe_w = self.pipe_w
self.pipe_w = None
os.close(pipe_w)
if self.pipe_r is not None:
pipe_r = self.pipe_r
self.pipe_r = None
os.close(pipe_r)
if self._vscprocessor is not None:
proc = self._vscprocessor
self._vscprocessor = None
proc.close()
self._closed = True
self._closing = False
def shutdown(self, mode):
# TODO: docstring
pass
"""Called when pydevd has stopped."""
def recv(self, count):
"""Return the requested number of bytes.
@ -358,17 +383,33 @@ class VSCodeMessageProcessor(ipcjson.SocketIO, ipcjson.IpcChannel):
self.loop = futures.EventLoop()
self.exceptions_mgr = ExceptionsManager(self)
pydevd._vscprocessor = self
self._closed = False
t = threading.Thread(target=self.loop.run_forever,
name='ptvsd.EventLoop')
t.daemon = True
t.start()
def close(self):
# TODO: docstring
"""Stop the message processor and release its resources."""
if self._closed:
return
self._closed = True
pydevd = self.pydevd
self.pydevd = None
pydevd.shutdown(socket.SHUT_RDWR)
pydevd.close()
global ptvsd_sys_exit_code
self.send_event('exited', exitCode=ptvsd_sys_exit_code)
self.send_event('terminated')
self.loop.stop()
if self.socket:
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
def pydevd_notify(self, cmd_id, args):
@ -841,6 +882,45 @@ class VSCodeMessageProcessor(ipcjson.SocketIO, ipcjson.IpcChannel):
pass
########################
# lifecycle
def _create_server(port):
server = _new_sock()
server.bind(('127.0.0.1', port))
server.listen(1)
return server
def _create_client():
return _new_sock()
def _new_sock():
sock = socket.socket(socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock
def _start(client, server):
name = 'ptvsd.Client' if server is None else 'ptvsd.Server'
pydevd = PydevdSocket(lambda *args: proc.on_pydevd_event(*args))
proc = VSCodeMessageProcessor(client, pydevd)
server_thread = threading.Thread(target=proc.process_messages,
name=name)
server_thread.daemon = True
server_thread.start()
return pydevd, proc, server_thread
########################
# pydevd hooks
def start_server(port):
"""Return a socket to a (new) local pydevd-handling daemon.
@ -849,24 +929,10 @@ def start_server(port):
This is a replacement fori _pydevd_bundle.pydevd_comm.start_server.
"""
server = socket.socket(socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind(('127.0.0.1', port))
server.listen(1)
client, addr = server.accept()
pydevd = PydevdSocket(lambda *args: proc.on_pydevd_event(*args))
proc = VSCodeMessageProcessor(client, pydevd)
server_thread = threading.Thread(target=proc.process_messages,
name='ptvsd.Server')
server_thread.daemon = True
server_thread.start()
server = _create_server(port)
client, _ = server.accept()
pydevd, proc, _ = _start(client, server)
atexit.register(proc.close)
return pydevd
@ -878,22 +944,10 @@ def start_client(host, port):
This is a replacement fori _pydevd_bundle.pydevd_comm.start_client.
"""
client = socket.socket(socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP)
client.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
client = _create_client()
client.connect((host, port))
pydevd = PydevdSocket(lambda *args: proc.on_pydevd_event(*args))
proc = VSCodeMessageProcessor(client, pydevd)
server_thread = threading.Thread(target=proc.process_messages,
name='ptvsd.Client')
server_thread.daemon = True
server_thread.start()
pydevd, proc, _ = _start(client, None)
atexit.register(proc.close)
return pydevd

View file

@ -1,4 +1,5 @@
from collections import namedtuple
import errno
import threading
from . import socket
@ -42,19 +43,33 @@ class MessageProtocol(namedtuple('Protocol', 'parse encode iter')):
class Started(object):
"""A simple wrapper around a started message protocol daemon."""
def __init__(self, fake):
def __init__(self, fake, address, starting=None):
self.fake = fake
self.address = address
self._starting = starting
def __enter__(self):
self.wait_until_connected()
return self
def __exit__(self, *args):
self.close()
def wait_until_connected(self, timeout=None):
starting = self._starting
if starting is None:
return
starting.join(timeout=timeout)
if starting.is_alive():
raise RuntimeError('timed out')
self._starting = None
def send_message(self, msg):
return self.fake.send_response(msg)
self.wait_until_connected()
return self.fake.send_message(msg)
def close(self):
self.wait_until_connected()
self.fake.close()
@ -68,8 +83,8 @@ class Daemon(object):
"""Ensure the message is legitimate."""
# By default check nothing.
def __init__(self, connect, protocol, handler):
self._connect = connect
def __init__(self, bind, protocol, handler):
self._bind = bind
self._protocol = protocol
self._closed = False
@ -80,10 +95,8 @@ class Daemon(object):
self._default_handler = handler
# These are set when we start.
self._host = None
self._port = None
self._address = None
self._sock = None
self._server = None
self._listener = None
@property
@ -102,18 +115,17 @@ class Daemon(object):
"""All send/recv failures thus far."""
return list(self._failures)
def start(self, host, port):
def start(self, address):
"""Start the fake daemon.
This calls the earlier provided connect() function.
This calls the earlier provided bind() function.
A listener loop is started in another thread to handle incoming
messages from the socket.
"""
self._host = host or None
self._port = port
self._start()
return self.STARTED(self)
self._address = address
addr, starting = self._start(address)
return self.STARTED(self, addr, starting)
def send_message(self, msg):
"""Serialize msg to the line format and send it to the socket."""
@ -149,23 +161,38 @@ class Daemon(object):
# internal methods
def _start(self, host=None):
self._sock, self._server = self._connect(
host or self._host,
self._port,
)
def _start(self, address):
connect, addr = self._bind(address)
# TODO: make it a daemon thread?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
def run():
self._sock = connect()
# TODO: make it a daemon thread?
self._listener = threading.Thread(target=self._listen)
self._listener.start()
t = threading.Thread(target=run)
t.start()
return addr, t
def _listen(self):
with self._sock.makefile('rb') as sockfile:
for msg in self._protocol.iter(sockfile, lambda: self._closed):
if isinstance(msg, StreamFailure):
self._failures.append(msg)
else:
self._add_received(msg)
try:
with self._sock.makefile('rb') as sockfile:
for msg in self._protocol.iter(sockfile, lambda: self._closed):
if isinstance(msg, StreamFailure):
self._failures.append(msg)
else:
self._add_received(msg)
except BrokenPipeError:
if self._closed:
return
# TODO: try reconnecting?
raise
except OSError as exc:
if exc.errno == 9: # socket closed
return
if exc.errno == errno.EBADF: # socket closed
return
# TODO: try reconnecting?
raise
def _add_received(self, msg):
self._received.append(msg)
@ -196,6 +223,7 @@ class Daemon(object):
try:
self._send(raw)
except Exception as exc:
raise
failure = StreamFailure('send', msg, exc)
self._failures.append(failure)
@ -208,9 +236,6 @@ class Daemon(object):
if self._sock is not None:
socket.close(self._sock)
self._sock = None
if self._server is not None:
socket.close(self._server)
self._server = None
if self._listener is not None:
self._listener.join(timeout=1)
# TODO: the listener isn't stopping!

View file

@ -2,9 +2,9 @@ from _pydevd_bundle.pydevd_comm import (
CMD_VERSION,
)
from ptvsd.wrapper import start_server, start_client
import ptvsd.wrapper as _ptvsd
from ._pydevd import parse_message, encode_message, iter_messages, Message
from tests.helpers import protocol
from tests.helpers import protocol, socket
PROTOCOL = protocol.MessageProtocol(
@ -14,19 +14,24 @@ PROTOCOL = protocol.MessageProtocol(
)
def _connect(host, port):
if host is None:
return start_server(port), None
else:
return start_client(host, port), None
def _bind(address):
connect, remote = socket.bind(address)
def connect(_connect=connect):
client, server = _connect()
pydevd, _, _ = _ptvsd._start(client, server)
return socket.Connection(pydevd, server)
return connect, remote
class Started(protocol.Started):
def send_response(self, msg):
self.wait_until_connected()
return self.fake.send_response(msg)
def send_event(self, msg):
self.wait_until_connected()
return self.fake.send_event(msg)
@ -92,7 +97,7 @@ class FakePyDevd(protocol.Daemon):
def __init__(self, handler=None):
super(FakePyDevd, self).__init__(
_connect,
_bind,
PROTOCOL,
(lambda msg, send: self.handle_request(msg, send, handler)),
)

View file

@ -1,38 +1,107 @@
from __future__ import absolute_import
from collections import namedtuple
import socket
import ptvsd.wrapper as _ptvsd
def connect(host, port):
"""Return (client, server) after connecting.
If host is None then it's a server, so it will wait for a connection
on localhost. Otherwise it will connect to the remote host.
def create_server(address):
"""Return a server socket after binding."""
host, port = address
return _ptvsd._create_server(port)
def create_client():
"""Return a new (unconnected) client socket."""
return _ptvsd._create_client()
def connect(sock, address):
"""Return a client socket after connecting.
If address is None then it's a server, so it will wait for a
connection. Otherwise it will connect to the remote host.
"""
sock = socket.socket(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
)
sock.setsockopt(
socket.SOL_SOCKET,
socket.SO_REUSEADDR,
1,
)
return _connect(sock, address)
def bind(address):
"""Return (connect, remote addr) for the given address.
"connect" is a function with no args that returns (client, server),
which are sockets. If the host is None then a server socket will
be created bound to localhost, and that server socket will be
returned from connect(). Otherwise a client socket is connected to
the remote address and None is returned from connect() for the
server.
"""
host, _ = address
if host is None:
addr = ('127.0.0.1', port)
sock = create_server(address)
server = sock
server.bind(addr)
server.listen(1)
sock, _ = server.accept()
connect_to = None
remote = sock.getsockname()
else:
addr = (host, port)
sock.connect(addr)
sock = create_client()
server = None
return sock, server
connect_to = address
remote = address
def connect():
client = _connect(sock, connect_to)
return client, server
return connect, remote
def close(sock):
"""Shutdown and close the socket."""
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class Connection(namedtuple('Connection', 'client server')):
"""A wrapper around a client socket.
If a server socket is provided then it will be closed when the
client is closed.
"""
def __new__(cls, client, server=None):
self = super(Connection, cls).__new__(
cls,
client,
server,
)
return self
def send(self, *args, **kwargs):
return self.client.send(*args, **kwargs)
def recv(self, *args, **kwargs):
return self.client.recv(*args, **kwargs)
def makefile(self, *args, **kwargs):
return self.client.makefile(*args, **kwargs)
def shutdown(self, *args, **kwargs):
if self.server is not None:
self.server.shutdown(*args, **kwargs)
self.client.shutdown(*args, **kwargs)
def close(self):
if self.server is not None:
self.server.close()
self.client.close()
########################
# internal functions
def _connect(sock, address):
if address is None:
client, _ = sock.accept()
else:
sock.connect(address)
client = sock
return client

View file

@ -12,9 +12,19 @@ PROTOCOL = protocol.MessageProtocol(
)
def _bind(address):
connect, remote = socket.bind(address)
def connect(_connect=connect):
client, server = _connect()
return socket.Connection(client, server)
return connect, remote
class Started(protocol.Started):
def send_request(self, msg):
self.wait_until_connected()
return self.fake.send_request(msg)
@ -52,19 +62,23 @@ class FakeVSC(protocol.Daemon):
PROTOCOL = PROTOCOL
def __init__(self, start_adapter, handler=None):
super(FakeVSC, self).__init__(socket.connect, PROTOCOL, handler)
def start_adapter(host, port, _start_adapter=start_adapter):
self._adapter = _start_adapter(host, port)
super(FakeVSC, self).__init__(
_bind,
PROTOCOL,
handler,
)
def start_adapter(address, start=start_adapter):
self._adapter = start(address)
return self._adapter
self._start_adapter = start_adapter
self._adapter = None
def start(self, host, port):
def start(self, address):
"""Start the fake and the adapter."""
if self._adapter is not None:
raise RuntimeError('already started')
return super(FakeVSC, self).start(host, port)
return super(FakeVSC, self).start(address)
def send_request(self, req):
"""Send the given Request object."""
@ -94,26 +108,20 @@ class FakeVSC(protocol.Daemon):
# internal methods
def _start(self, host=None):
start_adapter = (lambda: self._start_adapter(self._host, self._port))
if not self._host:
def _start(self, address):
host, port = address
if host is None:
# The adapter is the server so start it first.
t = threading.Thread(target=start_adapter)
t.start()
super(FakeVSC, self)._start('127.0.0.1')
t.join(timeout=1)
if t.is_alive():
raise RuntimeError('timed out')
adapter = self._start_adapter((None, port))
return super(FakeVSC, self)._start(adapter.address)
else:
# The adapter is the client so start it last.
# TODO: For now don't use this.
raise NotImplementedError
t = threading.Thread(target=super(FakeVSC, self)._start)
t.start()
start_adapter()
t.join(timeout=1)
if t.is_alive():
raise RuntimeError('timed out')
addr, starting = super(FakeVSC, self)._start(address)
self._start_adapter(addr)
# TODO Wait for adapter to be ready?
return addr, starting
def _close(self):
if self._adapter is not None:

View file

@ -113,6 +113,8 @@ class VSCMessages(object):
class VSCLifecycle(object):
PORT = 8888
MIN_INITIALIZE_ARGS = {
'adapterID': '<an adapter ID>',
}
@ -120,12 +122,12 @@ class VSCLifecycle(object):
def __init__(self, fix):
self._fix = fix
def launched(self, port=8888, **kwargs):
def launched(self, port=None, **kwargs):
def start():
self.launch(**kwargs)
return self._started(start, port)
def attached(self, port=8888, **kwargs):
def attached(self, port=None, **kwargs):
def start():
self.attach(**kwargs)
return self._started(start, port)
@ -147,7 +149,10 @@ class VSCLifecycle(object):
@contextlib.contextmanager
def _started(self, start, port):
with self._fix.vsc.start(None, port):
if port is None:
port = self.PORT
addr = (None, port)
with self._fix.vsc.start(addr):
with self._fix.disconnect_when_done():
start()
yield
@ -167,8 +172,8 @@ class VSCLifecycle(object):
self._send_request('configurationDone')
next(self._fix.vsc_msgs.event_seq)
assert(self._fix.vsc.failures == [])
assert(self._fix.debugger.failures == [])
assert self._fix.vsc.failures == [], self._fix.vsc.failures
assert self._fix.debugger.failures == [], self._fix.debugger.failures
if reset:
self._fix.vsc.reset()
self._fix.debugger.reset()
@ -274,7 +279,7 @@ class HighlevelFixture(object):
yield
finally:
self.send_request('disconnect')
self.vsc._received.pop(-1)
#self.vsc._received.pop(-1)
class HighlevelTest(object):
@ -293,6 +298,8 @@ class HighlevelTest(object):
return vsc
self.fix = self.FIXTURE(new_daemon)
self.maxDiff = None
def __getattr__(self, name):
return getattr(self.fix, name)
@ -313,6 +320,11 @@ class HighlevelTest(object):
vsc, debugger = self.fix.new_fake(debugger, handler)
return vsc, debugger
def assert_vsc_received(self, received, expected):
received = list(self.vsc.protocol.parse_each(received))
expected = list(self.vsc.protocol.parse_each(expected))
self.assertEqual(received, expected)
def assert_received(self, daemon, expected):
"""Ensure that the received messages match the expected ones."""
received = list(daemon.protocol.parse_each(daemon.received))

View file

@ -27,7 +27,8 @@ class LifecycleTests(HighlevelTest, unittest.TestCase):
def test_attach(self):
version = self.debugger.VERSION
with self.vsc.start(None, 8888):
addr = (None, 8888)
with self.vsc.start(addr):
with self.vsc.wait_for_event('initialized'):
# initialize
self.set_debugger_response(CMD_VERSION, version)
@ -46,6 +47,7 @@ class LifecycleTests(HighlevelTest, unittest.TestCase):
# end
req_disconnect = self.send_request('disconnect')
# An "exited" event comes once self.vsc closes.
self.assert_received(self.vsc, [
self.new_response(req_initialize, **dict(
@ -77,6 +79,7 @@ class LifecycleTests(HighlevelTest, unittest.TestCase):
startMethod='attach',
)),
self.new_response(req_disconnect),
self.new_event('exited', exitCode=0),
])
self.assert_received(self.debugger, [
self.debugger_msgs.new_request(CMD_VERSION,
@ -86,7 +89,8 @@ class LifecycleTests(HighlevelTest, unittest.TestCase):
def test_launch(self):
version = self.debugger.VERSION
with self.vsc.start(None, 8888):
addr = (None, 8888)
with self.vsc.start(addr):
with self.vsc.wait_for_event('initialized'):
# initialize
self.set_debugger_response(CMD_VERSION, version)
@ -105,6 +109,7 @@ class LifecycleTests(HighlevelTest, unittest.TestCase):
# end
req_disconnect = self.send_request('disconnect')
# An "exited" event comes once self.vsc closes.
self.assert_received(self.vsc, [
self.new_response(req_initialize, **dict(
@ -136,6 +141,7 @@ class LifecycleTests(HighlevelTest, unittest.TestCase):
startMethod='launch',
)),
self.new_response(req_disconnect),
self.new_event('exited', exitCode=0),
])
self.assert_received(self.debugger, [
self.debugger_msgs.new_request(CMD_VERSION,

View file

@ -83,14 +83,16 @@ class InitializeTests(LifecycleTest, unittest.TestCase):
@unittest.skip('tested via test_lifecycle.py')
def test_basic(self):
version = self.debugger.VERSION
with self.vsc.start(None, 8888):
addr = (None, 8888)
with self.vsc.start(addr):
with self.disconnect_when_done():
self.set_debugger_response(CMD_VERSION, version)
req = self.send_request('initialize', {
'adapterID': 'spam',
})
received = self.vsc.received
self.assert_received(self.vsc, [
self.assert_vsc_received(received, [
self.new_response(req, **dict(
supportsExceptionInfoRequest=True,
supportsConfigurationDoneRequest=True,
@ -151,7 +153,10 @@ class NormalRequestTest(RunningTest):
)
def expected_pydevd_request(self, *args):
return self.debugger_msgs.new_request(self.PYDEVD_CMD, *args)
if self.PYDEVD_REQ is not None:
return self.debugger_msgs.new_request(self.PYDEVD_REQ, *args)
else:
return self.debugger_msgs.new_request(self.PYDEVD_CMD, *args)
class ThreadsTests(NormalRequestTest, unittest.TestCase):
@ -174,8 +179,9 @@ class ThreadsTests(NormalRequestTest, unittest.TestCase):
(12, ''),
)
self.send_request()
received = self.vsc.received
self.assert_received(self.vsc, [
self.assert_vsc_received(received, [
self.expected_response(
threads=[
{'id': 1, 'name': 'spam'},