This commit is contained in:
Clifford Gama 2025-11-17 13:45:10 +01:00 committed by GitHub
commit fa8182884d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 141 additions and 79 deletions

View file

@ -128,6 +128,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"Oracle doesn't support casting filters to NUMBER.": {
"lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup",
},
"Oracle doesn't support JSON null scalar extraction.": {
"model_fields.test_jsonfield.JSONNullTests.test_filter_in",
},
}
if self.connection.oracle_version < (23,):
skips.update(

View file

@ -375,6 +375,92 @@ class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
class ProcessJSONLHSMixin:
def _get_json_path(self, connection, key_transforms):
if key_transforms is None:
return "$"
return connection.ops.compile_json_path(key_transforms)
def _process_as_oracle(self, sql, params, connection, key_transforms=None):
json_path = self._get_json_path(connection, key_transforms)
if connection.features.supports_primitives_in_json_field:
template = (
"COALESCE("
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
")"
)
else:
template = (
"COALESCE("
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
"JSON_VALUE(%s, q'\uffff%s\uffff')"
")"
)
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle. Use a custom delimiter to prevent the
# JSON path from escaping the SQL literal. Each key in the JSON path is
# passed through json.dumps() with ensure_ascii=True (the default),
# which converts the delimiter into the escaped \uffff format. This
# ensures that the delimiter is not present in the JSON path.
sql = template % ((sql, json_path) * 2)
return sql, params * 2
def _process_as_sqlite(self, sql, params, connection, key_transforms=None):
json_path = self._get_json_path(connection, key_transforms)
datatype_values = ",".join(
[repr(value) for value in connection.ops.jsonfield_datatype_values]
)
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (sql, datatype_values, sql, sql), (*params, json_path) * 3
def _process_as_mysql(self, sql, params, connection, key_transforms=None):
json_path = self._get_json_path(connection, key_transforms)
return "JSON_EXTRACT(%s, %%s)" % sql, (*params, json_path)
class JSONIn(ProcessJSONLHSMixin, lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler,
connection,
sql,
param,
)
if not connection.features.has_native_json_field and (
not hasattr(param, "as_sql") or isinstance(param, expressions.Value)
):
if connection.vendor == "oracle":
value = param.value if hasattr(param, "value") else json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql %= "JSON_QUERY"
else:
sql %= "JSON_VALUE"
elif connection.vendor == "mysql" or (
connection.vendor == "sqlite"
and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
sql = "JSON_UNQUOTE(%s)" % sql
return sql, params
def process_lhs(self, compiler, connection):
sql, params = super().process_lhs(compiler, connection)
if isinstance(self.lhs, KeyTransform):
return sql, params
if connection.vendor == "mysql":
return self._process_as_mysql(sql, params, connection)
elif connection.vendor == "oracle":
return self._process_as_oracle(sql, params, connection)
elif connection.vendor == "sqlite":
return self._process_as_sqlite(sql, params, connection)
return sql, params
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
@ -382,9 +468,10 @@ JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
JSONField.register_lookup(JSONIn)
class KeyTransform(Transform):
class KeyTransform(ProcessJSONLHSMixin, Transform):
postgres_operator = "->"
postgres_nested_operator = "#>"
@ -406,33 +493,11 @@ class KeyTransform(Transform):
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = connection.ops.compile_json_path(key_transforms)
return "JSON_EXTRACT(%s, %%s)" % lhs, (*params, json_path)
return self._process_as_mysql(lhs, params, connection, key_transforms)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = connection.ops.compile_json_path(key_transforms)
if connection.features.supports_primitives_in_json_field:
sql = (
"COALESCE("
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
")"
)
else:
sql = (
"COALESCE("
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
"JSON_VALUE(%s, q'\uffff%s\uffff')"
")"
)
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle. Use a custom delimiter to prevent the
# JSON path from escaping the SQL literal. Each key in the JSON path is
# passed through json.dumps() with ensure_ascii=True (the default),
# which converts the delimiter into the escaped \uffff format. This
# ensures that the delimiter is not present in the JSON path.
return sql % ((lhs, json_path) * 2), tuple(params) * 2
return self._process_as_oracle(lhs, params, connection, key_transforms)
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
@ -447,14 +512,7 @@ class KeyTransform(Transform):
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = connection.ops.compile_json_path(key_transforms)
datatype_values = ",".join(
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
)
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (*params, json_path) * 3
return self._process_as_sqlite(lhs, params, connection, key_transforms)
class KeyTextTransform(KeyTransform):
@ -535,33 +593,8 @@ class KeyTransformIsNull(lookups.IsNull):
)
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler,
connection,
sql,
param,
)
if (
not hasattr(param, "as_sql")
and not connection.features.has_native_json_field
):
if connection.vendor == "oracle":
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql %= "JSON_QUERY"
else:
sql %= "JSON_VALUE"
elif connection.vendor == "mysql" or (
connection.vendor == "sqlite"
and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
sql = "JSON_UNQUOTE(%s)" % sql
return sql, params
class KeyTransformIn(JSONIn):
pass
class KeyTransformExact(JSONExact):

View file

@ -287,33 +287,26 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
def get_prep_lookup(self):
if hasattr(self.rhs, "resolve_expression"):
return self.rhs
contains_expr = False
if any(hasattr(value, "resolve_expression") for value in self.rhs):
return ExpressionList(
*[
(
value
if hasattr(value, "resolve_expression")
else Value(value, self.lhs.output_field)
)
for value in self.rhs
]
)
prepared_values = []
for rhs_value in self.rhs:
if hasattr(rhs_value, "resolve_expression"):
# An expression will be handled by the database but can coexist
# alongside real values.
contains_expr = True
elif (
if (
self.prepare_rhs
and hasattr(self.lhs, "output_field")
and hasattr(self.lhs.output_field, "get_prep_value")
):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
if contains_expr:
return ExpressionList(
*[
# Expression defaults `str` to field references while
# lookups default them to literal values.
(
Value(prep_value, self.lhs.output_field)
if isinstance(prep_value, str)
else prep_value
)
for prep_value in prepared_values
]
)
return prepared_values
def process_rhs(self, compiler, connection):

View file

@ -1010,6 +1010,19 @@ class TestQuerying(TestCase):
NullableJSONModel.objects.filter(value__foo__iexact='"BaR"').exists(), False
)
def test_in(self):
tests = [
([[]], [self.objs[1]]),
([{}], [self.objs[2]]),
([{"a": "b", "c": 14}], [self.objs[3]]),
([[1, [2]]], [self.objs[5]]),
]
for lookup_value, expected in tests:
with self.subTest(value__in=lookup_value):
self.assertCountEqual(
NullableJSONModel.objects.filter(value__in=lookup_value), expected
)
def test_key_in(self):
tests = [
("value__c__in", [14], self.objs[3:5]),
@ -1023,6 +1036,7 @@ class TestQuerying(TestCase):
[self.objs[7]],
),
("value__foo__in", [F("value__bax__foo")], [self.objs[7]]),
("value__foo__in", [F("value__bax__foo"), {}], [self.objs[7]]),
(
"value__foo__in",
[KeyTransform("foo", KeyTransform("bax", "value")), "baz"],
@ -1031,6 +1045,17 @@ class TestQuerying(TestCase):
("value__foo__in", [F("value__bax__foo"), "baz"], [self.objs[7]]),
("value__foo__in", ["bar", "baz"], [self.objs[7]]),
("value__bar__in", [["foo", "bar"]], [self.objs[7]]),
("value__bar__in", [Value(["foo", "bar"], JSONField())], [self.objs[7]]),
(
"value__bar__in",
[["foo", "bar"], Value({}, JSONField())],
[self.objs[7]],
),
(
"value__bar__in",
[Value(["foo", "bar"], JSONField()), {"a": "b"}],
[self.objs[7]],
),
("value__bar__in", [["foo", "bar"], ["a"]], [self.objs[7]]),
("value__bax__in", [{"foo": "bar"}, {"a": "b"}], [self.objs[7]]),
("value__h__in", [True, "foo"], [self.objs[4]]),
@ -1297,6 +1322,14 @@ class JSONNullTests(TestCase):
NullableJSONModel.objects.filter(value__isnull=True), [sql_null]
)
def test_filter_in(self):
obj = NullableJSONModel.objects.create(value=JSONNull())
obj2 = NullableJSONModel.objects.create(value=[1])
self.assertSequenceEqual(
NullableJSONModel.objects.filter(value__in=[JSONNull(), [1], "foo"]),
[obj, obj2],
)
def test_bulk_update(self):
obj1 = NullableJSONModel.objects.create(value={"k": "1st"})
obj2 = NullableJSONModel.objects.create(value={"k": "2nd"})