mirror of
https://github.com/python/cpython.git
synced 2025-11-27 13:45:25 +00:00
asyncio: sync with Tulip
* _SelectorTransport constructor: extra parameter is now optional * Fix _SelectorDatagramTransport constructor. Only start reading after connection_made() has been called. * Fix _SelectorSslTransport.close(). Don't call protocol.connection_lost() if protocol.connection_made() was not called yet: if the SSL handshake failed or is still in progress. The close() method can be called if the creation of the connection is cancelled, by a timeout for example.
This commit is contained in:
parent
7b5a900e88
commit
47bbea7124
2 changed files with 24 additions and 4 deletions
|
|
@ -467,7 +467,7 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
|
||||
_buffer_factory = bytearray # Constructs initial value for self._buffer.
|
||||
|
||||
def __init__(self, loop, sock, protocol, extra, server=None):
|
||||
def __init__(self, loop, sock, protocol, extra=None, server=None):
|
||||
super().__init__(extra, loop)
|
||||
self._extra['socket'] = sock
|
||||
self._extra['sockname'] = sock.getsockname()
|
||||
|
|
@ -479,6 +479,7 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
self._sock = sock
|
||||
self._sock_fd = sock.fileno()
|
||||
self._protocol = protocol
|
||||
self._protocol_connected = True
|
||||
self._server = server
|
||||
self._buffer = self._buffer_factory()
|
||||
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
||||
|
|
@ -555,6 +556,7 @@ class _SelectorTransport(transports._FlowControlMixin,
|
|||
|
||||
def _call_connection_lost(self, exc):
|
||||
try:
|
||||
if self._protocol_connected:
|
||||
self._protocol.connection_lost(exc)
|
||||
finally:
|
||||
self._sock.close()
|
||||
|
|
@ -718,6 +720,8 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
|
||||
|
||||
super().__init__(loop, sslsock, protocol, extra, server)
|
||||
# the protocol connection is only made after the SSL handshake
|
||||
self._protocol_connected = False
|
||||
|
||||
self._server_hostname = server_hostname
|
||||
self._waiter = waiter
|
||||
|
|
@ -797,6 +801,7 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._read_wants_write = False
|
||||
self._write_wants_read = False
|
||||
self._loop.add_reader(self._sock_fd, self._read_ready)
|
||||
self._protocol_connected = True
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
# only wake up the waiter when connection_made() has been called
|
||||
self._loop.call_soon(self._wakeup_waiter)
|
||||
|
|
@ -928,8 +933,10 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
|||
waiter=None, extra=None):
|
||||
super().__init__(loop, sock, protocol, extra)
|
||||
self._address = address
|
||||
self._loop.add_reader(self._sock_fd, self._read_ready)
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
# only start reading when connection_made() has been called
|
||||
self._loop.call_soon(self._loop.add_reader,
|
||||
self._sock_fd, self._read_ready)
|
||||
if waiter is not None:
|
||||
# only wake up the waiter when connection_made() has been called
|
||||
self._loop.call_soon(waiter._set_result_unless_cancelled, None)
|
||||
|
|
|
|||
|
|
@ -1427,7 +1427,7 @@ class SelectorSslTransportTests(test_utils.TestCase):
|
|||
self.assertFalse(tr.can_write_eof())
|
||||
self.assertRaises(NotImplementedError, tr.write_eof)
|
||||
|
||||
def test_close(self):
|
||||
def check_close(self):
|
||||
tr = self._make_one()
|
||||
tr.close()
|
||||
|
||||
|
|
@ -1439,6 +1439,19 @@ class SelectorSslTransportTests(test_utils.TestCase):
|
|||
self.assertEqual(tr._conn_lost, 1)
|
||||
self.assertEqual(1, self.loop.remove_reader_count[1])
|
||||
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
def test_close(self):
|
||||
self.check_close()
|
||||
self.assertTrue(self.protocol.connection_made.called)
|
||||
self.assertTrue(self.protocol.connection_lost.called)
|
||||
|
||||
def test_close_not_connected(self):
|
||||
self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
|
||||
self.check_close()
|
||||
self.assertFalse(self.protocol.connection_made.called)
|
||||
self.assertFalse(self.protocol.connection_lost.called)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No SSL support')
|
||||
def test_server_hostname(self):
|
||||
self.ssl_transport(server_hostname='localhost')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue