Make CompositeCaster easier to subclass

This commit is contained in:
Daniele Varrazzo 2012-09-22 01:46:53 +01:00
parent 7de8611607
commit 1b2c2c34b6
2 changed files with 45 additions and 8 deletions

View File

@ -854,8 +854,13 @@ class CompositeCaster(object):
"expecting %d components for the type %s, %d found instead" % "expecting %d components for the type %s, %d found instead" %
(len(self.atttypes), self.name, len(tokens))) (len(self.atttypes), self.name, len(tokens)))
return self._ctor(curs.cast(oid, token) values = [ curs.cast(oid, token)
for oid, token in zip(self.atttypes, tokens)) for oid, token in zip(self.atttypes, tokens) ]
return self.make(values)
def make(self, values):
return self._ctor(values)
_re_tokenize = regex.compile(r""" _re_tokenize = regex.compile(r"""
\(? ([,)]) # an empty token, representing NULL \(? ([,)]) # an empty token, representing NULL
@ -937,10 +942,10 @@ ORDER BY attnum;
array_oid = recs[0][1] array_oid = recs[0][1]
type_attrs = [ (r[2], r[3]) for r in recs ] type_attrs = [ (r[2], r[3]) for r in recs ]
return CompositeCaster(tname, type_oid, type_attrs, return self(tname, type_oid, type_attrs,
array_oid=array_oid) array_oid=array_oid)
def register_composite(name, conn_or_curs, globally=False): def register_composite(name, conn_or_curs, globally=False, factory=None):
"""Register a typecaster to convert a composite type into a tuple. """Register a typecaster to convert a composite type into a tuple.
:param name: the name of a PostgreSQL composite type, e.g. created using :param name: the name of a PostgreSQL composite type, e.g. created using
@ -950,14 +955,21 @@ def register_composite(name, conn_or_curs, globally=False):
object, unless *globally* is set to `!True` object, unless *globally* is set to `!True`
:param globally: if `!False` (default) register the typecaster only on :param globally: if `!False` (default) register the typecaster only on
*conn_or_curs*, otherwise register it globally *conn_or_curs*, otherwise register it globally
:return: the registered `CompositeCaster` instance responsible for the :param factory: if specified it should be a `CompositeCaster` subclass: use
conversion it to :ref:`customize how to cast composite types <custom-composite>`
:return: the registered `CompositeCaster` or *factory* instance
responsible for the conversion
.. versionchanged:: 2.4.3 .. versionchanged:: 2.4.3
added support for array of composite types added support for array of composite types
.. versionchanged:: 2.4.6
added the *factory* parameter
""" """
caster = CompositeCaster._from_db(name, conn_or_curs) if factory is None:
factory = CompositeCaster
caster = factory._from_db(name, conn_or_curs)
_ext.register_type(caster.typecaster, not globally and conn_or_curs or None) _ext.register_type(caster.typecaster, not globally and conn_or_curs or None)
if caster.array_typecaster is not None: if caster.array_typecaster is not None:

View File

@ -736,7 +736,7 @@ class AdaptTypeTestCase(unittest.TestCase):
self.assertEqual(r[0], (2, 'test2')) self.assertEqual(r[0], (2, 'test2'))
self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)]) self.assertEqual(r[1], [(3, 'testc', 2), (4, 'testd', 2)])
@skip_if_no_hstore @skip_if_no_composite
def test_non_dbapi_connection(self): def test_non_dbapi_connection(self):
from psycopg2.extras import RealDictConnection from psycopg2.extras import RealDictConnection
from psycopg2.extras import register_composite from psycopg2.extras import register_composite
@ -760,6 +760,31 @@ class AdaptTypeTestCase(unittest.TestCase):
finally: finally:
conn.close() conn.close()
@skip_if_no_composite
def test_subclass(self):
oid = self._create_type("type_isd",
[('anint', 'integer'), ('astring', 'text'), ('adate', 'date')])
from psycopg2.extras import register_composite, CompositeCaster
class DictComposite(CompositeCaster):
def make(self, values):
return dict(zip(self.attnames, values))
t = register_composite('type_isd', self.conn, factory=DictComposite)
self.assertEqual(t.name, 'type_isd')
self.assertEqual(t.oid, oid)
curs = self.conn.cursor()
r = (10, 'hello', date(2011,1,2))
curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0]
self.assert_(isinstance(v, dict))
self.assertEqual(v['anint'], 10)
self.assertEqual(v['astring'], "hello")
self.assertEqual(v['adate'], date(2011,1,2))
def _create_type(self, name, fields): def _create_type(self, name, fields):
curs = self.conn.cursor() curs = self.conn.cursor()
try: try: