Refs #27222 -- Adapted RETURNING handling to be usable for UPDATE queries.

Renamed existing methods and abstractions used for INSERT … RETURNING
to be generic enough to be used in the context of UPDATEs as well.

This also consolidates SQL compliant implementations on
BaseDatabaseOperations.
This commit is contained in:
Simon Charette 2025-03-21 21:50:54 -04:00 committed by Mariusz Felisiak
parent dc4ee99152
commit 292b9e6fe8
5 changed files with 51 additions and 53 deletions

View File

@ -208,13 +208,6 @@ class BaseDatabaseOperations:
else: else:
return ["DISTINCT"], [] return ["DISTINCT"], []
def fetch_returned_insert_columns(self, cursor, returning_params):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the newly created data.
"""
return cursor.fetchone()
def force_group_by(self): def force_group_by(self):
""" """
Return a GROUP BY clause to use with a HAVING clause when no grouping Return a GROUP BY clause to use with a HAVING clause when no grouping
@ -358,11 +351,12 @@ class BaseDatabaseOperations:
""" """
return value return value
def return_insert_columns(self, fields): def returning_columns(self, fields):
""" """
For backends that support returning columns as part of an insert query, For backends that support returning columns as part of an insert or
return the SQL and params to append to the INSERT query. The returned update query, return the SQL and params to append to the query.
fragment should contain a format string to hold the appropriate column. The returned fragment should contain a format string to hold the
appropriate column.
""" """
if not fields: if not fields:
return "", () return "", ()
@ -376,10 +370,10 @@ class BaseDatabaseOperations:
] ]
return "RETURNING %s" % ", ".join(columns), () return "RETURNING %s" % ", ".join(columns), ()
def fetch_returned_insert_rows(self, cursor): def fetch_returned_rows(self, cursor, returning_params):
""" """
Given a cursor object that has just performed an INSERT...RETURNING Given a cursor object for a DML query with a RETURNING statement,
statement into a table, return the tuple of returned data. return the selected returning rows of tuples.
""" """
return cursor.fetchall() return cursor.fetchall()

View File

@ -22,7 +22,7 @@ from django.utils.functional import cached_property
from django.utils.regex_helper import _lazy_re_compile from django.utils.regex_helper import _lazy_re_compile
from .base import Database from .base import Database
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime from .utils import BoundVar, BulkInsertMapper, Oracle_datetime
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
@ -298,12 +298,27 @@ END;
def deferrable_sql(self): def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED" return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_columns(self, cursor, returning_params): def returning_columns(self, fields):
columns = [] if not fields:
for param in returning_params: return "", ()
value = param.get_value() field_names = []
columns.append(value[0]) params = []
return tuple(columns) for field in fields:
field_names.append(
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
)
params.append(BoundVar(field))
return "RETURNING %s INTO %s" % (
", ".join(field_names),
", ".join(["%s"] * len(params)),
), tuple(params)
def fetch_returned_rows(self, cursor, returning_params):
return list(zip(*(param.get_value() for param in returning_params)))
def no_limit_value(self): def no_limit_value(self):
return None return None
@ -391,25 +406,6 @@ END;
match_option = "'i'" match_option = "'i'"
return "REGEXP_LIKE(%%s, %%s, %s)" % match_option return "REGEXP_LIKE(%%s, %%s, %s)" % match_option
def return_insert_columns(self, fields):
if not fields:
return "", ()
field_names = []
params = []
for field in fields:
field_names.append(
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
)
params.append(InsertVar(field))
return "RETURNING %s INTO %s" % (
", ".join(field_names),
", ".join(["%s"] * len(params)),
), tuple(params)
def __foreign_key_constraints(self, table_name, recursive): def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
if recursive: if recursive:

View File

@ -4,7 +4,7 @@ import decimal
from .base import Database from .base import Database
class InsertVar: class BoundVar:
""" """
A late-binding cursor variable that can be passed to Cursor.execute A late-binding cursor variable that can be passed to Cursor.execute
as a parameter, in order to receive the id of the row created by an as a parameter, in order to receive the id of the row created by an

View File

@ -1890,7 +1890,7 @@ class SQLInsertCompiler(SQLCompiler):
result.append(on_conflict_suffix_sql) result.append(on_conflict_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for # Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096. # 3rd party backends. Refs #19096.
r_sql, self.returning_params = self.connection.ops.return_insert_columns( r_sql, self.returning_params = self.connection.ops.returning_columns(
self.returning_fields self.returning_fields
) )
if r_sql: if r_sql:
@ -1925,20 +1925,16 @@ class SQLInsertCompiler(SQLCompiler):
cursor.execute(sql, params) cursor.execute(sql, params)
if not self.returning_fields: if not self.returning_fields:
return [] return []
obj_len = len(self.query.objs)
if ( if (
self.connection.features.can_return_rows_from_bulk_insert self.connection.features.can_return_rows_from_bulk_insert
and len(self.query.objs) > 1 and obj_len > 1
) or (
self.connection.features.can_return_columns_from_insert and obj_len == 1
): ):
rows = self.connection.ops.fetch_returned_insert_rows(cursor) rows = self.connection.ops.fetch_returned_rows(
cols = [field.get_col(opts.db_table) for field in self.returning_fields] cursor, self.returning_params
elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1
rows = [
self.connection.ops.fetch_returned_insert_columns(
cursor,
self.returning_params,
) )
]
cols = [field.get_col(opts.db_table) for field in self.returning_fields] cols = [field.get_col(opts.db_table) for field in self.returning_fields]
elif returning_fields and isinstance( elif returning_fields and isinstance(
returning_field := returning_fields[0], AutoField returning_field := returning_fields[0], AutoField

View File

@ -402,6 +402,18 @@ backends.
* :class:`~django.db.backends.base.schema.BaseDatabaseSchemaEditor` and * :class:`~django.db.backends.base.schema.BaseDatabaseSchemaEditor` and
PostgreSQL backends no longer use ``CASCADE`` when dropping a column. PostgreSQL backends no longer use ``CASCADE`` when dropping a column.
* ``DatabaseOperations.return_insert_columns()`` and
``DatabaseOperations.fetch_returned_insert_rows()`` methods are renamed to
``returning_columns()`` and ``fetch_returned_rows()``, respectively, to
denote they can be used in the context ``UPDATE … RETURNING`` statements as
well as ``INSERT … RETURNING``.
* The ``DatabaseOperations.fetch_returned_insert_columns()`` method is removed
and the ``fetch_returned_rows()`` method replacing
``fetch_returned_insert_rows()`` expects both a ``cursor`` and
``returning_params`` to be provided just like
``fetch_returned_insert_columns()`` did.
Dropped support for MariaDB 10.5 Dropped support for MariaDB 10.5
-------------------------------- --------------------------------