diff --git a/lib/extras.py b/lib/extras.py index 0764edfc..3e113b06 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -1198,7 +1198,7 @@ def execute_batch(cur, sql, argslist, page_size=100): cur.execute(b";".join(sqls)) -def execute_values(cur, sql, argslist, template=None, page_size=100): +def execute_values(cur, sql, argslist, template=None, page_size=100, fetch_result=False): '''Execute a statement using :sql:`VALUES` with a sequence of parameters. :param cur: the cursor to use to execute the query. @@ -1229,6 +1229,9 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): statement. If there are more items the function will execute more than one statement. + :param fetch_result: flag indicating that results of query execution should + be returned. Useful for queries with `RETURNING` clause + .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will @@ -1265,6 +1268,7 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): sql = sql.encode(_ext.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) + result = [] for page in _paginate(argslist, page_size=page_size): if template is None: template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' @@ -1274,6 +1278,11 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): parts.append(b',') parts[-1:] = post cur.execute(b''.join(parts)) + if fetch_result: + result.extend(cur.fetchall()) + + if fetch_result: + return result def _split_sql(sql):