Rework psycopg2.connect() interface.

This commit is contained in:
Oleksandr Shulgin 2015-10-27 12:54:10 +01:00
parent fe4cb0d493
commit 7aba8b3ed0
7 changed files with 303 additions and 63 deletions

View File

@ -24,6 +24,28 @@ functionalities defined by the |DBAPI|_.
>>> psycopg2.extensions.parse_dsn('dbname=test user=postgres password=secret')
{'password': 'secret', 'user': 'postgres', 'dbname': 'test'}
.. function:: make_dsn(**kwargs)
Wrap keyword parameters into a connection string, applying necessary
quoting and escaping any special characters (namely, single quote and
backslash).
Example (note the order of parameters in the resulting string is
arbitrary)::
>>> psycopg2.extensions.make_dsn(dbname='test', user='postgres', password='secret')
'user=postgres dbname=test password=secret'
As a special case, the *database* keyword is translated to *dbname*::
>>> psycopg2.extensions.make_dsn(database='test')
'dbname=test'
An example of quoting (using `print()` for clarity)::
>>> print(psycopg2.extensions.make_dsn(database='test', password="some\\thing ''special"))
password='some\\thing \'\'special' dbname=test
.. class:: connection(dsn, async=False)
Is the class usually returned by the `~psycopg2.connect()` function.

View File

@ -56,7 +56,7 @@ from psycopg2._psycopg import Error, Warning, DataError, DatabaseError, Programm
from psycopg2._psycopg import IntegrityError, InterfaceError, InternalError
from psycopg2._psycopg import NotSupportedError, OperationalError
from psycopg2._psycopg import _connect, apilevel, threadsafety, paramstyle
from psycopg2._psycopg import _connect, parse_args, apilevel, threadsafety, paramstyle
from psycopg2._psycopg import __version__, __libpq_version__
from psycopg2 import tz
@ -80,27 +80,8 @@ else:
_ext.register_adapter(Decimal, Adapter)
del Decimal, Adapter
import re
def _param_escape(s,
re_escape=re.compile(r"([\\'])"),
re_space=re.compile(r'\s')):
"""
Apply the escaping rule required by PQconnectdb
"""
if not s: return "''"
s = re_escape.sub(r'\\\1', s)
if re_space.search(s):
s = "'" + s + "'"
return s
del re
def connect(dsn=None,
database=None, user=None, password=None, host=None, port=None,
connection_factory=None, cursor_factory=None, async=False, **kwargs):
"""
Create a new database connection.
@ -135,33 +116,7 @@ def connect(dsn=None,
library: the list of supported parameters depends on the library version.
"""
items = []
if database is not None:
items.append(('dbname', database))
if user is not None:
items.append(('user', user))
if password is not None:
items.append(('password', password))
if host is not None:
items.append(('host', host))
if port is not None:
items.append(('port', port))
items.extend([(k, v) for (k, v) in kwargs.iteritems() if v is not None])
if dsn is not None and items:
raise TypeError(
"'%s' is an invalid keyword argument when the dsn is specified"
% items[0][0])
if dsn is None:
if not items:
raise TypeError('missing dsn and no parameters')
else:
dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
for (k, v) in items])
conn = _connect(dsn, connection_factory=connection_factory, async=async)
conn = _connect(dsn, connection_factory, async, **kwargs)
if cursor_factory is not None:
conn.cursor_factory = cursor_factory

View File

@ -56,7 +56,8 @@ try:
except ImportError:
pass
from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version, parse_dsn, quote_ident
from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version
from psycopg2._psycopg import parse_dsn, make_dsn, quote_ident
from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type
from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column

View File

@ -119,11 +119,17 @@ typedef struct cursorObject cursorObject;
typedef struct connectionObject connectionObject;
/* some utility functions */
HIDDEN PyObject *psyco_parse_args(PyObject *self, PyObject *args, PyObject *kwargs);
HIDDEN PyObject *psyco_parse_dsn(PyObject *self, PyObject *args, PyObject *kwargs);
HIDDEN PyObject *psyco_make_dsn(PyObject *self, PyObject *args, PyObject *kwargs);
RAISES HIDDEN PyObject *psyco_set_error(PyObject *exc, cursorObject *curs, const char *msg);
HIDDEN char *psycopg_escape_string(connectionObject *conn,
const char *from, Py_ssize_t len, char *to, Py_ssize_t *tolen);
HIDDEN char *psycopg_escape_identifier_easy(const char *from, Py_ssize_t len);
HIDDEN char *psycopg_escape_conninfo(const char *from, Py_ssize_t len);
HIDDEN int psycopg_strdup(char **to, const char *from, Py_ssize_t len);
HIDDEN int psycopg_is_text_file(PyObject *f);

View File

@ -70,24 +70,104 @@ HIDDEN PyObject *psyco_null = NULL;
/* The type of the cursor.description items */
HIDDEN PyObject *psyco_DescriptionType = NULL;
/* finds a keyword or positional arg (pops it from kwargs if found there) */
static PyObject *
parse_arg(int pos, char *name, PyObject *defval, PyObject *args, PyObject *kwargs)
{
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
PyObject *val = NULL;
if (kwargs && PyMapping_HasKeyString(kwargs, name)) {
val = PyMapping_GetItemString(kwargs, name);
Py_XINCREF(val);
PyMapping_DelItemString(kwargs, name); /* pop from the kwargs dict! */
}
if (nargs > pos) {
if (!val) {
val = PyTuple_GET_ITEM(args, pos);
Py_XINCREF(val);
} else {
PyErr_Format(PyExc_TypeError,
"parse_args() got multiple values for keyword argument '%s'", name);
return NULL;
}
}
if (!val) {
val = defval;
Py_XINCREF(val);
}
return val;
}
#define psyco_parse_args_doc \
"parse_args(...) -- parse connection parameters.\n\n" \
"Return a tuple of (dsn, connection_factory, async)"
PyObject *
psyco_parse_args(PyObject *self, PyObject *args, PyObject *kwargs)
{
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
PyObject *dsn = NULL;
PyObject *factory = NULL;
PyObject *async = NULL;
PyObject *res = NULL;
if (nargs > 3) {
PyErr_Format(PyExc_TypeError,
"parse_args() takes at most 3 arguments (%d given)", (int)nargs);
goto exit;
}
/* parse and remove all keywords we know, so they are not interpreted as part of DSN */
if (!(dsn = parse_arg(0, "dsn", Py_None, args, kwargs))) { goto exit; }
if (!(factory = parse_arg(1, "connection_factory", Py_None,
args, kwargs))) { goto exit; }
if (!(async = parse_arg(2, "async", Py_False, args, kwargs))) { goto exit; }
if (kwargs && PyMapping_Size(kwargs) > 0) {
if (dsn == Py_None) {
Py_DECREF(dsn);
if (!(dsn = psyco_make_dsn(NULL, NULL, kwargs))) { goto exit; }
} else {
PyErr_SetString(PyExc_TypeError, "both dsn and parameters given");
goto exit;
}
} else {
if (dsn == Py_None) {
PyErr_SetString(PyExc_TypeError, "missing dsn and no parameters");
goto exit;
}
}
res = PyTuple_Pack(3, dsn, factory, async);
exit:
Py_XDECREF(dsn);
Py_XDECREF(factory);
Py_XDECREF(async);
return res;
}
/** connect module-level function **/
#define psyco_connect_doc \
"_connect(dsn, [connection_factory], [async]) -- New database connection.\n\n"
"_connect(dsn, [connection_factory], [async], **kwargs) -- New database connection.\n\n"
static PyObject *
psyco_connect(PyObject *self, PyObject *args, PyObject *keywds)
{
PyObject *conn = NULL;
PyObject *tuple = NULL;
PyObject *factory = NULL;
const char *dsn = NULL;
int async = 0;
static char *kwlist[] = {"dsn", "connection_factory", "async", NULL};
if (!(tuple = psyco_parse_args(self, args, keywds))) { goto exit; }
if (!PyArg_ParseTupleAndKeywords(args, keywds, "s|Oi", kwlist,
&dsn, &factory, &async)) {
return NULL;
}
if (!PyArg_ParseTuple(tuple, "s|Oi", &dsn, &factory, &async)) { goto exit; }
Dprintf("psyco_connect: dsn = '%s', async = %d", dsn, async);
@ -109,12 +189,16 @@ psyco_connect(PyObject *self, PyObject *args, PyObject *keywds)
conn = PyObject_CallFunction(factory, "si", dsn, async);
}
exit:
Py_XDECREF(tuple);
return conn;
}
#define psyco_parse_dsn_doc "parse_dsn(dsn) -> dict"
static PyObject *
PyObject *
psyco_parse_dsn(PyObject *self, PyObject *args, PyObject *kwargs)
{
char *err = NULL;
@ -166,6 +250,114 @@ exit:
}
#define psyco_make_dsn_doc "make_dsn(**kwargs) -> str"
PyObject *
psyco_make_dsn(PyObject *self, PyObject *args, PyObject *kwargs)
{
Py_ssize_t len, pos;
PyObject *res = NULL;
PyObject *key = NULL, *value = NULL;
PyObject *newkey, *newval;
PyObject *dict = NULL;
char *str = NULL, *p, *q;
if (args && (len = PyTuple_Size(args)) > 0) {
PyErr_Format(PyExc_TypeError, "make_dsn() takes no arguments (%d given)", (int)len);
goto exit;
}
if (kwargs == NULL) {
return Text_FromUTF8("");
}
/* iterate through kwargs, calculating the total resulting string
length and saving prepared key/values to a temp. dict */
if (!(dict = PyDict_New())) { goto exit; }
len = 0;
pos = 0;
while (PyDict_Next(kwargs, &pos, &key, &value)) {
if (value == NULL || value == Py_None) { continue; }
Py_INCREF(key); /* for ensure_bytes */
if (!(newkey = psycopg_ensure_bytes(key))) { goto exit; }
/* special handling of 'database' keyword */
if (strcmp(Bytes_AsString(newkey), "database") == 0) {
key = Bytes_FromString("dbname");
Py_DECREF(newkey);
} else {
key = newkey;
}
/* now transform the value */
if (Bytes_CheckExact(value)) {
Py_INCREF(value);
} else if (PyUnicode_CheckExact(value)) {
if (!(value = PyUnicode_AsUTF8String(value))) { goto exit; }
} else {
/* this could be port=5432, so we need to get the text representation */
if (!(value = PyObject_Str(value))) { goto exit; }
/* and still ensure it's bytes() (but no need to incref here) */
if (!(value = psycopg_ensure_bytes(value))) { goto exit; }
}
/* passing NULL for plen checks for NIL bytes in content and errors out */
if (Bytes_AsStringAndSize(value, &str, NULL) < 0) { goto exit; }
/* escape any special chars */
if (!(str = psycopg_escape_conninfo(str, 0))) { goto exit; }
if (!(newval = Bytes_FromString(str))) {
goto exit;
}
PyMem_Free(str);
str = NULL;
Py_DECREF(value);
value = newval;
/* finally put into the temp. dict */
if (PyDict_SetItem(dict, key, value) < 0) { goto exit; }
len += Bytes_GET_SIZE(key) + Bytes_GET_SIZE(value) + 2; /* =, space or NIL */
Py_DECREF(key);
Py_DECREF(value);
}
key = NULL;
value = NULL;
if (!(str = PyMem_Malloc(len))) {
PyErr_NoMemory();
goto exit;
}
p = str;
pos = 0;
while (PyDict_Next(dict, &pos, &newkey, &newval)) {
if (p != str) {
*(p++) = ' ';
}
if (Bytes_AsStringAndSize(newkey, &q, &len) < 0) { goto exit; }
strncpy(p, q, len);
p += len;
*(p++) = '=';
if (Bytes_AsStringAndSize(newval, &q, &len) < 0) { goto exit; }
strncpy(p, q, len);
p += len;
}
*p = '\0';
res = Text_FromUTF8AndSize(str, p - str);
exit:
PyMem_Free(str);
Py_XDECREF(key);
Py_XDECREF(value);
Py_XDECREF(dict);
return res;
}
#define psyco_quote_ident_doc \
"quote_ident(str, conn_or_curs) -> str -- wrapper around PQescapeIdentifier\n\n" \
":Parameters:\n" \
@ -820,8 +1012,12 @@ error:
static PyMethodDef psycopgMethods[] = {
{"_connect", (PyCFunction)psyco_connect,
METH_VARARGS|METH_KEYWORDS, psyco_connect_doc},
{"parse_args", (PyCFunction)psyco_parse_args,
METH_VARARGS|METH_KEYWORDS, psyco_parse_args_doc},
{"parse_dsn", (PyCFunction)psyco_parse_dsn,
METH_VARARGS|METH_KEYWORDS, psyco_parse_dsn_doc},
{"make_dsn", (PyCFunction)psyco_make_dsn,
METH_VARARGS|METH_KEYWORDS, psyco_make_dsn_doc},
{"quote_ident", (PyCFunction)psyco_quote_ident,
METH_VARARGS|METH_KEYWORDS, psyco_quote_ident_doc},
{"adapt", (PyCFunction)psyco_microprotocols_adapt,

View File

@ -124,6 +124,50 @@ psycopg_escape_identifier_easy(const char *from, Py_ssize_t len)
return rv;
}
char *
psycopg_escape_conninfo(const char *from, Py_ssize_t len)
{
char *rv = NULL;
const char *src;
const char *end;
char *dst;
int space = 0;
if (!len) { len = strlen(from); }
end = from + len;
if (!(rv = PyMem_Malloc(3 + 2 * len))) {
PyErr_NoMemory();
return NULL;
}
/* check for any whitespace or empty string */
if (from < end && *from) {
for (src = from; src < end && *src; ++src) {
if (isspace(*src)) {
space = 1;
break;
}
}
} else {
/* empty string: we should produce '' */
space = 1;
}
dst = rv;
if (space) { *(dst++) = '\''; }
/* scan and copy */
for (src = from; src < end && *src; ++src, ++dst) {
if (*src == '\'' || *src == '\\')
*(dst++) = '\\';
*dst = *src;
}
if (space) { *(dst++) = '\''; }
*dst = '\0';
return rv;
}
/* Duplicate a string.
*
* Allocate a new buffer on the Python heap containing the new string.

View File

@ -34,11 +34,11 @@ import psycopg2
class ConnectTestCase(unittest.TestCase):
def setUp(self):
self.args = None
def conect_stub(dsn, connection_factory=None, async=False):
self.args = (dsn, connection_factory, async)
def connect_stub(*args, **kwargs):
self.args = psycopg2.parse_args(*args, **kwargs)
self._connect_orig = psycopg2._connect
psycopg2._connect = conect_stub
psycopg2._connect = connect_stub
def tearDown(self):
psycopg2._connect = self._connect_orig
@ -91,29 +91,45 @@ class ConnectTestCase(unittest.TestCase):
pass
psycopg2.connect(database='foo', bar='baz', connection_factory=f)
self.assertEqual(self.args[0], 'dbname=foo bar=baz')
dsn = " %s " % self.args[0]
self.assertIn(" dbname=foo ", dsn)
self.assertIn(" bar=baz ", dsn)
self.assertEqual(self.args[1], f)
self.assertEqual(self.args[2], False)
psycopg2.connect("dbname=foo bar=baz", connection_factory=f)
self.assertEqual(self.args[0], 'dbname=foo bar=baz')
dsn = " %s " % self.args[0]
self.assertIn(" dbname=foo ", dsn)
self.assertIn(" bar=baz ", dsn)
self.assertEqual(self.args[1], f)
self.assertEqual(self.args[2], False)
def test_async(self):
psycopg2.connect(database='foo', bar='baz', async=1)
self.assertEqual(self.args[0], 'dbname=foo bar=baz')
dsn = " %s " % self.args[0]
self.assertIn(" dbname=foo ", dsn)
self.assertIn(" bar=baz ", dsn)
self.assertEqual(self.args[1], None)
self.assert_(self.args[2])
psycopg2.connect("dbname=foo bar=baz", async=True)
self.assertEqual(self.args[0], 'dbname=foo bar=baz')
dsn = " %s " % self.args[0]
self.assertIn(" dbname=foo ", dsn)
self.assertIn(" bar=baz ", dsn)
self.assertEqual(self.args[1], None)
self.assert_(self.args[2])
def test_int_port_param(self):
psycopg2.connect(database='sony', port=6543)
dsn = " %s " % self.args[0]
self.assertIn(" dbname=sony ", dsn)
self.assertIn(" port=6543 ", dsn)
def test_empty_param(self):
psycopg2.connect(database='sony', password='')
self.assertEqual(self.args[0], "dbname=sony password=''")
dsn = " %s " % self.args[0]
self.assertIn(" dbname=sony ", dsn)
self.assertIn(" password='' ", dsn)
def test_escape(self):
psycopg2.connect(database='hello world')