From ed623776f366bf6cbee9dffad8e959356ce26faf Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 27 Sep 2010 00:49:31 +0100 Subject: [PATCH] Hstore can return unicode keys and values. --- lib/extras.py | 31 ++++++++++++++++++++++++++----- tests/types_extras.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/lib/extras.py b/lib/extras.py index d9872c53..e7fbfb4e 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -27,6 +27,7 @@ and classes untill a better place in the distribution is found. import os import time +import codecs import warnings import re as regex @@ -534,7 +535,10 @@ class HstoreAdapter(object): (?:\s*,\s*|$) # pairs separated by comma or end of string. """, regex.VERBOSE) - def parse(self, s, cur): + # backslash decoder + _bsdec = codecs.getdecoder("string_escape") + + def parse(self, s, cur, _decoder=_bsdec): """Parse an hstore representation in a Python string. The hstore is represented as something like:: @@ -552,10 +556,10 @@ class HstoreAdapter(object): if m is None or m.start() != start: raise psycopg2.InterfaceError( "error parsing hstore pair at char %d" % start) - k = m.group(1).decode("string_escape") + k = _decoder(m.group(1))[0] v = m.group(2) if v is not None: - v = v.decode("string_escape") + v = _decoder(v)[0] rv[k] = v start = m.end() @@ -568,7 +572,16 @@ class HstoreAdapter(object): parse = classmethod(parse) -def register_hstore(conn_or_curs): + def parse_unicode(self, s, cur): + """Parse an hstore returning unicode keys and values.""" + codec = codecs.getdecoder(_ext.encodings[cur.connection.encoding]) + bsdec = self._bsdec + decoder = lambda s: codec(bsdec(s)[0]) + return self.parse(s, cur, _decoder=decoder) + + parse_unicode = classmethod(parse_unicode) + +def register_hstore(conn_or_curs, unicode=False): """Register adapter/typecaster for dict/hstore reading/writing. The adapter must be registered on a connection or cursor as the hstore @@ -576,6 +589,9 @@ def register_hstore(conn_or_curs): Raise `~psycopg2.ProgrammingError` if hstore is not installed in the target database. + + By default the returned dicts have string keys and values: use + *unicode*=True to return `unicode` objects instead. """ if hasattr(conn_or_curs, 'execute'): conn = conn_or_curs.connection @@ -607,7 +623,12 @@ WHERE typname = 'hstore' and nspname = 'public'; "please install it from your 'contrib/hstore.sql' file") # create and register the typecaster - HSTORE = _ext.new_type((oids[0],), "HSTORE", HstoreAdapter.parse) + if unicode: + cast = HstoreAdapter.parse_unicode + else: + cast = HstoreAdapter.parse + + HSTORE = _ext.new_type((oids[0],), "HSTORE", cast) _ext.register_type(HSTORE, conn_or_curs) _ext.register_adapter(dict, HstoreAdapter) diff --git a/tests/types_extras.py b/tests/types_extras.py index 4171fd75..4c0a9dad 100644 --- a/tests/types_extras.py +++ b/tests/types_extras.py @@ -209,6 +209,19 @@ class HstoreTestCase(unittest.TestCase): self.assertEqual(t[1], {}) self.assertEqual(t[2], {'a': 'b'}) + def test_register_unicode(self): + from psycopg2.extras import register_hstore + + register_hstore(self.conn, unicode=True) + cur = self.conn.cursor() + cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore") + t = cur.fetchone() + self.assert_(t[0] is None) + self.assertEqual(t[1], {}) + self.assertEqual(t[2], {u'a': u'b'}) + self.assert_(isinstance(t[2].keys()[0], unicode)) + self.assert_(isinstance(t[2].values()[0], unicode)) + def test_roundtrip(self): from psycopg2.extras import register_hstore register_hstore(self.conn) @@ -234,6 +247,29 @@ class HstoreTestCase(unittest.TestCase): ok({''.join(ab): ''.join(ab)}) ok(dict(zip(ab, ab))) + def test_roundtrip_unicode(self): + from psycopg2.extras import register_hstore + register_hstore(self.conn, unicode=True) + cur = self.conn.cursor() + + def ok(d): + cur.execute("select %s", (d,)) + d1 = cur.fetchone()[0] + self.assertEqual(len(d), len(d1)) + for k, v in d1.iteritems(): + self.assert_(k in d, k) + self.assertEqual(d[k], v) + self.assert_(isinstance(k, unicode)) + self.assert_(v is None or isinstance(v, unicode)) + + ok({}) + ok({'a': 'b', 'c': None, 'd': u'\u20ac', u'\u2603': 'e'}) + + ab = map(unichr, range(1, 1024)) + ok({u''.join(ab): u''.join(ab)}) + ok(dict(zip(ab, ab))) + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)