mirror of
https://github.com/django/django.git
synced 2025-11-17 18:48:15 +00:00
Merge 4ac58bcd10 into 1ce6e78dd4
This commit is contained in:
commit
fa8182884d
4 changed files with 141 additions and 79 deletions
|
|
@ -128,6 +128,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
"Oracle doesn't support casting filters to NUMBER.": {
|
||||
"lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup",
|
||||
},
|
||||
"Oracle doesn't support JSON null scalar extraction.": {
|
||||
"model_fields.test_jsonfield.JSONNullTests.test_filter_in",
|
||||
},
|
||||
}
|
||||
if self.connection.oracle_version < (23,):
|
||||
skips.update(
|
||||
|
|
|
|||
|
|
@ -375,6 +375,92 @@ class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
|||
pass
|
||||
|
||||
|
||||
class ProcessJSONLHSMixin:
|
||||
def _get_json_path(self, connection, key_transforms):
|
||||
if key_transforms is None:
|
||||
return "$"
|
||||
return connection.ops.compile_json_path(key_transforms)
|
||||
|
||||
def _process_as_oracle(self, sql, params, connection, key_transforms=None):
|
||||
json_path = self._get_json_path(connection, key_transforms)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
template = (
|
||||
"COALESCE("
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
template = (
|
||||
"COALESCE("
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff')"
|
||||
")"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle. Use a custom delimiter to prevent the
|
||||
# JSON path from escaping the SQL literal. Each key in the JSON path is
|
||||
# passed through json.dumps() with ensure_ascii=True (the default),
|
||||
# which converts the delimiter into the escaped \uffff format. This
|
||||
# ensures that the delimiter is not present in the JSON path.
|
||||
sql = template % ((sql, json_path) * 2)
|
||||
return sql, params * 2
|
||||
|
||||
def _process_as_sqlite(self, sql, params, connection, key_transforms=None):
|
||||
json_path = self._get_json_path(connection, key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(value) for value in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (sql, datatype_values, sql, sql), (*params, json_path) * 3
|
||||
|
||||
def _process_as_mysql(self, sql, params, connection, key_transforms=None):
|
||||
json_path = self._get_json_path(connection, key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % sql, (*params, json_path)
|
||||
|
||||
|
||||
class JSONIn(ProcessJSONLHSMixin, lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if not connection.features.has_native_json_field and (
|
||||
not hasattr(param, "as_sql") or isinstance(param, expressions.Value)
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = param.value if hasattr(param, "value") else json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
sql, params = super().process_lhs(compiler, connection)
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
return sql, params
|
||||
if connection.vendor == "mysql":
|
||||
return self._process_as_mysql(sql, params, connection)
|
||||
elif connection.vendor == "oracle":
|
||||
return self._process_as_oracle(sql, params, connection)
|
||||
elif connection.vendor == "sqlite":
|
||||
return self._process_as_sqlite(sql, params, connection)
|
||||
return sql, params
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
|
|
@ -382,9 +468,10 @@ JSONField.register_lookup(HasKeys)
|
|||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
JSONField.register_lookup(JSONIn)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
class KeyTransform(ProcessJSONLHSMixin, Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
|
||||
|
|
@ -406,33 +493,11 @@ class KeyTransform(Transform):
|
|||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = connection.ops.compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, (*params, json_path)
|
||||
return self._process_as_mysql(lhs, params, connection, key_transforms)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = connection.ops.compile_json_path(key_transforms)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff')"
|
||||
")"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle. Use a custom delimiter to prevent the
|
||||
# JSON path from escaping the SQL literal. Each key in the JSON path is
|
||||
# passed through json.dumps() with ensure_ascii=True (the default),
|
||||
# which converts the delimiter into the escaped \uffff format. This
|
||||
# ensures that the delimiter is not present in the JSON path.
|
||||
return sql % ((lhs, json_path) * 2), tuple(params) * 2
|
||||
return self._process_as_oracle(lhs, params, connection, key_transforms)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
|
|
@ -447,14 +512,7 @@ class KeyTransform(Transform):
|
|||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = connection.ops.compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (*params, json_path) * 3
|
||||
return self._process_as_sqlite(lhs, params, connection, key_transforms)
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
|
|
@ -535,33 +593,8 @@ class KeyTransformIsNull(lookups.IsNull):
|
|||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
class KeyTransformIn(JSONIn):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
|
|
|
|||
|
|
@ -287,33 +287,26 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
|||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, "resolve_expression"):
|
||||
return self.rhs
|
||||
contains_expr = False
|
||||
if any(hasattr(value, "resolve_expression") for value in self.rhs):
|
||||
return ExpressionList(
|
||||
*[
|
||||
(
|
||||
value
|
||||
if hasattr(value, "resolve_expression")
|
||||
else Value(value, self.lhs.output_field)
|
||||
)
|
||||
for value in self.rhs
|
||||
]
|
||||
)
|
||||
prepared_values = []
|
||||
for rhs_value in self.rhs:
|
||||
if hasattr(rhs_value, "resolve_expression"):
|
||||
# An expression will be handled by the database but can coexist
|
||||
# alongside real values.
|
||||
contains_expr = True
|
||||
elif (
|
||||
if (
|
||||
self.prepare_rhs
|
||||
and hasattr(self.lhs, "output_field")
|
||||
and hasattr(self.lhs.output_field, "get_prep_value")
|
||||
):
|
||||
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
|
||||
prepared_values.append(rhs_value)
|
||||
if contains_expr:
|
||||
return ExpressionList(
|
||||
*[
|
||||
# Expression defaults `str` to field references while
|
||||
# lookups default them to literal values.
|
||||
(
|
||||
Value(prep_value, self.lhs.output_field)
|
||||
if isinstance(prep_value, str)
|
||||
else prep_value
|
||||
)
|
||||
for prep_value in prepared_values
|
||||
]
|
||||
)
|
||||
return prepared_values
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
|
|
|
|||
|
|
@ -1010,6 +1010,19 @@ class TestQuerying(TestCase):
|
|||
NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False
|
||||
)
|
||||
|
||||
def test_in(self):
|
||||
tests = [
|
||||
([[]], [self.objs[1]]),
|
||||
([{}], [self.objs[2]]),
|
||||
([{"a": "b", "c": 14}], [self.objs[3]]),
|
||||
([[1, [2]]], [self.objs[5]]),
|
||||
]
|
||||
for lookup_value, expected in tests:
|
||||
with self.subTest(value__in=lookup_value):
|
||||
self.assertCountEqual(
|
||||
NullableJSONModel.objects.filter(value__in=lookup_value), expected
|
||||
)
|
||||
|
||||
def test_key_in(self):
|
||||
tests = [
|
||||
("value__c__in", [14], self.objs[3:5]),
|
||||
|
|
@ -1023,6 +1036,7 @@ class TestQuerying(TestCase):
|
|||
[self.objs[7]],
|
||||
),
|
||||
("value__foo__in", [F("value__bax__foo")], [self.objs[7]]),
|
||||
("value__foo__in", [F("value__bax__foo"), {}], [self.objs[7]]),
|
||||
(
|
||||
"value__foo__in",
|
||||
[KeyTransform("foo", KeyTransform("bax", "value")), "baz"],
|
||||
|
|
@ -1031,6 +1045,17 @@ class TestQuerying(TestCase):
|
|||
("value__foo__in", [F("value__bax__foo"), "baz"], [self.objs[7]]),
|
||||
("value__foo__in", ["bar", "baz"], [self.objs[7]]),
|
||||
("value__bar__in", [["foo", "bar"]], [self.objs[7]]),
|
||||
("value__bar__in", [Value(["foo", "bar"], JSONField())], [self.objs[7]]),
|
||||
(
|
||||
"value__bar__in",
|
||||
[["foo", "bar"], Value({}, JSONField())],
|
||||
[self.objs[7]],
|
||||
),
|
||||
(
|
||||
"value__bar__in",
|
||||
[Value(["foo", "bar"], JSONField()), {"a": "b"}],
|
||||
[self.objs[7]],
|
||||
),
|
||||
("value__bar__in", [["foo", "bar"], ["a"]], [self.objs[7]]),
|
||||
("value__bax__in", [{"foo": "bar"}, {"a": "b"}], [self.objs[7]]),
|
||||
("value__h__in", [True, "foo"], [self.objs[4]]),
|
||||
|
|
@ -1297,6 +1322,14 @@ class JSONNullTests(TestCase):
|
|||
NullableJSONModel.objects.filter(value__isnull=True), [sql_null]
|
||||
)
|
||||
|
||||
def test_filter_in(self):
|
||||
obj = NullableJSONModel.objects.create(value=JSONNull())
|
||||
obj2 = NullableJSONModel.objects.create(value=[1])
|
||||
self.assertSequenceEqual(
|
||||
NullableJSONModel.objects.filter(value__in=[JSONNull(), [1], "foo"]),
|
||||
[obj, obj2],
|
||||
)
|
||||
|
||||
def test_bulk_update(self):
|
||||
obj1 = NullableJSONModel.objects.create(value={"k": "1st"})
|
||||
obj2 = NullableJSONModel.objects.create(value={"k": "2nd"})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue