Hstore adapter compatible with Python 3.

This commit is contained in:
Daniele Varrazzo 2010-12-29 03:45:24 +01:00
parent 89c492d3a4
commit c176de4bf8
2 changed files with 40 additions and 40 deletions

View File

@ -42,6 +42,7 @@ from psycopg2 import extensions as _ext
from psycopg2.extensions import cursor as _cursor from psycopg2.extensions import cursor as _cursor
from psycopg2.extensions import connection as _connection from psycopg2.extensions import connection as _connection
from psycopg2.extensions import adapt as _A from psycopg2.extensions import adapt as _A
from psycopg2.extensions import b
class DictCursorBase(_cursor): class DictCursorBase(_cursor):
@ -574,7 +575,7 @@ class HstoreAdapter(object):
def _getquoted_8(self): def _getquoted_8(self):
"""Use the operators available in PG pre-9.0.""" """Use the operators available in PG pre-9.0."""
if not self.wrapped: if not self.wrapped:
return "''::hstore" return b("''::hstore")
adapt = _ext.adapt adapt = _ext.adapt
rv = [] rv = []
@ -588,22 +589,23 @@ class HstoreAdapter(object):
v.prepare(self.conn) v.prepare(self.conn)
v = v.getquoted() v = v.getquoted()
else: else:
v = 'NULL' v = b('NULL')
rv.append("(%s => %s)" % (k, v)) # XXX this b'ing is painfully inefficient!
rv.append(b("(") + k + b(" => ") + v + b(")"))
return "(" + '||'.join(rv) + ")" return b("(") + b('||').join(rv) + b(")")
def _getquoted_9(self): def _getquoted_9(self):
"""Use the hstore(text[], text[]) function.""" """Use the hstore(text[], text[]) function."""
if not self.wrapped: if not self.wrapped:
return "''::hstore" return b("''::hstore")
k = _ext.adapt(self.wrapped.keys()) k = _ext.adapt(self.wrapped.keys())
k.prepare(self.conn) k.prepare(self.conn)
v = _ext.adapt(self.wrapped.values()) v = _ext.adapt(self.wrapped.values())
v.prepare(self.conn) v.prepare(self.conn)
return "hstore(%s, %s)" % (k.getquoted(), v.getquoted()) return b("hstore(") + k.getquoted() + b(", ") + v.getquoted() + b(")")
getquoted = _getquoted_9 getquoted = _getquoted_9
@ -620,13 +622,8 @@ class HstoreAdapter(object):
(?:\s*,\s*|$) # pairs separated by comma or end of string. (?:\s*,\s*|$) # pairs separated by comma or end of string.
""", regex.VERBOSE) """, regex.VERBOSE)
# backslash decoder @classmethod
if sys.version_info[0] < 3: def parse(self, s, cur, _bsdec=regex.compile(r"\\(.)")):
_bsdec = codecs.getdecoder("string_escape")
else:
_bsdec = codecs.getdecoder("unicode_escape")
def parse(self, s, cur, _decoder=_bsdec):
"""Parse an hstore representation in a Python string. """Parse an hstore representation in a Python string.
The hstore is represented as something like:: The hstore is represented as something like::
@ -644,10 +641,10 @@ class HstoreAdapter(object):
if m is None or m.start() != start: if m is None or m.start() != start:
raise psycopg2.InterfaceError( raise psycopg2.InterfaceError(
"error parsing hstore pair at char %d" % start) "error parsing hstore pair at char %d" % start)
k = _decoder(m.group(1))[0] k = _bsdec.sub(r'\1', m.group(1))
v = m.group(2) v = m.group(2)
if v is not None: if v is not None:
v = _decoder(v)[0] v = _bsdec.sub(r'\1', v)
rv[k] = v rv[k] = v
start = m.end() start = m.end()
@ -658,16 +655,14 @@ class HstoreAdapter(object):
return rv return rv
parse = classmethod(parse) @classmethod
def parse_unicode(self, s, cur): def parse_unicode(self, s, cur):
"""Parse an hstore returning unicode keys and values.""" """Parse an hstore returning unicode keys and values."""
codec = codecs.getdecoder(_ext.encodings[cur.connection.encoding]) if s is None:
bsdec = self._bsdec return None
decoder = lambda s: codec(bsdec(s)[0])
return self.parse(s, cur, _decoder=decoder)
parse_unicode = classmethod(parse_unicode) s = s.decode(_ext.encodings[cur.connection.encoding])
return self.parse(s, cur)
@classmethod @classmethod
def get_oids(self, conn_or_curs): def get_oids(self, conn_or_curs):
@ -713,11 +708,11 @@ def register_hstore(conn_or_curs, globally=False, unicode=False):
uses a single database you can pass *globally*\=True to have the typecaster uses a single database you can pass *globally*\=True to have the typecaster
registered on all the connections. registered on all the connections.
By default the returned dicts will have `str` objects as keys and values: On Python 2, by default the returned dicts will have `str` objects as keys and values:
use *unicode*\=True to return `unicode` objects instead. When adapting a use *unicode*\=True to return `unicode` objects instead. When adapting a
dictionary both `str` and `unicode` keys and values are handled (the dictionary both `str` and `unicode` keys and values are handled (the
`unicode` values will be converted according to the current `unicode` values will be converted according to the current
`~connection.encoding`). `~connection.encoding`). The option is not available on Python 3.
The |hstore| contrib module must be already installed in the database The |hstore| contrib module must be already installed in the database
(executing the ``hstore.sql`` script in your ``contrib`` directory). (executing the ``hstore.sql`` script in your ``contrib`` directory).
@ -730,7 +725,7 @@ def register_hstore(conn_or_curs, globally=False, unicode=False):
"please install it from your 'contrib/hstore.sql' file") "please install it from your 'contrib/hstore.sql' file")
# create and register the typecaster # create and register the typecaster
if unicode: if sys.version_info[0] < 3 and unicode:
cast = HstoreAdapter.parse_unicode cast = HstoreAdapter.parse_unicode
else: else:
cast = HstoreAdapter.parse cast = HstoreAdapter.parse

View File

@ -24,6 +24,8 @@ from testutils import unittest
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
from psycopg2.extensions import b
from testconfig import dsn from testconfig import dsn
@ -164,18 +166,17 @@ class HstoreTestCase(unittest.TestCase):
a.prepare(self.conn) a.prepare(self.conn)
q = a.getquoted() q = a.getquoted()
self.assert_(q.startswith("(("), q) self.assert_(q.startswith(b("((")), q)
self.assert_(q.endswith("))"), q) ii = q[1:-1].split(b("||"))
ii = q[1:-1].split("||")
ii.sort() ii.sort()
self.assertEqual(len(ii), len(o)) self.assertEqual(len(ii), len(o))
self.assertEqual(ii[0], filter_scs(self.conn, "(E'a' => E'1')")) self.assertEqual(ii[0], filter_scs(self.conn, b("(E'a' => E'1')")))
self.assertEqual(ii[1], filter_scs(self.conn, "(E'b' => E'''')")) self.assertEqual(ii[1], filter_scs(self.conn, b("(E'b' => E'''')")))
self.assertEqual(ii[2], filter_scs(self.conn, "(E'c' => NULL)")) self.assertEqual(ii[2], filter_scs(self.conn, b("(E'c' => NULL)")))
if 'd' in o: if 'd' in o:
encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding]) encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
self.assertEqual(ii[3], filter_scs(self.conn, "(E'd' => E'%s')" % encc)) self.assertEqual(ii[3], filter_scs(self.conn, b("(E'd' => E'") + encc + b("')")))
def test_adapt_9(self): def test_adapt_9(self):
if self.conn.server_version < 90000: if self.conn.server_version < 90000:
@ -191,21 +192,21 @@ class HstoreTestCase(unittest.TestCase):
a.prepare(self.conn) a.prepare(self.conn)
q = a.getquoted() q = a.getquoted()
m = re.match(r'hstore\(ARRAY\[([^\]]+)\], ARRAY\[([^\]]+)\]\)', q) m = re.match(b(r'hstore\(ARRAY\[([^\]]+)\], ARRAY\[([^\]]+)\]\)'), q)
self.assert_(m, repr(q)) self.assert_(m, repr(q))
kk = m.group(1).split(", ") kk = m.group(1).split(b(", "))
vv = m.group(2).split(", ") vv = m.group(2).split(b(", "))
ii = zip(kk, vv) ii = zip(kk, vv)
ii.sort() ii.sort()
self.assertEqual(len(ii), len(o)) self.assertEqual(len(ii), len(o))
self.assertEqual(ii[0], ("E'a'", "E'1'")) self.assertEqual(ii[0], (b("E'a'"), b("E'1'")))
self.assertEqual(ii[1], ("E'b'", "E''''")) self.assertEqual(ii[1], (b("E'b'"), b("E''''")))
self.assertEqual(ii[2], ("E'c'", "NULL")) self.assertEqual(ii[2], (b("E'c'"), b("NULL")))
if 'd' in o: if 'd' in o:
encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding]) encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
self.assertEqual(ii[3], ("E'd'", "E'%s'" % encc)) self.assertEqual(ii[3], (b("E'd'"), b("E'") + encc + b("'")))
def test_parse(self): def test_parse(self):
from psycopg2.extras import HstoreAdapter from psycopg2.extras import HstoreAdapter
@ -321,7 +322,11 @@ class HstoreTestCase(unittest.TestCase):
ok({''.join(ab): ''.join(ab)}) ok({''.join(ab): ''.join(ab)})
self.conn.set_client_encoding('latin1') self.conn.set_client_encoding('latin1')
ab = map(chr, range(1, 256)) if sys.version_info[0] < 3:
ab = map(chr, range(32, 127) + range(160, 255))
else:
ab = bytes(range(32, 127) + range(160, 255)).decode('latin1')
ok({''.join(ab): ''.join(ab)}) ok({''.join(ab): ''.join(ab)})
ok(dict(zip(ab, ab))) ok(dict(zip(ab, ab)))