Fixed #15648 -- Allowed QuerySet.values_list() to return a namedtuple.

This commit is contained in:
Sergey Fedoseev 2017-08-04 15:28:39 +05:00 committed by Tim Graham
parent a027447f56
commit f3c9562143
4 changed files with 101 additions and 5 deletions

View file

@ -6,8 +6,9 @@ import copy
import operator
import sys
import warnings
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from contextlib import suppress
from functools import lru_cache
from itertools import chain
from django.conf import settings
@ -137,6 +138,34 @@ class ValuesListIterable(BaseIterable):
return compiler.results_iter(tuple_expected=True)
class NamedValuesListIterable(ValuesListIterable):
"""
Iterable returned by QuerySet.values_list(named=True) that yields a
namedtuple for each row.
"""
@staticmethod
@lru_cache()
def create_namedtuple_class(*names):
# Cache namedtuple() with @lru_cache() since it's too slow to be
# called for every QuerySet evaluation.
return namedtuple('Row', names)
def __iter__(self):
queryset = self.queryset
if queryset._fields:
names = queryset._fields
else:
query = queryset.query
names = list(query.extra_select)
names.extend(query.values_select)
names.extend(query.annotation_select)
tuple_class = self.create_namedtuple_class(*names)
new = tuple.__new__
for row in super().__iter__():
yield new(tuple_class, row)
class FlatValuesListIterable(BaseIterable):
"""
Iterable returned by QuerySet.values_list(flat=True) that yields single
@ -712,22 +741,35 @@ class QuerySet:
clone._iterable_class = ValuesIterable
return clone
def values_list(self, *fields, flat=False):
def values_list(self, *fields, flat=False, named=False):
if flat and named:
raise TypeError("'flat' and 'named' can't be used together.")
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
field_names = {f for f in fields if not hasattr(f, 'resolve_expression')}
_fields = []
expressions = {}
counter = 1
for field in fields:
if hasattr(field, 'resolve_expression'):
field_id = str(id(field))
field_id_prefix = getattr(field, 'default_alias', field.__class__.__name__.lower())
while True:
field_id = field_id_prefix + str(counter)
counter += 1
if field_id not in field_names:
break
expressions[field_id] = field
_fields.append(field_id)
else:
_fields.append(field)
clone = self._values(*_fields, **expressions)
clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
clone._iterable_class = (
NamedValuesListIterable if named
else FlatValuesListIterable if flat
else ValuesListIterable
)
return clone
def dates(self, field_name, kind, order='ASC'):