mirror of
https://github.com/django/django.git
synced 2025-07-07 21:35:15 +00:00
Fixed #36380 -- Deferred SQL formatting when running tests with --debug-sql.
Thanks to Jacob Walls for the report and previous iterations of this
fix, to Simon Charette for the logging formatter idea, and to Tim Graham
for testing and ensuring that 3rd party backends remain compatible.
This partially reverts d8f093908c
.
Refs #36112, #35448.
Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
This commit is contained in:
parent
104cbfd44b
commit
1a03a984ab
4 changed files with 216 additions and 28 deletions
|
@ -151,7 +151,7 @@ class CursorDebugWrapper(CursorWrapper):
|
|||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s",
|
||||
duration,
|
||||
self.db.ops.format_debug_sql(sql),
|
||||
sql,
|
||||
params,
|
||||
self.db.alias,
|
||||
extra={
|
||||
|
|
|
@ -16,7 +16,6 @@ import unittest.suite
|
|||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
|
||||
import django
|
||||
from django.core.management import call_command
|
||||
|
@ -41,16 +40,47 @@ except ImportError:
|
|||
tblib = None
|
||||
|
||||
|
||||
class QueryFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
if (alias := getattr(record, "alias", None)) in connections:
|
||||
format_sql = connections[alias].ops.format_debug_sql
|
||||
|
||||
sql = None
|
||||
formatted_sql = None
|
||||
if args := record.args:
|
||||
if isinstance(args, tuple) and len(args) > 1 and (sql := args[1]):
|
||||
record.args = (args[0], formatted_sql := format_sql(sql), *args[2:])
|
||||
elif isinstance(record.args, dict) and (sql := record.args.get("sql")):
|
||||
record.args["sql"] = formatted_sql = format_sql(sql)
|
||||
|
||||
if extra_sql := getattr(record, "sql", None):
|
||||
if extra_sql == sql:
|
||||
record.sql = formatted_sql
|
||||
else:
|
||||
record.sql = format_sql(extra_sql)
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
def __init__(self, stream, descriptions, verbosity):
|
||||
self.logger = logging.getLogger("django.db.backends")
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
self.debug_sql_stream = None
|
||||
self.handler = None
|
||||
super().__init__(stream, descriptions, verbosity)
|
||||
|
||||
def _read_logger_stream(self):
|
||||
if self.handler is None:
|
||||
# Error before tests e.g. in setUpTestData().
|
||||
sql = ""
|
||||
else:
|
||||
self.handler.stream.seek(0)
|
||||
sql = self.handler.stream.read()
|
||||
return sql
|
||||
|
||||
def startTest(self, test):
|
||||
self.debug_sql_stream = StringIO()
|
||||
self.handler = logging.StreamHandler(self.debug_sql_stream)
|
||||
self.handler = logging.StreamHandler(io.StringIO())
|
||||
self.handler.setFormatter(QueryFormatter())
|
||||
self.logger.addHandler(self.handler)
|
||||
super().startTest(test)
|
||||
|
||||
|
@ -58,35 +88,26 @@ class DebugSQLTextTestResult(unittest.TextTestResult):
|
|||
super().stopTest(test)
|
||||
self.logger.removeHandler(self.handler)
|
||||
if self.showAll:
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.stream.write(self.debug_sql_stream.read())
|
||||
self.stream.write(self._read_logger_stream())
|
||||
self.stream.writeln(self.separator2)
|
||||
|
||||
def addError(self, test, err):
|
||||
super().addError(test, err)
|
||||
if self.debug_sql_stream is None:
|
||||
# Error before tests e.g. in setUpTestData().
|
||||
sql = ""
|
||||
else:
|
||||
self.debug_sql_stream.seek(0)
|
||||
sql = self.debug_sql_stream.read()
|
||||
self.errors[-1] = self.errors[-1] + (sql,)
|
||||
self.errors[-1] = self.errors[-1] + (self._read_logger_stream(),)
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super().addFailure(test, err)
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
|
||||
self.failures[-1] = self.failures[-1] + (self._read_logger_stream(),)
|
||||
|
||||
def addSubTest(self, test, subtest, err):
|
||||
super().addSubTest(test, subtest, err)
|
||||
if err is not None:
|
||||
self.debug_sql_stream.seek(0)
|
||||
errors = (
|
||||
self.failures
|
||||
if issubclass(err[0], test.failureException)
|
||||
else self.errors
|
||||
)
|
||||
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
|
||||
errors[-1] = errors[-1] + (self._read_logger_stream(),)
|
||||
|
||||
def printErrorList(self, flavour, errors):
|
||||
for test, err, sql_debug in errors:
|
||||
|
|
|
@ -83,12 +83,7 @@ class LastExecutedQueryTest(TestCase):
|
|||
connection.ops.last_executed_query(cursor, "SELECT %s" + suffix, (1,))
|
||||
|
||||
def test_debug_sql(self):
|
||||
qs = Reporter.objects.filter(first_name="test")
|
||||
ops = connections[qs.db].ops
|
||||
with mock.patch.object(ops, "format_debug_sql") as format_debug_sql:
|
||||
list(qs)
|
||||
# Queries are formatted with DatabaseOperations.format_debug_sql().
|
||||
format_debug_sql.assert_called()
|
||||
list(Reporter.objects.filter(first_name="test"))
|
||||
sql = connection.queries[-1]["sql"].lower()
|
||||
self.assertIn("select", sql)
|
||||
self.assertIn(Reporter._meta.db_table, sql)
|
||||
|
@ -580,13 +575,13 @@ class BackendTestCase(TransactionTestCase):
|
|||
@mock.patch("django.db.backends.utils.logger")
|
||||
@override_settings(DEBUG=True)
|
||||
def test_queries_logger(self, mocked_logger):
|
||||
sql = "SELECT 1" + connection.features.bare_select_suffix
|
||||
sql = connection.ops.format_debug_sql(sql)
|
||||
sql = "select 1" + connection.features.bare_select_suffix
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
params, kwargs = mocked_logger.debug.call_args
|
||||
self.assertIn("; alias=%s", params[0])
|
||||
self.assertEqual(params[2], sql)
|
||||
self.assertNotEqual(params[2], connection.ops.format_debug_sql(sql))
|
||||
self.assertIsNone(params[3])
|
||||
self.assertEqual(params[4], connection.alias)
|
||||
self.assertEqual(
|
||||
|
|
|
@ -1,12 +1,184 @@
|
|||
import logging
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from time import time
|
||||
from unittest import mock
|
||||
|
||||
from django.db import connection
|
||||
from django.db import DEFAULT_DB_ALIAS, connection, connections
|
||||
from django.test import TestCase
|
||||
from django.test.runner import DiscoverRunner
|
||||
from django.test.runner import DiscoverRunner, QueryFormatter
|
||||
|
||||
from .models import Person
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryFormatterTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.format_sql_calls = []
|
||||
|
||||
def new_format_sql(self, sql):
|
||||
# Use time() to introduce some uniqueness.
|
||||
formatted = "Formatted! %s at %s" % (sql.upper(), time())
|
||||
self.format_sql_calls.append({sql: formatted})
|
||||
return formatted
|
||||
|
||||
def make_handler(self, **formatter_kwargs):
|
||||
formatter = QueryFormatter(**formatter_kwargs)
|
||||
|
||||
handler = logging.StreamHandler(StringIO())
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
self.addCleanup(logger.setLevel, original_level)
|
||||
logger.addHandler(handler)
|
||||
self.addCleanup(logger.removeHandler, handler)
|
||||
|
||||
return handler
|
||||
|
||||
def do_log(self, msg, *logger_args, alias=DEFAULT_DB_ALIAS, extra=None):
|
||||
if extra is None:
|
||||
extra = {}
|
||||
if alias and "alias" not in extra:
|
||||
extra["alias"] = alias
|
||||
# Patch connection's format_debug_sql to ensure it was properly called.
|
||||
with mock.patch.object(
|
||||
connections[alias].ops, "format_debug_sql", side_effect=self.new_format_sql
|
||||
):
|
||||
logger.info(msg, *logger_args, extra=extra)
|
||||
|
||||
def assertLogRecord(self, handler, expected):
|
||||
handler.stream.seek(0)
|
||||
self.assertEqual(handler.stream.read().strip(), expected)
|
||||
|
||||
def assertSQLFormatted(self, handler, sql, total_calls=1):
|
||||
self.assertEqual(len(self.format_sql_calls), total_calls)
|
||||
formatted_sql = self.format_sql_calls[0][sql]
|
||||
expected = f"=> Executing query duration=3.142 sql={formatted_sql}"
|
||||
self.assertLogRecord(handler, expected)
|
||||
|
||||
def test_formats_sql_bracket_format_style(self):
|
||||
handler = self.make_handler(
|
||||
fmt="{message} duration={duration:.3f} sql={sql}", style="{"
|
||||
)
|
||||
msg = "=> Executing query"
|
||||
sql = "select * from foo"
|
||||
|
||||
self.do_log(msg, extra={"sql": sql, "duration": 3.1416})
|
||||
self.assertSQLFormatted(handler, sql)
|
||||
|
||||
def test_formats_sql_named_fmt_format_style(self):
|
||||
handler = self.make_handler(
|
||||
fmt="%(message)s duration=%(duration).3f sql=%(sql)s"
|
||||
)
|
||||
msg = "=> Executing query"
|
||||
sql = "select * from foo"
|
||||
|
||||
self.do_log(msg, extra={"sql": sql, "duration": 3.1416})
|
||||
self.assertSQLFormatted(handler, sql)
|
||||
|
||||
def test_formats_sql_named_percent_format_style(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%(duration).3f sql=%(sql)s"
|
||||
sql = "select * from foo"
|
||||
|
||||
self.do_log(msg, {"duration": 3.1416, "sql": sql})
|
||||
self.assertSQLFormatted(handler, sql)
|
||||
|
||||
def test_formats_sql_default_percent_format_style(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%.3f sql=%s"
|
||||
sql = "select * from foo"
|
||||
|
||||
self.do_log(msg, 3.1416, sql)
|
||||
self.assertSQLFormatted(handler, sql)
|
||||
|
||||
def test_formats_sql_multiple_matching_sql(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%.3f sql=%s"
|
||||
sql = "select * from foo"
|
||||
|
||||
self.do_log(msg, 3.1416, sql, extra={"duration": 3.1416, "sql": sql})
|
||||
self.assertSQLFormatted(handler, sql)
|
||||
|
||||
def test_formats_sql_multiple_non_matching_sql(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%.3f sql=%s"
|
||||
sql1 = "select * from foo"
|
||||
sql2 = "select * from other"
|
||||
|
||||
self.do_log(msg, 3.1416, sql1, extra={"duration": 3.1416, "sql": sql2})
|
||||
self.assertSQLFormatted(handler, sql1, total_calls=2)
|
||||
# Second format call is triggered since the sql are different.
|
||||
self.assertEqual(list(self.format_sql_calls[1].keys()), [sql2])
|
||||
|
||||
def test_log_record_no_args(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query no args"
|
||||
|
||||
self.do_log(msg)
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, msg)
|
||||
|
||||
def test_log_record_not_enough_args(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query one args %r"
|
||||
args = "not formatted"
|
||||
|
||||
self.do_log(msg, args)
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, msg % args)
|
||||
|
||||
def test_log_record_not_key_in_dict_args(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query missing sql key %(foo)r"
|
||||
args = {"foo": "bar"}
|
||||
|
||||
self.do_log(msg, args)
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, msg % args)
|
||||
|
||||
def test_log_record_no_alias(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%.3f sql=%s"
|
||||
args = (3.1416, "select * from foo")
|
||||
|
||||
self.do_log(msg, *args, extra={"alias": "does not exist"})
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, msg % args)
|
||||
|
||||
def test_log_record_sql_arg_none(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%.3f sql=%s"
|
||||
args = (3.1416, None)
|
||||
|
||||
self.do_log(msg, *args)
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, msg % args)
|
||||
|
||||
def test_log_record_sql_key_none(self):
|
||||
handler = self.make_handler()
|
||||
msg = "=> Executing query duration=%(duration).3f sql=%(sql)s"
|
||||
args = {"duration": 3.1416, "sql": None}
|
||||
|
||||
self.do_log(msg, args)
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, msg % args)
|
||||
|
||||
def test_log_record_sql_extra_none(self):
|
||||
handler = self.make_handler(
|
||||
fmt="{message} duration={duration:.3f} sql={sql}", style="{"
|
||||
)
|
||||
msg = "=> Executing query"
|
||||
|
||||
self.do_log(msg, extra={"sql": None, "duration": 3.1416})
|
||||
self.assertEqual(self.format_sql_calls, [])
|
||||
self.assertLogRecord(handler, f"{msg} duration=3.142 sql=None")
|
||||
|
||||
|
||||
@unittest.skipUnless(
|
||||
connection.vendor == "sqlite", "Only run on sqlite so we can check output SQL."
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue