diff --git a/NEWS b/NEWS index 3b4a11cc..d97c23de 100644 --- a/NEWS +++ b/NEWS @@ -8,6 +8,8 @@ What's new in psycopg 2.8.6 (:ticket:`#1101`). - Fixed search of mxDateTime headers in virtualenvs (:ticket:`#996`). - Added missing values from errorcodes (:ticket:`#1133`). +- `cursor.query` reports the query of the last :sql:`COPY` opearation too + (:ticket:`#1141`). - `~psycopg2.errorcodes` map and `~psycopg2.errors` classes updated to PostgreSQL 13. - Wheel package compiled against OpenSSL 1.1.1g. diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index f2dd379a..c290c715 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1446,6 +1446,11 @@ curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs) Dprintf("curs_copy_from: query = %s", query); + Py_CLEAR(self->query); + if (!(self->query = Bytes_FromString(query))) { + goto exit; + } + /* This routine stores a borrowed reference. Although it is only held * for the duration of curs_copy_from, nested invocations of * Py_BEGIN_ALLOW_THREADS could surrender control to another thread, @@ -1538,6 +1543,11 @@ curs_copy_to(cursorObject *self, PyObject *args, PyObject *kwargs) Dprintf("curs_copy_to: query = %s", query); + Py_CLEAR(self->query); + if (!(self->query = Bytes_FromString(query))) { + goto exit; + } + self->copysize = 0; Py_INCREF(file); self->copyfile = file; @@ -1615,6 +1625,10 @@ curs_copy_expert(cursorObject *self, PyObject *args, PyObject *kwargs) Py_INCREF(file); self->copyfile = file; + Py_CLEAR(self->query); + Py_INCREF(sql); + self->query = sql; + /* At this point, the SQL statement must be str, not unicode */ if (pq_execute(self, Bytes_AS_STRING(sql), 0, 0, 0) >= 0) { res = Py_None; diff --git a/tests/test_copy.py b/tests/test_copy.py index 05bef213..9274f1d1 100755 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -307,6 +307,28 @@ class CopyTests(ConnectingTestCase): curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy') self.assertEqual(curs.rowcount, -1) + def test_copy_query(self): + curs = self.conn.cursor() + + curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) + self.assert_(b"copy " in curs.query.lower()) + self.assert_(b" from stdin" in curs.query.lower()) + + curs.copy_expert( + "copy tcopy (data) from stdin", + StringIO('ddd\neee\n')) + self.assert_(b"copy " in curs.query.lower()) + self.assert_(b" from stdin" in curs.query.lower()) + + curs.copy_to(StringIO(), "tcopy") + self.assert_(b"copy " in curs.query.lower()) + self.assert_(b" to stdout" in curs.query.lower()) + + curs.execute("insert into tcopy (data) values ('fff')") + curs.copy_expert("copy tcopy to stdout", StringIO()) + self.assert_(b"copy " in curs.query.lower()) + self.assert_(b" to stdout" in curs.query.lower()) + @slow def test_copy_from_segfault(self): # issue #219