diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 5031033c..b2c0cd23 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1439,7 +1439,7 @@ exit: /* extension: copy_from - implements COPY FROM */ #define psyco_curs_copy_from_doc \ -"copy_from(file, table, sep='\\t', null='\\\\N', size=8192, columns=None) -- Copy table from file." +"copy_from(file, table, sep='\\t', null='\\\\N', size=8192, columns=None, quote='\"', format='TXT') -- Copy table from file." STEALS(1) static int _psyco_curs_has_read_check(PyObject *o, PyObject **var) @@ -1466,30 +1466,35 @@ static PyObject * psyco_curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = { - "file", "table", "sep", "null", "size", "columns", NULL}; + "file", "table", "sep", "null", "size", "columns", "quote", "format", NULL}; const char *sep = "\t"; const char *null = "\\N"; - const char *command = - "COPY %s%s FROM stdin WITH DELIMITER AS %s NULL AS %s"; + const char *quote = "\""; + const char *format = "TXT"; + Py_ssize_t query_size; char *query = NULL; char *columnlist = NULL; char *quoted_delimiter = NULL; char *quoted_null = NULL; + char *quoted_quote = NULL; const char *table_name; + const char *command; + Py_ssize_t bufsize = DEFAULT_COPYBUFF; PyObject *file, *columns = NULL, *res = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwargs, - "O&s|ssnO", kwlist, + "O&s|ssnOss", kwlist, _psyco_curs_has_read_check, &file, &table_name, &sep, &null, &bufsize, - &columns)) + &columns, "e, &format)) { return NULL; } + EXC_IF_CURS_CLOSED(self); EXC_IF_CURS_ASYNC(self, copy_from); @@ -1508,19 +1513,45 @@ psyco_curs_copy_from(cursorObject *self, PyObject *args, PyObject *kwargs) self->conn, null, -1, NULL, NULL))) { goto exit; } - - query_size = strlen(command) + strlen(table_name) + strlen(columnlist) - + strlen(quoted_delimiter) + strlen(quoted_null) + 1; - if (!(query = PyMem_New(char, query_size))) { - PyErr_NoMemory(); + + if (!(quoted_quote = psycopg_escape_string( + self->conn, quote, 0, NULL, NULL))) { goto exit; } - PyOS_snprintf(query, query_size, command, - table_name, columnlist, quoted_delimiter, quoted_null); + + if(strcmp("TXT", format) == 0){ + // Load by default TXT file type + command = + "COPY %s%s FROM stdin WITH DELIMITER AS %s NULL AS %s"; + query_size = strlen(command) + strlen(table_name) + strlen(columnlist) + + strlen(quoted_delimiter) + strlen(quoted_null) + 1; + + if (!(query = PyMem_New(char, query_size))) { + PyErr_NoMemory(); + goto exit; + } + + PyOS_snprintf(query, query_size, command, + table_name, columnlist, quoted_delimiter, quoted_null); + }else{ + // Load from .CSV + command = + "COPY %s%s FROM stdin WITH DELIMITER AS %s NULL AS %s QUOTE AS %s %s"; + query_size = strlen(command) + strlen(table_name) + strlen(columnlist) + + strlen(quoted_delimiter) + strlen(quoted_null) + strlen(quoted_quote) + + strlen(format) + 1; + + if (!(query = PyMem_New(char, query_size))) { + PyErr_NoMemory(); + goto exit; + } + + PyOS_snprintf(query, query_size, command, + table_name, columnlist, quoted_delimiter, quoted_null, quoted_quote, format); + } Dprintf("psyco_curs_copy_from: query = %s", query); - self->copysize = bufsize; Py_INCREF(file); self->copyfile = file; @@ -1536,6 +1567,7 @@ exit: PyMem_Free(columnlist); PyMem_Free(quoted_delimiter); PyMem_Free(quoted_null); + PyMem_Free(quoted_quote); PyMem_Free(query); return res; diff --git a/tests/test_copy.py b/tests/test_copy.py index 3aa509b5..d15f40f0 100755 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -225,6 +225,29 @@ class CopyTests(ConnectingTestCase): curs.execute("select data from tcopy;") self.assertEqual(curs.fetchone()[0], abin) + def _copy_from_csv(self, curs, nrecs, srec, copykw, mock_columns_enclosed_by): + + f = StringIO() + for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)): + l = c * srec + # Enclose '{1}' and '{2}' in the quote char '{0}' (Defaults to '"') + f.write("%s,%s%s%s\n" % (i,mock_columns_enclosed_by,l,mock_columns_enclosed_by)) + + f.seek(0) + copykw['format'] = 'CSV' + + copykw['sep'] = "," + curs.copy_from(MinimalRead(f), "tcopy", **copykw) + + curs.execute("select count(*) from tcopy") + self.assertEqual(nrecs, curs.fetchone()[0]) + + curs.execute("select data from tcopy where id < %s order by id", + (len(string.ascii_letters),)) + for i, (l,) in enumerate(curs): + self.assertEqual(l, string.ascii_letters[i] * srec) + + def _copy_from(self, curs, nrecs, srec, copykw): f = StringIO() for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)): @@ -376,7 +399,29 @@ conn.close() curs.execute("insert into tcopy values (10, 'hi')") self.assertRaises(ZeroDivisionError, curs.copy_to, BrokenWrite(), "tcopy") - + def test_copy_from_csv(self): + curs = self.conn.cursor() + try: + # 'Quote' should default to '"' + self._copy_from_csv(curs, nrecs=1024, srec=10*1024, copykw={}, mock_columns_enclosed_by='"') + finally: + curs.close() + def test_copy_from_csv_specify_column_enclosure(self): + curs = self.conn.cursor() + try: + self._copy_from_csv(curs, nrecs=1024, srec=10*1024, copykw={'quote': "'"}, mock_columns_enclosed_by="'") + finally: + curs.close() + def test_copy_txt_and_set_quote(self): + # this shouldn't return an error... + # b/c format = 'TXT' by default, and we do not + # override this default here, quote does not get included + # in the COPY ... FROM ... command. + curs = self.conn.cursor() + try: + self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={'quote': "'"}) + finally: + curs.close() decorate_all_tests(CopyTests, skip_copy_if_green)