Fixed #30171 -- Fixed DatabaseError in servers tests.

Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.

The error appeared after 8c775391b7.
This commit is contained in:
Jon Dufresne 2019-02-14 07:04:55 -08:00 committed by Tim Graham
parent 21f9d43737
commit 76990cbbda
7 changed files with 100 additions and 66 deletions

View file

@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
connection = connections[DEFAULT_DB_ALIAS]
# Allow thread sharing so the connection can be closed by the
# main thread.
connection.allow_thread_sharing = True
connection.inc_thread_sharing()
connection.cursor()
connections_dict[id(connection)] = connection
for x in range(2):
t = threading.Thread(target=runner)
t.start()
t.join()
# Each created connection got different inner connection.
self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
for conn in connections_dict.values():
if conn is not connection:
conn.close()
try:
for x in range(2):
t = threading.Thread(target=runner)
t.start()
t.join()
# Each created connection got different inner connection.
self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
finally:
# Finish by closing the connections opened by the other threads
# (the connection opened in the main thread will automatically be
# closed on teardown).
for conn in connections_dict.values():
if conn is not connection:
if conn.allow_thread_sharing:
conn.close()
conn.dec_thread_sharing()
def test_connections_thread_local(self):
"""
@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
for conn in connections.all():
# Allow thread sharing so the connection can be closed by the
# main thread.
conn.allow_thread_sharing = True
conn.inc_thread_sharing()
connections_dict[id(conn)] = conn
for x in range(2):
t = threading.Thread(target=runner)
t.start()
t.join()
self.assertEqual(len(connections_dict), 6)
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
for conn in connections_dict.values():
if conn is not connection:
conn.close()
try:
for x in range(2):
t = threading.Thread(target=runner)
t.start()
t.join()
self.assertEqual(len(connections_dict), 6)
finally:
# Finish by closing the connections opened by the other threads
# (the connection opened in the main thread will automatically be
# closed on teardown).
for conn in connections_dict.values():
if conn is not connection:
if conn.allow_thread_sharing:
conn.close()
conn.dec_thread_sharing()
def test_pass_connection_between_threads(self):
"""
@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
t.start()
t.join()
# Without touching allow_thread_sharing, which should be False by default.
# Without touching thread sharing, which should be False by default.
exceptions = []
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
# If explicitly setting allow_thread_sharing to False
connections['default'].allow_thread_sharing = False
exceptions = []
do_thread()
# Forbidden!
self.assertIsInstance(exceptions[0], DatabaseError)
# If explicitly setting allow_thread_sharing to True
connections['default'].allow_thread_sharing = True
exceptions = []
do_thread()
# All good
self.assertEqual(exceptions, [])
# After calling inc_thread_sharing() on the connection.
connections['default'].inc_thread_sharing()
try:
exceptions = []
do_thread()
# All good
self.assertEqual(exceptions, [])
finally:
connections['default'].dec_thread_sharing()
def test_closing_non_shared_connections(self):
"""
@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
except DatabaseError as e:
exceptions.add(e)
# Enable thread sharing
connections['default'].allow_thread_sharing = True
t2 = threading.Thread(target=runner2, args=[connections['default']])
t2.start()
t2.join()
connections['default'].inc_thread_sharing()
try:
t2 = threading.Thread(target=runner2, args=[connections['default']])
t2.start()
t2.join()
finally:
connections['default'].dec_thread_sharing()
t1 = threading.Thread(target=runner1)
t1.start()
t1.join()
# No exception was raised
self.assertEqual(len(exceptions), 0)
def test_thread_sharing_count(self):
self.assertIs(connection.allow_thread_sharing, False)
connection.inc_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.inc_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.dec_thread_sharing()
self.assertIs(connection.allow_thread_sharing, True)
connection.dec_thread_sharing()
self.assertIs(connection.allow_thread_sharing, False)
msg = 'Cannot decrement the thread sharing count below zero.'
with self.assertRaisesMessage(RuntimeError, msg):
connection.dec_thread_sharing()
class MySQLPKZeroTests(TestCase):
"""