Fixed #13781 -- Improved select_related in inheritance situations

The select_related code got confused when it needed to travel a
reverse relation to a model which had different parent than the
originally travelled relation.

Thanks to Trac aliases shauncutts for report and ungenio for original
patch (committed patch is somewhat modified version of that).
This commit is contained in:
Anssi Kääriäinen 2012-11-09 20:21:46 +02:00
parent 92d7f541da
commit f51e409a5f
5 changed files with 203 additions and 43 deletions

View file

@ -1300,7 +1300,7 @@ class EmptyQuerySet(QuerySet):
value_annotation = False
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
only_load=None, local_only=False):
only_load=None, from_parent=None):
"""
Helper function that recursively returns an information for a klass, to be
used in get_cached_row. It exists just to compute this information only
@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
* local_only - Only populate local fields. This is used when
following reverse select-related relations
* from_parent - the parent model used to get to this model
Note that when travelling from parent to child, we will only load child
fields which aren't in the parent.
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields:
skip.add(field.attname)
elif local_only and model is not None:
elif from_parent and issubclass(from_parent, model.__class__):
# Avoid loading fields already loaded for parent model for
# child models.
continue
else:
init_list.append(field.attname)
@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else:
# Load all fields on klass
# We trying to not populate field_names variable for perfomance reason.
# If field_names variable is set, it is used to instantiate desired fields,
# by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
# But kwargs version of Model.__init__ is slower, so we should avoid using
# it when it is not really neccesary.
if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
field_count = len(klass._meta.local_fields)
field_names = [f.attname for f in klass._meta.local_fields]
else:
field_count = len(klass._meta.fields)
field_count = len(klass._meta.fields)
# Check if we need to skip some parent fields.
if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
# Only load those fields which haven't been already loaded into
# 'from_parent'.
non_seen_models = [p for p in klass._meta.get_parent_list()
if not issubclass(from_parent, p)]
# Load local fields, too...
non_seen_models.append(klass)
field_names = [f.attname for f in klass._meta.fields
if f.model in non_seen_models]
field_count = len(field_names)
# Try to avoid populating field_names variable for perfomance reasons.
# If field_names variable is set, we use **kwargs based model init
# which is slower than normal init.
if field_count == len(klass._meta.fields):
field_names = ()
restricted = requested is not None
@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
if o.field.unique and select_related_descend(o.field, restricted, requested,
only_load.get(o.model), reverse=True):
next = requested[o.field.related_query_name()]
parent = klass if issubclass(o.model, klass) else None
klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
requested=next, only_load=only_load, local_only=True)
requested=next, only_load=only_load, from_parent=parent)
reverse_related_fields.append((o.field, klass_info))
if field_names:
pk_idx = field_names.index(klass._meta.pk.attname)
@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
def get_cached_row(row, index_start, using, klass_info, offset=0):
def get_cached_row(row, index_start, using, klass_info, offset=0,
parent_data=()):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
@ -1418,13 +1430,16 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
* offset - the number of additional fields that are known to
exist in row for `klass`. This usually means the number of
annotated results on `klass`.
* using - the database alias on which the query is being executed.
* using - the database alias on which the query is being executed.
* klass_info - result of the get_klass_info function
* parent_data - parent model data in format (field, value). Used
to populate the non-local fields of child models.
"""
if klass_info is None:
return None
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
fields = row[index_start : index_start + field_count]
# If the pk column is None (or the Oracle equivalent ''), then the related
# object must be non-existent - set the relation to None.
@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
obj = klass(**dict(zip(field_names, fields)))
else:
obj = klass(*fields)
# If an object was retrieved, set the database state.
if obj:
obj._state.db = using
@ -1464,34 +1478,35 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
# Only handle the restricted case - i.e., don't do a depth
# descent into reverse relations unless explicitly requested
for f, klass_info in reverse_related_fields:
# Transfer data from this object to childs.
parent_data = []
for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
if rel_model is not None and isinstance(obj, rel_model):
parent_data.append((rel_field, getattr(obj, rel_field.attname)))
# Recursively retrieve the data for the related object
cached_row = get_cached_row(row, index_end, using, klass_info)
cached_row = get_cached_row(row, index_end, using, klass_info,
parent_data=parent_data)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
# If the field is unique, populate the
# reverse descriptor cache
# populate the reverse descriptor cache
setattr(obj, f.related.get_cache_name(), rel_obj)
if rel_obj is not None:
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
# Now populate all the non-local field values
# on the related object
for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
if rel_model is not None:
# Populate related object caches using parent data.
for rel_field, _ in parent_data:
if rel_field.rel:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
# populate the field cache for any related object
# that has already been retrieved
if rel_field.rel:
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
try:
cached_obj = getattr(obj, rel_field.get_cache_name())
setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
except AttributeError:
# Related object hasn't been cached yet
pass
return obj, index_end