mirror of
https://github.com/python/cpython.git
synced 2025-08-30 21:48:47 +00:00
Fix issue 9794: adds context manager protocol to socket.socket so that socket.create_connection() can be used with the 'with' statement.
This commit is contained in:
parent
7c9cf01238
commit
b383dbb45e
4 changed files with 60 additions and 0 deletions
|
@ -213,6 +213,9 @@ The module :mod:`socket` exports the following constants and functions:
|
||||||
.. versionchanged:: 3.2
|
.. versionchanged:: 3.2
|
||||||
*source_address* was added.
|
*source_address* was added.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.2
|
||||||
|
support for the :keyword:`with` statement was added.
|
||||||
|
|
||||||
|
|
||||||
.. function:: getaddrinfo(host, port, family=0, type=0, proto=0, flags=0)
|
.. function:: getaddrinfo(host, port, family=0, type=0, proto=0, flags=0)
|
||||||
|
|
||||||
|
|
|
@ -389,6 +389,12 @@ New, Improved, and Deprecated Modules
|
||||||
|
|
||||||
(Contributed by Giampaolo Rodolà; :issue:`8807`.)
|
(Contributed by Giampaolo Rodolà; :issue:`8807`.)
|
||||||
|
|
||||||
|
* :func:`socket.create_connection` now supports the context manager protocol
|
||||||
|
to unconditionally consume :exc:`socket.error` exceptions and to close the
|
||||||
|
socket when done.
|
||||||
|
|
||||||
|
(Contributed by Giampaolo Rodolà; :issue:`9794`.)
|
||||||
|
|
||||||
|
|
||||||
Multi-threading
|
Multi-threading
|
||||||
===============
|
===============
|
||||||
|
|
|
@ -93,6 +93,13 @@ class socket(_socket.socket):
|
||||||
self._io_refs = 0
|
self._io_refs = 0
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
if not self._closed:
|
||||||
|
self.close()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Wrap __repr__() to reveal the real class name."""
|
"""Wrap __repr__() to reveal the real class name."""
|
||||||
s = _socket.socket.__repr__(self)
|
s = _socket.socket.__repr__(self)
|
||||||
|
|
|
@ -1595,6 +1595,49 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
|
||||||
self.cli.close()
|
self.cli.close()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipUnless(thread, 'Threading required for this test.')
|
||||||
|
class ContextManagersTest(ThreadedTCPSocketTest):
|
||||||
|
|
||||||
|
def _testSocketClass(self):
|
||||||
|
# base test
|
||||||
|
with socket.socket() as sock:
|
||||||
|
self.assertFalse(sock._closed)
|
||||||
|
self.assertTrue(sock._closed)
|
||||||
|
# close inside with block
|
||||||
|
with socket.socket() as sock:
|
||||||
|
sock.close()
|
||||||
|
self.assertTrue(sock._closed)
|
||||||
|
# exception inside with block
|
||||||
|
with socket.socket() as sock:
|
||||||
|
self.assertRaises(socket.error, sock.sendall, b'foo')
|
||||||
|
self.assertTrue(sock._closed)
|
||||||
|
|
||||||
|
def testCreateConnectionBase(self):
|
||||||
|
conn, addr = self.serv.accept()
|
||||||
|
data = conn.recv(1024)
|
||||||
|
conn.sendall(data)
|
||||||
|
|
||||||
|
def _testCreateConnectionBase(self):
|
||||||
|
address = self.serv.getsockname()
|
||||||
|
with socket.create_connection(address) as sock:
|
||||||
|
self.assertFalse(sock._closed)
|
||||||
|
sock.sendall(b'foo')
|
||||||
|
self.assertEqual(sock.recv(1024), b'foo')
|
||||||
|
self.assertTrue(sock._closed)
|
||||||
|
|
||||||
|
def testCreateConnectionClose(self):
|
||||||
|
conn, addr = self.serv.accept()
|
||||||
|
data = conn.recv(1024)
|
||||||
|
conn.sendall(data)
|
||||||
|
|
||||||
|
def _testCreateConnectionClose(self):
|
||||||
|
address = self.serv.getsockname()
|
||||||
|
with socket.create_connection(address) as sock:
|
||||||
|
sock.close()
|
||||||
|
self.assertTrue(sock._closed)
|
||||||
|
self.assertRaises(socket.error, sock.sendall, b'foo')
|
||||||
|
|
||||||
|
|
||||||
def test_main():
|
def test_main():
|
||||||
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
|
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
|
||||||
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
|
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
|
||||||
|
@ -1609,6 +1652,7 @@ def test_main():
|
||||||
NetworkConnectionNoServer,
|
NetworkConnectionNoServer,
|
||||||
NetworkConnectionAttributesTest,
|
NetworkConnectionAttributesTest,
|
||||||
NetworkConnectionBehaviourTest,
|
NetworkConnectionBehaviourTest,
|
||||||
|
ContextManagersTest,
|
||||||
])
|
])
|
||||||
if hasattr(socket, "socketpair"):
|
if hasattr(socket, "socketpair"):
|
||||||
tests.append(BasicSocketPairTest)
|
tests.append(BasicSocketPairTest)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue