Refs #29444 -- Allowed returning multiple fields from INSERT statements on PostgreSQL.

Thanks Florian Apolloner, Tim Graham, Simon Charette, Nick Pope, and
Mariusz Felisiak for reviews.
This commit is contained in:
Johannes Hoppe 2019-07-24 08:42:41 +02:00 committed by Mariusz Felisiak
parent 736e7d44de
commit 7254f1138d
16 changed files with 209 additions and 89 deletions

View file

@ -470,23 +470,33 @@ class QuerySet:
return objs
self._for_write = True
connection = connections[self.db]
fields = self.model._meta.concrete_fields
opts = self.model._meta
fields = opts.concrete_fields
objs = list(objs)
self._populate_pk_values(objs)
with transaction.atomic(using=self.db, savepoint=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk:
self._batched_insert(objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts)
returned_columns = self._batched_insert(
objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
)
for obj_with_pk, results in zip(objs_with_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
if field != opts.pk:
setattr(obj_with_pk, field.attname, result)
for obj_with_pk in objs_with_pk:
obj_with_pk._state.adding = False
obj_with_pk._state.db = self.db
if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)]
ids = self._batched_insert(objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts)
returned_columns = self._batched_insert(
objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
)
if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
assert len(ids) == len(objs_without_pk)
for obj_without_pk, pk in zip(objs_without_pk, ids):
obj_without_pk.pk = pk
assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, results in zip(objs_without_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
setattr(obj_without_pk, field.attname, result)
obj_without_pk._state.adding = False
obj_without_pk._state.db = self.db
@ -1181,7 +1191,7 @@ class QuerySet:
# PRIVATE METHODS #
###################
def _insert(self, objs, fields, return_id=False, raw=False, using=None, ignore_conflicts=False):
def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
"""
Insert a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented.
@ -1191,7 +1201,7 @@ class QuerySet:
using = self.db
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(return_id)
return query.get_compiler(using=using).execute_sql(returning_fields)
_insert.alters_data = True
_insert.queryset_only = False
@ -1203,21 +1213,22 @@ class QuerySet:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
ops = connections[self.db].ops
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
inserted_ids = []
inserted_rows = []
bulk_return = connections[self.db].features.can_return_rows_from_bulk_insert
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and not ignore_conflicts:
inserted_id = self._insert(
item, fields=fields, using=self.db, return_id=True,
inserted_columns = self._insert(
item, fields=fields, using=self.db,
returning_fields=self.model._meta.db_returning_fields,
ignore_conflicts=ignore_conflicts,
)
if isinstance(inserted_id, list):
inserted_ids.extend(inserted_id)
if isinstance(inserted_columns, list):
inserted_rows.extend(inserted_columns)
else:
inserted_ids.append(inserted_id)
inserted_rows.append(inserted_columns)
else:
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
return inserted_ids
return inserted_rows
def _chain(self, **kwargs):
"""