mirror of
https://github.com/django/django.git
synced 2025-08-30 23:37:50 +00:00
Fixed #30171 -- Fixed DatabaseError in servers tests.
Made DatabaseWrapper thread sharing logic reentrant. Used a reference
counting like scheme to allow nested uses.
The error appeared after 8c775391b7
.
This commit is contained in:
parent
21f9d43737
commit
76990cbbda
7 changed files with 100 additions and 66 deletions
|
@ -605,21 +605,25 @@ class ThreadTests(TransactionTestCase):
|
|||
connection = connections[DEFAULT_DB_ALIAS]
|
||||
# Allow thread sharing so the connection can be closed by the
|
||||
# main thread.
|
||||
connection.allow_thread_sharing = True
|
||||
connection.inc_thread_sharing()
|
||||
connection.cursor()
|
||||
connections_dict[id(connection)] = connection
|
||||
for x in range(2):
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
# Each created connection got different inner connection.
|
||||
self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
|
||||
# Finish by closing the connections opened by the other threads (the
|
||||
# connection opened in the main thread will automatically be closed on
|
||||
# teardown).
|
||||
for conn in connections_dict.values():
|
||||
if conn is not connection:
|
||||
conn.close()
|
||||
try:
|
||||
for x in range(2):
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
# Each created connection got different inner connection.
|
||||
self.assertEqual(len({conn.connection for conn in connections_dict.values()}), 3)
|
||||
finally:
|
||||
# Finish by closing the connections opened by the other threads
|
||||
# (the connection opened in the main thread will automatically be
|
||||
# closed on teardown).
|
||||
for conn in connections_dict.values():
|
||||
if conn is not connection:
|
||||
if conn.allow_thread_sharing:
|
||||
conn.close()
|
||||
conn.dec_thread_sharing()
|
||||
|
||||
def test_connections_thread_local(self):
|
||||
"""
|
||||
|
@ -636,19 +640,23 @@ class ThreadTests(TransactionTestCase):
|
|||
for conn in connections.all():
|
||||
# Allow thread sharing so the connection can be closed by the
|
||||
# main thread.
|
||||
conn.allow_thread_sharing = True
|
||||
conn.inc_thread_sharing()
|
||||
connections_dict[id(conn)] = conn
|
||||
for x in range(2):
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
self.assertEqual(len(connections_dict), 6)
|
||||
# Finish by closing the connections opened by the other threads (the
|
||||
# connection opened in the main thread will automatically be closed on
|
||||
# teardown).
|
||||
for conn in connections_dict.values():
|
||||
if conn is not connection:
|
||||
conn.close()
|
||||
try:
|
||||
for x in range(2):
|
||||
t = threading.Thread(target=runner)
|
||||
t.start()
|
||||
t.join()
|
||||
self.assertEqual(len(connections_dict), 6)
|
||||
finally:
|
||||
# Finish by closing the connections opened by the other threads
|
||||
# (the connection opened in the main thread will automatically be
|
||||
# closed on teardown).
|
||||
for conn in connections_dict.values():
|
||||
if conn is not connection:
|
||||
if conn.allow_thread_sharing:
|
||||
conn.close()
|
||||
conn.dec_thread_sharing()
|
||||
|
||||
def test_pass_connection_between_threads(self):
|
||||
"""
|
||||
|
@ -668,25 +676,21 @@ class ThreadTests(TransactionTestCase):
|
|||
t.start()
|
||||
t.join()
|
||||
|
||||
# Without touching allow_thread_sharing, which should be False by default.
|
||||
# Without touching thread sharing, which should be False by default.
|
||||
exceptions = []
|
||||
do_thread()
|
||||
# Forbidden!
|
||||
self.assertIsInstance(exceptions[0], DatabaseError)
|
||||
|
||||
# If explicitly setting allow_thread_sharing to False
|
||||
connections['default'].allow_thread_sharing = False
|
||||
exceptions = []
|
||||
do_thread()
|
||||
# Forbidden!
|
||||
self.assertIsInstance(exceptions[0], DatabaseError)
|
||||
|
||||
# If explicitly setting allow_thread_sharing to True
|
||||
connections['default'].allow_thread_sharing = True
|
||||
exceptions = []
|
||||
do_thread()
|
||||
# All good
|
||||
self.assertEqual(exceptions, [])
|
||||
# After calling inc_thread_sharing() on the connection.
|
||||
connections['default'].inc_thread_sharing()
|
||||
try:
|
||||
exceptions = []
|
||||
do_thread()
|
||||
# All good
|
||||
self.assertEqual(exceptions, [])
|
||||
finally:
|
||||
connections['default'].dec_thread_sharing()
|
||||
|
||||
def test_closing_non_shared_connections(self):
|
||||
"""
|
||||
|
@ -721,16 +725,33 @@ class ThreadTests(TransactionTestCase):
|
|||
except DatabaseError as e:
|
||||
exceptions.add(e)
|
||||
# Enable thread sharing
|
||||
connections['default'].allow_thread_sharing = True
|
||||
t2 = threading.Thread(target=runner2, args=[connections['default']])
|
||||
t2.start()
|
||||
t2.join()
|
||||
connections['default'].inc_thread_sharing()
|
||||
try:
|
||||
t2 = threading.Thread(target=runner2, args=[connections['default']])
|
||||
t2.start()
|
||||
t2.join()
|
||||
finally:
|
||||
connections['default'].dec_thread_sharing()
|
||||
t1 = threading.Thread(target=runner1)
|
||||
t1.start()
|
||||
t1.join()
|
||||
# No exception was raised
|
||||
self.assertEqual(len(exceptions), 0)
|
||||
|
||||
def test_thread_sharing_count(self):
|
||||
self.assertIs(connection.allow_thread_sharing, False)
|
||||
connection.inc_thread_sharing()
|
||||
self.assertIs(connection.allow_thread_sharing, True)
|
||||
connection.inc_thread_sharing()
|
||||
self.assertIs(connection.allow_thread_sharing, True)
|
||||
connection.dec_thread_sharing()
|
||||
self.assertIs(connection.allow_thread_sharing, True)
|
||||
connection.dec_thread_sharing()
|
||||
self.assertIs(connection.allow_thread_sharing, False)
|
||||
msg = 'Cannot decrement the thread sharing count below zero.'
|
||||
with self.assertRaisesMessage(RuntimeError, msg):
|
||||
connection.dec_thread_sharing()
|
||||
|
||||
|
||||
class MySQLPKZeroTests(TestCase):
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue