From 8743c0c24673b52f3cdb06a1dcadab7c1bb1c66c Mon Sep 17 00:00:00 2001 From: Luan-233 <2533556772@qq.com> Date: Fri, 11 Aug 2023 10:50:27 +0800 Subject: [PATCH] Add a new kind of place holder --- lib/extras.py | 54 +++--- psycopg/bytes_format.c | 399 +++++++++++++++++++++++---------------- psycopg/cursor_type.c | 411 +++++++++++++++++++++++++---------------- psycopg/utils.h | 4 +- 4 files changed, 523 insertions(+), 345 deletions(-) diff --git a/lib/extras.py b/lib/extras.py index 36e8ef9a..4853ce63 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -140,15 +140,15 @@ class DictCursor(DictCursorBase): super().__init__(*args, **kwargs) self._prefetch = True - def execute(self, query, vars=None): + def execute(self, query, vars=None, place_holder = '%'): self.index = OrderedDict() self._query_executed = True - return super().execute(query, vars) + return super().execute(query, vars, place_holder) - def callproc(self, procname, vars=None): + def callproc(self, procname, vars=None, place_holder = '%'): self.index = OrderedDict() self._query_executed = True - return super().callproc(procname, vars) + return super().callproc(procname, vars, place_holder) def _build_index(self): if self._query_executed and self.description: @@ -230,15 +230,15 @@ class RealDictCursor(DictCursorBase): kwargs['row_factory'] = RealDictRow super().__init__(*args, **kwargs) - def execute(self, query, vars=None): + def execute(self, query, vars=None, place_holder = '%'): self.column_mapping = [] self._query_executed = True - return super().execute(query, vars) + return super().execute(query, vars, place_holder) - def callproc(self, procname, vars=None): + def callproc(self, procname, vars=None, place_holder = '%'): self.column_mapping = [] self._query_executed = True - return super().callproc(procname, vars) + return super().callproc(procname, vars, place_holder) def _build_index(self): if self._query_executed and self.description: @@ -307,17 +307,17 @@ class NamedTupleCursor(_cursor): Record = None MAX_CACHE = 1024 - def execute(self, query, vars=None): + def execute(self, query, vars=None, place_holder = '%'): self.Record = None - return super().execute(query, vars) + return super().execute(query, vars, place_holder) - def executemany(self, query, vars): + def executemany(self, query, vars, place_holder = '%'): self.Record = None - return super().executemany(query, vars) + return super().executemany(query, vars, place_holder) - def callproc(self, procname, vars=None): + def callproc(self, procname, vars=None, place_holder = '%'): self.Record = None - return super().callproc(procname, vars) + return super().callproc(procname, vars, place_holder) def fetchone(self): t = super().fetchone() @@ -440,15 +440,15 @@ class LoggingConnection(_connection): class LoggingCursor(_cursor): """A cursor that logs queries using its connection logging facilities.""" - def execute(self, query, vars=None): + def execute(self, query, vars=None, place_holder = '%'): try: - return super().execute(query, vars) + return super().execute(query, vars, place_holder) finally: self.connection.log(self.query, self) - def callproc(self, procname, vars=None): + def callproc(self, procname, vars=None, place_holder = '%'): try: - return super().callproc(procname, vars) + return super().callproc(procname, vars, place_holder) finally: self.connection.log(self.query, self) @@ -484,13 +484,13 @@ class MinTimeLoggingConnection(LoggingConnection): class MinTimeLoggingCursor(LoggingCursor): """The cursor sub-class companion to `MinTimeLoggingConnection`.""" - def execute(self, query, vars=None): + def execute(self, query, vars=None, place_holder = '%'): self.timestamp = _time.time() - return LoggingCursor.execute(self, query, vars) + return LoggingCursor.execute(self, query, vars, place_holder) - def callproc(self, procname, vars=None): + def callproc(self, procname, vars=None, place_holder = '%'): self.timestamp = _time.time() - return LoggingCursor.callproc(self, procname, vars) + return LoggingCursor.callproc(self, procname, vars, place_holder) class LogicalReplicationConnection(_replicationConnection): @@ -1191,7 +1191,7 @@ def _paginate(seq, page_size): return -def execute_batch(cur, sql, argslist, page_size=100): +def execute_batch(cur, sql, argslist, page_size=100, place_holder = '%'): r"""Execute groups of statements in fewer server roundtrips. Execute *sql* several times, against all parameters set (sequences or @@ -1212,11 +1212,11 @@ def execute_batch(cur, sql, argslist, page_size=100): """ for page in _paginate(argslist, page_size=page_size): - sqls = [cur.mogrify(sql, args) for args in page] + sqls = [cur.mogrify(sql, args, place_holder) for args in page] cur.execute(b";".join(sqls)) -def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False): +def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False, place_holder = '%'): '''Execute a statement using :sql:`VALUES` with a sequence of parameters. :param cur: the cursor to use to execute the query. @@ -1293,7 +1293,7 @@ def execute_values(cur, sql, argslist, template=None, page_size=100, fetch=False template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' parts = pre[:] for args in page: - parts.append(cur.mogrify(template, args)) + parts.append(cur.mogrify(template, args, place_holder)) parts.append(b',') parts[-1:] = post cur.execute(b''.join(parts)) @@ -1337,4 +1337,4 @@ def _split_sql(sql): # ascii except alnum and underscore _re_clean = _re.compile( - '[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']') + '[' + _re.escape(' !"#$%&\'()*+,-./:;<=>?@[\\]^`{|}~') + ']') \ No newline at end of file diff --git a/psycopg/bytes_format.c b/psycopg/bytes_format.c index d34a0171..38eb8213 100644 --- a/psycopg/bytes_format.c +++ b/psycopg/bytes_format.c @@ -85,225 +85,314 @@ /* Helpers for formatstring */ BORROWED Py_LOCAL_INLINE(PyObject *) -getnextarg(PyObject *args, Py_ssize_t arglen, Py_ssize_t *p_argidx) -{ +getnextarg(PyObject *args, Py_ssize_t arglen, Py_ssize_t *p_argidx) { Py_ssize_t argidx = *p_argidx; if (argidx < arglen) { (*p_argidx)++; - if (arglen < 0) - return args; - else - return PyTuple_GetItem(args, argidx); + if (arglen < 0) return args; + else return PyTuple_GetItem(args, argidx); } - PyErr_SetString(PyExc_TypeError, - "not enough arguments for format string"); return NULL; } +/* + for function getnextarg: + I delete the line including 'raise error', for making this func a iterator + just used to fill in the arguments array +*/ + /* wrapper around _Bytes_Resize offering normal Python call semantics */ STEALS(1) Py_LOCAL_INLINE(PyObject *) resize_bytes(PyObject *b, Py_ssize_t newsize) { - if (0 == _Bytes_Resize(&b, newsize)) { - return b; - } - else { - return NULL; - } + if (0 == _Bytes_Resize(&b, newsize)) return b; + else return NULL; } -/* fmt%(v1,v2,...) is roughly equivalent to sprintf(fmt, v1, v2, ...) */ - -PyObject * -Bytes_Format(PyObject *format, PyObject *args) -{ - char *fmt, *res; - Py_ssize_t arglen, argidx; - Py_ssize_t reslen, rescnt, fmtcnt; - int args_owned = 0; - PyObject *result; - PyObject *dict = NULL; - if (format == NULL || !Bytes_Check(format) || args == NULL) { +PyObject *Bytes_Format(PyObject *format, PyObject *args, char place_holder) { + char *fmt, *res; //array pointer of format, and array pointer of result + Py_ssize_t arglen, argidx; //length of arguments array, and index of arguments(when processing args_list) + Py_ssize_t reslen, rescnt, fmtcnt; //rescnt: blank space in result; reslen: the total length of result; fmtcnt: length of format + int args_owned = 0; //args is valid or invalid(or maybe refcnt), 0 for invalid,1 otherwise + PyObject *result; //function's return value + PyObject *dict = NULL; //dictionary + PyObject *args_value = NULL; //every argument store in it after parse + char **args_list = NULL; //arguments list as char ** + char *args_buffer = NULL; //Bytes_AS_STRING(args_value) + Py_ssize_t * args_len = NULL; //every argument's length in args_list + int args_id = 0; //index of arguments(when generating result) + int index_type = 0; //if exists $number, it will be 1, otherwise 0 + + if (format == NULL || !Bytes_Check(format) || args == NULL) { //check if arguments are valid PyErr_SetString(PyExc_SystemError, "bad argument to internal function"); return NULL; } - fmt = Bytes_AS_STRING(format); - fmtcnt = Bytes_GET_SIZE(format); - reslen = rescnt = fmtcnt + 100; - result = Bytes_FromStringAndSize((char *)NULL, reslen); - if (result == NULL) - return NULL; + fmt = Bytes_AS_STRING(format); //get pointer of format + fmtcnt = Bytes_GET_SIZE(format); //get length of format + reslen = rescnt = 1; + while (reslen <= fmtcnt) { //when space is not enough, double it's size + reslen *= 2; + rescnt *= 2; + } + result = Bytes_FromStringAndSize((char *)NULL, reslen); + if (result == NULL) return NULL; res = Bytes_AS_STRING(result); - if (PyTuple_Check(args)) { + if (PyTuple_Check(args)) { //check if arguments are sequences arglen = PyTuple_GET_SIZE(args); argidx = 0; } - else { + else { //if no, then this two are of no importance arglen = -1; argidx = -2; } - if (Py_TYPE(args)->tp_as_mapping && !PyTuple_Check(args) && - !PyObject_TypeCheck(args, &Bytes_Type)) + if (Py_TYPE(args)->tp_as_mapping && !PyTuple_Check(args) && !PyObject_TypeCheck(args, &Bytes_Type)) { //check if args is dict dict = args; - while (--fmtcnt >= 0) { - if (*fmt != '%') { + //Py_INCREF(dict); + } + while (--fmtcnt >= 0) { //scan the format + if (*fmt != '%') { //if not %, pass it(for the special format '%(name)s') if (--rescnt < 0) { - rescnt = fmtcnt + 100; - reslen += rescnt; + rescnt = reslen; //double the space + reslen *= 2; if (!(result = resize_bytes(result, reslen))) { return NULL; } - res = Bytes_AS_STRING(result) + reslen - rescnt; + res = Bytes_AS_STRING(result) + reslen - rescnt;//calculate offset --rescnt; } - *res++ = *fmt++; + *res++ = *fmt++; //copy } else { /* Got a format specifier */ - Py_ssize_t width = -1; - int c = '\0'; - PyObject *v = NULL; - PyObject *temp = NULL; - char *pbuf; - Py_ssize_t len; fmt++; if (*fmt == '(') { - char *keystart; - Py_ssize_t keylen; + char *keystart; //begin pos of left bracket + Py_ssize_t keylen; //length of content in bracket PyObject *key; - int pcount = 1; - + int pcount = 1; //counter of left bracket + Py_ssize_t length = 0; if (dict == NULL) { - PyErr_SetString(PyExc_TypeError, - "format requires a mapping"); + PyErr_SetString(PyExc_TypeError, "format requires a mapping"); goto error; } ++fmt; --fmtcnt; keystart = fmt; /* Skip over balanced parentheses */ - while (pcount > 0 && --fmtcnt >= 0) { - if (*fmt == ')') - --pcount; - else if (*fmt == '(') - ++pcount; + while (pcount > 0 && --fmtcnt >= 0) { //find the matching right bracket + if (*fmt == ')') --pcount; + else if (*fmt == '(') ++pcount; fmt++; } keylen = fmt - keystart - 1; - if (fmtcnt < 0 || pcount > 0) { - PyErr_SetString(PyExc_ValueError, - "incomplete format key"); + if (fmtcnt < 0 || pcount > 0 || *(fmt++) != 's') { //not found, raise an error + PyErr_SetString(PyExc_ValueError, "incomplete format key"); goto error; } - key = Text_FromUTF8AndSize(keystart, keylen); - if (key == NULL) - goto error; - if (args_owned) { + --fmtcnt; + key = Text_FromUTF8AndSize(keystart, keylen);//get key + if (key == NULL) goto error; + if (args_owned) { //if refcnt > 0, then release Py_DECREF(args); args_owned = 0; } - args = PyObject_GetItem(dict, key); + args = PyObject_GetItem(dict, key); //get value with key Py_DECREF(key); - if (args == NULL) { + if (args == NULL) goto error; + if (!Bytes_CheckExact(args)) { + PyErr_Format(PyExc_ValueError, "only bytes values expected, got %s", Py_TYPE(args)->tp_name); //raise error, but may have bug goto error; } + args_buffer = Bytes_AS_STRING(args); //temporary buffer + length = Bytes_GET_SIZE(args); + if (rescnt < length) { + while (rescnt < length) { + rescnt += reslen; + reslen *= 2; + } + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + } + res = Bytes_AS_STRING(result) + reslen - rescnt; + Py_MEMCPY(res, args_buffer, length); + rescnt -= length; + res += length; args_owned = 1; - arglen = -1; + arglen = -1; //exists place holder as "%(name)s", set these arguments to invalid argidx = -2; } - while (--fmtcnt >= 0) { - c = *fmt++; - break; - } - if (fmtcnt < 0) { - PyErr_SetString(PyExc_ValueError, - "incomplete format"); - goto error; - } - switch (c) { - case '%': - pbuf = "%"; - len = 1; - break; - case 's': - /* only bytes! */ - if (!(v = getnextarg(args, arglen, &argidx))) - goto error; - if (!Bytes_CheckExact(v)) { - PyErr_Format(PyExc_ValueError, - "only bytes values expected, got %s", - Py_TYPE(v)->tp_name); - goto error; - } - temp = v; - Py_INCREF(v); - pbuf = Bytes_AS_STRING(temp); - len = Bytes_GET_SIZE(temp); - break; - default: - PyErr_Format(PyExc_ValueError, - "unsupported format character '%c' (0x%x) " - "at index " FORMAT_CODE_PY_SSIZE_T, - c, c, - (Py_ssize_t)(fmt - 1 - Bytes_AS_STRING(format))); - goto error; - } - if (width < len) - width = len; - if (rescnt < width) { - reslen -= rescnt; - rescnt = width + fmtcnt + 100; - reslen += rescnt; - if (reslen < 0) { - Py_DECREF(result); - Py_XDECREF(temp); - if (args_owned) - Py_DECREF(args); - return PyErr_NoMemory(); - } - if (!(result = resize_bytes(result, reslen))) { - Py_XDECREF(temp); - if (args_owned) - Py_DECREF(args); - return NULL; - } - res = Bytes_AS_STRING(result) - + reslen - rescnt; - } - Py_MEMCPY(res, pbuf, len); - res += len; - rescnt -= len; - while (--width >= len) { - --rescnt; - *res++ = ' '; - } - if (dict && (argidx < arglen) && c != '%') { - PyErr_SetString(PyExc_TypeError, - "not all arguments converted during string formatting"); - Py_XDECREF(temp); - goto error; - } - Py_XDECREF(temp); } /* '%' */ } /* until end */ - if (argidx < arglen && !dict) { - PyErr_SetString(PyExc_TypeError, - "not all arguments converted during string formatting"); + + if (dict) { //if args' type is dict, the func ends + if (args_owned) Py_DECREF(args); + if (!(result = resize_bytes(result, reslen - rescnt))) return NULL; //resize and return + if (place_holder != '%') { + PyErr_SetString(PyExc_TypeError, "place holder only expect %% when using dict"); + goto error; + } + return result; + } + + args_list = (char **)malloc(sizeof(char *) * arglen); //buffer + args_len = (Py_ssize_t *)malloc(sizeof(Py_ssize_t *) * arglen); //length of every argument + while ((args_value = getnextarg(args, arglen, &argidx)) != NULL) { //stop when receive NULL + Py_ssize_t length = 0; + if (!Bytes_CheckExact(args_value)) { + PyErr_Format(PyExc_ValueError, "only bytes values expected, got %s", Py_TYPE(args_value)->tp_name); //may have bug + goto error; + } + Py_INCREF(args_value); //increase refcnt + args_buffer = Bytes_AS_STRING(args_value); + length = Bytes_GET_SIZE(args_value); + //printf("type: %s, len: %d, value: %s\n", Py_TYPE(args_value)->tp_name, length, args_buffer); + args_len[argidx - 1] = length; + args_list[argidx - 1] = (char *)malloc(sizeof(char *) * (length + 1)); + Py_MEMCPY(args_list[argidx - 1], args_buffer, length); + args_list[argidx - 1][length] = '\0'; + Py_XDECREF(args_value); + } + + fmt = Bytes_AS_STRING(format); //get pointer of format + fmtcnt = Bytes_GET_SIZE(format); //get length of format + reslen = rescnt = 1; + while (reslen <= fmtcnt) { + reslen *= 2; + rescnt *= 2; + } + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + res = Bytes_AS_STRING(result); + memset(res, 0, sizeof(char) * reslen); + + while (*fmt != '\0') { + if (*fmt != place_holder) { //not place holder, pass it + if (!rescnt) { + rescnt += reslen; + reslen *= 2; + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + res = Bytes_AS_STRING(result) + reslen - rescnt; + } + *(res++) = *(fmt++); + --rescnt; + continue; + } + if (*fmt == '%') { + char c = *(++fmt); + if (c == '\0') { //if there is nothing after '%', raise an error + PyErr_SetString(PyExc_ValueError, "incomplete format"); + goto error; + } + else if (c == '%') { //'%%' will be transfered to '%' + if (!rescnt) { + rescnt += reslen; + reslen *= 2; + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + res = Bytes_AS_STRING(result) + reslen - rescnt; + } + *res = c; + --rescnt; + ++res; + ++fmt; + } + else if (c == 's') { //'%s', replace it with corresponding string + if (args_id >= arglen) { //index is out of bound + PyErr_SetString(PyExc_TypeError, "arguments not enough during string formatting"); + goto error; + } + if (rescnt < args_len[args_id]) { + while (rescnt < args_len[args_id]) { + rescnt += reslen; + reslen *= 2; + } + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + res = Bytes_AS_STRING(result) + reslen - rescnt; + } + Py_MEMCPY(res, args_list[args_id], args_len[args_id]); + rescnt -= args_len[args_id]; + res += args_len[args_id]; + ++args_id; + ++fmt; + } + else { //not support the character currently + PyErr_Format(PyExc_ValueError, "unsupported format character '%c' (0x%x) " + "at index " FORMAT_CODE_PY_SSIZE_T, + c, c, + (Py_ssize_t)(fmt - 1 - Bytes_AS_STRING(format))); + goto error; + } + continue; + } + if (*fmt == '$') { + char c = *(++fmt); + if (c == '\0') { //if there is nothing after '$', raise an error + PyErr_SetString(PyExc_ValueError, "incomplete format"); + goto error; + } + else if (c == '$') { //'$$' will be transfered to'$' + if (!rescnt) { //resize buffer + rescnt += reslen; + reslen *= 2; + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + res = Bytes_AS_STRING(result) + reslen - rescnt; + } + *res = c; + --rescnt; + ++res; + ++fmt; + } + else if (isdigit(c)) { //represents '$number' + int index = 0; + index_type = 1; + while (isdigit(*fmt)) { + index = index * 10 + (*fmt) -'0'; + ++fmt; + } + if ((index > arglen) || (index <= 0)) { //invalid index + PyErr_SetString(PyExc_ValueError, "invalid index"); + goto error; + } + --index; + if (rescnt < args_len[index]) { + while (rescnt < args_len[index]) { + rescnt += reslen; + reslen *= 2; + } + if ((result = resize_bytes(result, reslen)) == NULL) goto error; + res = Bytes_AS_STRING(result) + reslen - rescnt; + } + Py_MEMCPY(res, args_list[index], args_len[index]); + rescnt -= args_len[index]; + res += args_len[index]; + } + else { //invalid place holder + PyErr_Format(PyExc_ValueError, "unsupported format character '%c' (0x%x) " + "at index " FORMAT_CODE_PY_SSIZE_T, + c, c, + (Py_ssize_t)(fmt - 1 - Bytes_AS_STRING(format))); + goto error; + } + } + } + if ((args_id < arglen) && (!dict) && (!index_type)) { //not all arguments are used + PyErr_SetString(PyExc_TypeError, "not all arguments converted during string formatting"); goto error; } - if (args_owned) { - Py_DECREF(args); - } - if (!(result = resize_bytes(result, reslen - rescnt))) { - return NULL; + if (args_list != NULL) { + while (--argidx >= 0) free(args_list[argidx]); + free(args_list); + free(args_len); } + if (args_owned) Py_DECREF(args); + if (!(result = resize_bytes(result, reslen - rescnt))) return NULL; //resize return result; - error: - Py_DECREF(result); - if (args_owned) { - Py_DECREF(args); +error: + if (args_list != NULL) { //release all the refcnt + while (--argidx >= 0) free(args_list[argidx]); + free(args_list); + free(args_len); } + Py_DECREF(result); + if (args_owned) Py_DECREF(args); return NULL; -} +} \ No newline at end of file diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index efdeefcc..74ab1e71 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -127,12 +127,13 @@ exit: /* mogrify a query string and build argument array or dict */ RAISES_NEG static int -_mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new) +_mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new, char place_holder) { PyObject *key, *value, *n; const char *d, *c; Py_ssize_t index = 0; int force = 0, kind = 0; + int max_index = 0; /* from now on we'll use n and replace its value in *new only at the end, just before returning. we also init *new to NULL to exit with an error @@ -141,164 +142,199 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new) c = Bytes_AsString(fmt); while(*c) { - if (*c++ != '%') { - /* a regular character */ - continue; - } - - switch (*c) { - - /* handle plain percent symbol in format string */ - case '%': + while ((*c != '\0') && (*c != place_holder)) ++c; + if (*c == '%') { ++c; force = 1; - break; + if (*c == '(') { + if ((kind == 2) || (kind == 3)) { + Py_XDECREF(n); + psyco_set_error(ProgrammingError, curs, + "argument formats can't be mixed"); + return -1; + } + kind = 1; - /* if we find '%(' then this is a dictionary, we: - 1/ find the matching ')' and extract the key name - 2/ locate the value in the dictionary (or return an error) - 3/ mogrify the value into something useful (quoting)... - 4/ ...and add it to the new dictionary to be used as argument - */ - case '(': - /* check if some crazy guy mixed formats */ - if (kind == 2) { - Py_XDECREF(n); - psyco_set_error(ProgrammingError, curs, - "argument formats can't be mixed"); - return -1; + /* let's have d point the end of the argument */ + for (d = c + 1; *d && *d != ')' && *d != '%'; d++); + + if (*d == ')') { + if (!(key = Text_FromUTF8AndSize(c+1, (Py_ssize_t)(d-c-1)))) { + Py_XDECREF(n); + return -1; + } + + /* if value is NULL we did not find the key (or this is not a + dictionary): let python raise a KeyError */ + if (!(value = PyObject_GetItem(var, key))) { + Py_DECREF(key); /* destroy key */ + Py_XDECREF(n); /* destroy n */ + return -1; + } + /* key has refcnt 1, value the original value + 1 */ + + Dprintf("_mogrify: value refcnt: " + FORMAT_CODE_PY_SSIZE_T " (+1)", Py_REFCNT(value)); + + if (n == NULL) { + if (!(n = PyDict_New())) { + Py_DECREF(key); + Py_DECREF(value); + return -1; + } + } + + if (0 == PyDict_Contains(n, key)) { + PyObject *t = NULL; + + /* None is always converted to NULL; this is an + optimization over the adapting code and can go away in + the future if somebody finds a None adapter useful. */ + if (value == Py_None) { + Py_INCREF(psyco_null); + t = psyco_null; + PyDict_SetItem(n, key, t); + /* t is a new object, refcnt = 1, key is at 2 */ + } + else { + t = microprotocol_getquoted(value, curs->conn); + if (t != NULL) { + PyDict_SetItem(n, key, t); + /* both key and t refcnt +1, key is at 2 now */ + } + else { + /* no adapter found, raise a BIG exception */ + Py_DECREF(key); + Py_DECREF(value); + Py_DECREF(n); + return -1; + } + } + + Py_XDECREF(t); /* t dies here */ + } + Py_DECREF(value); + Py_DECREF(key); /* key has the original refcnt now */ + Dprintf("_mogrify: after value refcnt: " + FORMAT_CODE_PY_SSIZE_T, Py_REFCNT(value)); + } + else { + /* we found %( but not a ) */ + Py_XDECREF(n); + psyco_set_error(ProgrammingError, curs, + "incomplete placeholder: '%(' without ')'"); + return -1; + } + c = d + 1; } - kind = 1; + else if (*c == 's') { + /* this is a format that expects a tuple; it is much easier, + because we don't need to check the old/new dictionary for + keys */ - /* let's have d point the end of the argument */ - for (d = c + 1; *d && *d != ')' && *d != '%'; d++); + /* check if some crazy guy mixed formats */ + if ((kind == 1) || (kind == 3)) { + Py_XDECREF(n); + psyco_set_error(ProgrammingError, curs, + "argument formats can't be mixed"); + return -1; + } + kind = 2; - if (*d == ')') { - if (!(key = Text_FromUTF8AndSize(c+1, (Py_ssize_t)(d-c-1)))) { + value = PySequence_GetItem(var, index); + /* value has refcnt inc'ed by 1 here */ + + /* if value is NULL this is not a sequence or the index is wrong; + anyway we let python set its own exception */ + if (value == NULL) { Py_XDECREF(n); return -1; } - /* if value is NULL we did not find the key (or this is not a - dictionary): let python raise a KeyError */ - if (!(value = PyObject_GetItem(var, key))) { - Py_DECREF(key); /* destroy key */ - Py_XDECREF(n); /* destroy n */ - return -1; - } - /* key has refcnt 1, value the original value + 1 */ - - Dprintf("_mogrify: value refcnt: " - FORMAT_CODE_PY_SSIZE_T " (+1)", Py_REFCNT(value)); - if (n == NULL) { - if (!(n = PyDict_New())) { - Py_DECREF(key); + if (!(n = PyTuple_New(PyObject_Length(var)))) { Py_DECREF(value); return -1; } } - if (0 == PyDict_Contains(n, key)) { - PyObject *t = NULL; - - /* None is always converted to NULL; this is an - optimization over the adapting code and can go away in - the future if somebody finds a None adapter useful. */ - if (value == Py_None) { - Py_INCREF(psyco_null); - t = psyco_null; - PyDict_SetItem(n, key, t); - /* t is a new object, refcnt = 1, key is at 2 */ - } - else { - t = microprotocol_getquoted(value, curs->conn); - if (t != NULL) { - PyDict_SetItem(n, key, t); - /* both key and t refcnt +1, key is at 2 now */ - } - else { - /* no adapter found, raise a BIG exception */ - Py_DECREF(key); - Py_DECREF(value); - Py_DECREF(n); - return -1; - } - } - - Py_XDECREF(t); /* t dies here */ - } - Py_DECREF(value); - Py_DECREF(key); /* key has the original refcnt now */ - Dprintf("_mogrify: after value refcnt: " - FORMAT_CODE_PY_SSIZE_T, Py_REFCNT(value)); - } - else { - /* we found %( but not a ) */ - Py_XDECREF(n); - psyco_set_error(ProgrammingError, curs, - "incomplete placeholder: '%(' without ')'"); - return -1; - } - c = d + 1; /* after the ) */ - break; - - default: - /* this is a format that expects a tuple; it is much easier, - because we don't need to check the old/new dictionary for - keys */ - - /* check if some crazy guy mixed formats */ - if (kind == 1) { - Py_XDECREF(n); - psyco_set_error(ProgrammingError, curs, - "argument formats can't be mixed"); - return -1; - } - kind = 2; - - value = PySequence_GetItem(var, index); - /* value has refcnt inc'ed by 1 here */ - - /* if value is NULL this is not a sequence or the index is wrong; - anyway we let python set its own exception */ - if (value == NULL) { - Py_XDECREF(n); - return -1; - } - - if (n == NULL) { - if (!(n = PyTuple_New(PyObject_Length(var)))) { - Py_DECREF(value); - return -1; - } - } - - /* let's have d point just after the '%' */ - if (value == Py_None) { - Py_INCREF(psyco_null); - PyTuple_SET_ITEM(n, index, psyco_null); - Py_DECREF(value); - } - else { - PyObject *t = microprotocol_getquoted(value, curs->conn); - - if (t != NULL) { - PyTuple_SET_ITEM(n, index, t); + /* let's have d point just after the '%' */ + if (value == Py_None) { + Py_INCREF(psyco_null); + PyTuple_SET_ITEM(n, index, psyco_null); Py_DECREF(value); } else { - Py_DECREF(n); - Py_DECREF(value); + PyObject *t = microprotocol_getquoted(value, curs->conn); + + if (t != NULL) { + PyTuple_SET_ITEM(n, index, t); + Py_DECREF(value); + } + else { + Py_DECREF(n); + Py_DECREF(value); + return -1; + } + } + index += 1; + } + } + + else if (*c == '$') { //new place holder $ + int tmp_index = 0; + if ((kind == 1) || (kind == 2)) { + Py_XDECREF(n); + psyco_set_error(ProgrammingError, curs, + "argument formats can't be mixed"); + return -1; + } + kind = 3; //kind = 3 means using + + ++c; + while (isdigit(*c)) { //calculate index + tmp_index = tmp_index * 10 + (*c) -'0'; + ++c; + } + --tmp_index; + + for (; max_index <= tmp_index; ++max_index) { + //to avoid index not cover all arguments, which may cause double free in bytes_format + int id = max_index; + value = PySequence_GetItem(var, id); + if (value == NULL) { + Py_XDECREF(n); return -1; } + if (n == NULL) { + if (!(n = PyTuple_New(PyObject_Length(var)))) { + Py_DECREF(value); + return -1; + } + } + if (value == Py_None) { + Py_INCREF(psyco_null); + PyTuple_SET_ITEM(n, id, psyco_null); + Py_DECREF(value); + } + else { + PyObject *t = microprotocol_getquoted(value, curs->conn); + if (t != NULL) { + PyTuple_SET_ITEM(n, id, t); + Py_DECREF(value); + } + else { + Py_DECREF(n); + Py_DECREF(value); + return -1; + } + } } - index += 1; + } } - if (force && n == NULL) - n = PyTuple_New(0); + if (force && n == NULL) n = PyTuple_New(0); *new = n; return 0; @@ -314,7 +350,7 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new) */ static PyObject * _psyco_curs_merge_query_args(cursorObject *self, - PyObject *query, PyObject *args) + PyObject *query, PyObject *args, char place_holder) { PyObject *fquery; @@ -329,7 +365,7 @@ _psyco_curs_merge_query_args(cursorObject *self, the current exception (we will later restore it if the type or the strings do not match.) */ - if (!(fquery = Bytes_Format(query, args))) { + if (!(fquery = Bytes_Format(query, args, place_holder))) { PyObject *err, *arg, *trace; int pe = 0; @@ -376,7 +412,7 @@ _psyco_curs_merge_query_args(cursorObject *self, RAISES_NEG static int _psyco_curs_execute(cursorObject *self, PyObject *query, PyObject *vars, - long int async, int no_result) + char place_holder, long int async, int no_result) { int res = -1; int tmp; @@ -396,12 +432,12 @@ _psyco_curs_execute(cursorObject *self, the right thing (i.e., what the user expects) */ if (vars && vars != Py_None) { - if (0 > _mogrify(vars, query, self, &cvt)) { goto exit; } + if (0 > _mogrify(vars, query, self, &cvt, place_holder)) { goto exit; } } /* Merge the query to the arguments if needed */ if (cvt) { - if (!(fquery = _psyco_curs_merge_query_args(self, query, cvt))) { + if (!(fquery = _psyco_curs_merge_query_args(self, query, cvt, place_holder))) { goto exit; } } @@ -461,15 +497,28 @@ exit: static PyObject * curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs) { - PyObject *vars = NULL, *operation = NULL; + PyObject *vars = NULL, *operation = NULL, *Place_holder = NULL; + char place_holder = '%'; //default value: '%' - static char *kwlist[] = {"query", "vars", NULL}; + static char *kwlist[] = {"query", "vars", "place_holder", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, - &operation, &vars)) { + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OO", kwlist, + &operation, &vars, &Place_holder)) { return NULL; } + if (Place_holder != NULL) { //if exists place holder argument, it will be checked and parse + if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) { + psyco_set_error(ProgrammingError, self, "can't parse place holder"); + return NULL; + } + if (Bytes_GET_SIZE(Place_holder) != 1) { + psyco_set_error(ProgrammingError, self, "place holder must be a character"); + return NULL; + } + place_holder = Bytes_AS_STRING(Place_holder)[0]; + } + if (self->name != NULL) { if (self->query) { psyco_set_error(ProgrammingError, self, @@ -488,7 +537,7 @@ curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs) EXC_IF_ASYNC_IN_PROGRESS(self, execute); EXC_IF_TPC_PREPARED(self->conn, execute); - if (0 > _psyco_curs_execute(self, operation, vars, self->conn->async, 0)) { + if (0 > _psyco_curs_execute(self, operation, vars, place_holder, self->conn->async, 0)) { return NULL; } @@ -502,20 +551,33 @@ curs_execute(cursorObject *self, PyObject *args, PyObject *kwargs) static PyObject * curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs) { - PyObject *operation = NULL, *vars = NULL; + PyObject *operation = NULL, *vars = NULL, *Place_holder = NULL; PyObject *v, *iter = NULL; + char place_holder = '%'; long rowcount = 0; - static char *kwlist[] = {"query", "vars_list", NULL}; + static char *kwlist[] = {"query", "vars_list", "plae_holder", NULL}; /* reset rowcount to -1 to avoid setting it when an exception is raised */ self->rowcount = -1; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", kwlist, - &operation, &vars)) { + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OO", kwlist, + &operation, &vars, &Place_holder)) { return NULL; } + if (Place_holder != NULL) { + if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) { + psyco_set_error(ProgrammingError, self, "can't parse place holder"); + return NULL; + } + if (Bytes_GET_SIZE(Place_holder) != 1) { + psyco_set_error(ProgrammingError, self, "place holder must be a character"); + return NULL; + } + place_holder = Bytes_AS_STRING(Place_holder)[0]; + } + EXC_IF_CURS_CLOSED(self); EXC_IF_CURS_ASYNC(self, executemany); EXC_IF_TPC_PREPARED(self->conn, executemany); @@ -532,7 +594,7 @@ curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs) } while ((v = PyIter_Next(vars)) != NULL) { - if (0 > _psyco_curs_execute(self, operation, v, 0, 1)) { + if (0 > _psyco_curs_execute(self, operation, v, place_holder, 0, 1)) { Py_DECREF(v); Py_XDECREF(iter); return NULL; @@ -562,7 +624,7 @@ curs_executemany(cursorObject *self, PyObject *args, PyObject *kwargs) static PyObject * _psyco_curs_mogrify(cursorObject *self, - PyObject *operation, PyObject *vars) + PyObject *operation, PyObject *vars, char place_holder) { PyObject *fquery = NULL, *cvt = NULL; @@ -577,13 +639,13 @@ _psyco_curs_mogrify(cursorObject *self, if (vars && vars != Py_None) { - if (0 > _mogrify(vars, operation, self, &cvt)) { + if (0 > _mogrify(vars, operation, self, &cvt, place_holder)) { goto cleanup; } } if (vars && cvt) { - if (!(fquery = _psyco_curs_merge_query_args(self, operation, cvt))) { + if (!(fquery = _psyco_curs_merge_query_args(self, operation, cvt, place_holder))) { goto cleanup; } @@ -606,16 +668,29 @@ cleanup: static PyObject * curs_mogrify(cursorObject *self, PyObject *args, PyObject *kwargs) { - PyObject *vars = NULL, *operation = NULL; + PyObject *vars = NULL, *operation = NULL, *Place_holder = NULL; + char place_holder = '%'; - static char *kwlist[] = {"query", "vars", NULL}; + static char *kwlist[] = {"query", "vars", "place_holder", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", kwlist, - &operation, &vars)) { + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OO", kwlist, + &operation, &vars, &Place_holder)) { return NULL; } - return _psyco_curs_mogrify(self, operation, vars); + if (Place_holder != NULL) { + if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) { + psyco_set_error(ProgrammingError, self, "can't parse place holder"); + return NULL; + } + if (Bytes_GET_SIZE(Place_holder) != 1) { + psyco_set_error(ProgrammingError, self, "place holder must be a character"); + return NULL; + } + place_holder = Bytes_AS_STRING(Place_holder)[0]; + } + + return _psyco_curs_mogrify(self, operation, vars, place_holder); } @@ -1016,11 +1091,25 @@ curs_callproc(cursorObject *self, PyObject *args) PyObject *pvals = NULL; char *cpname = NULL; char **scpnames = NULL; + PyObject *Place_holder = NULL; + char place_holder = '%'; - if (!PyArg_ParseTuple(args, "s#|O", &procname, &procname_len, - ¶meters)) { + if (!PyArg_ParseTuple(args, "s#|OO", &procname, &procname_len, + ¶meters, &Place_holder)) { goto exit; } + + if (Place_holder != NULL) { + if (!(Place_holder = curs_validate_sql_basic(self, Place_holder))) { + psyco_set_error(ProgrammingError, self, "can't parse place holder"); + return NULL; + } + if (Bytes_GET_SIZE(Place_holder) != 1) { + psyco_set_error(ProgrammingError, self, "place holder must be a character"); + return NULL; + } + place_holder = Bytes_AS_STRING(Place_holder)[0]; + } EXC_IF_CURS_CLOSED(self); EXC_IF_ASYNC_IN_PROGRESS(self, callproc); @@ -1114,7 +1203,7 @@ curs_callproc(cursorObject *self, PyObject *args) } if (0 <= _psyco_curs_execute( - self, operation, pvals, self->conn->async, 0)) { + self, operation, pvals, place_holder, self->conn->async, 0)) { /* The dict case is outside DBAPI scope anyway, so simply return None */ if (using_dict) { res = Py_None; @@ -2123,4 +2212,4 @@ PyTypeObject cursorType = { cursor_init, /*tp_init*/ 0, /*tp_alloc*/ cursor_new, /*tp_new*/ -}; +}; \ No newline at end of file diff --git a/psycopg/utils.h b/psycopg/utils.h index 5223d3a5..6a7b0585 100644 --- a/psycopg/utils.h +++ b/psycopg/utils.h @@ -59,7 +59,7 @@ HIDDEN RAISES BORROWED PyObject *psyco_set_error( HIDDEN PyObject *psyco_get_decimal_type(void); -HIDDEN PyObject *Bytes_Format(PyObject *format, PyObject *args); +HIDDEN PyObject *Bytes_Format(PyObject *format, PyObject *args, char place_holder); -#endif /* !defined(UTILS_H) */ +#endif /* !defined(UTILS_H) */ \ No newline at end of file