Escape table and column names in cursor.copy_from() and .copy_to()

This commit is contained in:
Daniele Varrazzo 2021-05-24 10:56:47 +02:00
parent e5ad0ab2d9
commit 8a2deb39ed
3 changed files with 63 additions and 11 deletions

2
NEWS
View File

@ -7,6 +7,8 @@ What's new in psycopg 2.9
- Dropped support for Python 2.7, 3.4, 3.5 (:tickets:`#1198, #1000, #1197`).
- ``with connection`` starts a transaction on autocommit transactions too
(:ticket:`#941`).
- Escape table and column names in `~cursor.copy_from()` and
`~cursor.copy_to()`.
- Connection exceptions with sqlstate ``08XXX`` reclassified as
`~psycopg2.OperationalError` (a subclass of the previously used
`~psycopg2.DatabaseError`) (:ticket:`#1148`).

View File

@ -1303,11 +1303,9 @@ exit:
/* Return a newly allocated buffer containing the list of columns to be
* copied. On error return NULL and set an exception.
*/
static char *_psyco_curs_copy_columns(PyObject *columns)
static char *_psyco_curs_copy_columns(cursorObject *self, PyObject *columns)
{
PyObject *col, *coliter;
Py_ssize_t collen;
char *colname;
char *columnlist = NULL;
Py_ssize_t bufsize = 512;
Py_ssize_t offset = 1;
@ -1333,15 +1331,28 @@ static char *_psyco_curs_copy_columns(PyObject *columns)
columnlist[0] = '(';
while ((col = PyIter_Next(coliter)) != NULL) {
Py_ssize_t collen;
char *colname;
char *quoted_colname;
if (!(col = psyco_ensure_bytes(col))) {
Py_DECREF(coliter);
goto error;
}
Bytes_AsStringAndSize(col, &colname, &collen);
if (!(quoted_colname = psyco_escape_identifier(
self->conn, colname, collen))) {
Py_DECREF(col);
Py_DECREF(coliter);
goto error;
}
collen = strlen(quoted_colname);
while (offset + collen > bufsize - 2) {
char *tmp;
bufsize *= 2;
if (NULL == (tmp = PyMem_Realloc(columnlist, bufsize))) {
PQfreemem(quoted_colname);
Py_DECREF(col);
Py_DECREF(coliter);
PyErr_NoMemory();
@ -1349,10 +1360,11 @@ static char *_psyco_curs_copy_columns(PyObject *columns)
}
columnlist = tmp;
}
strncpy(&columnlist[offset], colname, collen);
strncpy(&columnlist[offset], quoted_colname, collen);
offset += collen;
columnlist[offset++] = ',';
Py_DECREF(col);
PQfreemem(quoted_colname);
}
Py_DECREF(coliter);
@ -1399,8 +1411,9 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs)
char *columnlist = NULL;
char *quoted_delimiter = NULL;
char *quoted_null = NULL;
char *quoted_table_name = NULL;
const char *table_name;
Py_ssize_t bufsize = DEFAULT_COPYBUFF;
PyObject *file, *columns = NULL, *res = NULL;
@ -1421,8 +1434,9 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs)
EXC_IF_GREEN(copy_from);
EXC_IF_TPC_PREPARED(self->conn, copy_from);
if (NULL == (columnlist = _psyco_curs_copy_columns(columns)))
if (!(columnlist = _psyco_curs_copy_columns(self, columns))) {
goto exit;
}
if (!(quoted_delimiter = psyco_escape_string(
self->conn, sep, -1, NULL, NULL))) {
@ -1434,7 +1448,12 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs)
goto exit;
}
query_size = strlen(command) + strlen(table_name) + strlen(columnlist)
if (!(quoted_table_name = psyco_escape_identifier(
self->conn, table_name, -1))) {
goto exit;
}
query_size = strlen(command) + strlen(quoted_table_name) + strlen(columnlist)
+ strlen(quoted_delimiter) + strlen(quoted_null) + 1;
if (!(query = PyMem_New(char, query_size))) {
PyErr_NoMemory();
@ -1442,7 +1461,7 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs)
}
PyOS_snprintf(query, query_size, command,
table_name, columnlist, quoted_delimiter, quoted_null);
quoted_table_name, columnlist, quoted_delimiter, quoted_null);
Dprintf("curs_copy_from: query = %s", query);
@ -1469,6 +1488,9 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs)
Py_CLEAR(self->copyfile);
exit:
if (quoted_table_name) {
PQfreemem(quoted_table_name);
}
PyMem_Free(columnlist);
PyMem_Free(quoted_delimiter);
PyMem_Free(quoted_null);
@ -1499,6 +1521,7 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs)
char *quoted_null = NULL;
const char *table_name;
char *quoted_table_name = NULL;
PyObject *file = NULL, *columns = NULL, *res = NULL;
if (!PyArg_ParseTupleAndKeywords(
@ -1518,8 +1541,14 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs)
EXC_IF_GREEN(copy_to);
EXC_IF_TPC_PREPARED(self->conn, copy_to);
if (NULL == (columnlist = _psyco_curs_copy_columns(columns)))
if (!(quoted_table_name = psyco_escape_identifier(
self->conn, table_name, -1))) {
goto exit;
}
if (!(columnlist = _psyco_curs_copy_columns(self, columns))) {
goto exit;
}
if (!(quoted_delimiter = psyco_escape_string(
self->conn, sep, -1, NULL, NULL))) {
@ -1531,7 +1560,7 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs)
goto exit;
}
query_size = strlen(command) + strlen(table_name) + strlen(columnlist)
query_size = strlen(command) + strlen(quoted_table_name) + strlen(columnlist)
+ strlen(quoted_delimiter) + strlen(quoted_null) + 1;
if (!(query = PyMem_New(char, query_size))) {
PyErr_NoMemory();
@ -1539,7 +1568,7 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs)
}
PyOS_snprintf(query, query_size, command,
table_name, columnlist, quoted_delimiter, quoted_null);
quoted_table_name, columnlist, quoted_delimiter, quoted_null);
Dprintf("curs_copy_to: query = %s", query);
@ -1560,6 +1589,9 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs)
Py_CLEAR(self->copyfile);
exit:
if (quoted_table_name) {
PQfreemem(quoted_table_name);
}
PyMem_Free(columnlist);
PyMem_Free(quoted_delimiter);
PyMem_Free(quoted_null);

View File

@ -263,6 +263,24 @@ class CopyTests(ConnectingTestCase):
curs.execute("select count(*) from manycols;")
self.assertEqual(curs.fetchone()[0], 2)
def test_copy_funny_names(self):
cols = ["select", "insert", "group"]
curs = self.conn.cursor()
curs.execute('CREATE TEMPORARY TABLE "select" (%s)' % ',\n'.join(
['"%s" int' % c for c in cols]))
curs.execute('INSERT INTO "select" DEFAULT VALUES')
f = StringIO()
curs.copy_to(f, "select", columns=cols)
f.seek(0)
self.assertEqual(f.read().split(), ['\\N'] * len(cols))
f.seek(0)
curs.copy_from(f, "select", columns=cols)
curs.execute('select count(*) from "select";')
self.assertEqual(curs.fetchone()[0], 2)
@skip_before_postgres(8, 2) # they don't send the count
def test_copy_rowcount(self):
curs = self.conn.cursor()