From e72f8d9732fcd52338e6d486f0efed589a50c7b6 Mon Sep 17 00:00:00 2001 From: Changaco Date: Thu, 19 Sep 2019 23:19:28 +0200 Subject: [PATCH] create a new row type --- lib/extras.py | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/lib/extras.py b/lib/extras.py index afc31c32..36a1ea19 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -29,6 +29,7 @@ import os as _os import time as _time import re as _re from collections import namedtuple, OrderedDict +from operator import itemgetter import logging as _logging @@ -61,6 +62,9 @@ from psycopg2._range import ( # noqa from psycopg2._ipaddress import register_ipaddress # noqa +itemgetter0 = itemgetter(0) + + class DictCursorBase(_cursor): """Base class for all dict-like cursors.""" @@ -396,6 +400,108 @@ def _cached_make_nt(cls, key): NamedTupleCursor._cached_make_nt = classmethod(_cached_make_nt) +class HybridRow(object): + """A versatile row type. + + This class implements both dict-style and attribute-style lookups and + assignments, in addition to index-based lookups. However, index-based + assigments aren't allowed. + + Assignments aren't limited to the initial columns, extra attributes can be + added. + + Beware that although hybrid rows support dict-style lookups and assigments, + they do not have the standard :class:`dict` methods (:meth:`~dict.get`, + :meth:`~dict.items`, etc.). + + """ + + __slots__ = ('_cols', '__dict__') + + def __init__(self, cols, values): + self._cols = cols + self.__dict__.update(zip(map(itemgetter0, cols), values)) + + def __getitem__(self, key): + if isinstance(key, int): + return self.__dict__[self._cols[key][0]] + elif isinstance(key, slice): + return [self.__dict__[col[0]] for col in self._cols[key]] + else: + return self.__dict__[key] + + def __setitem__(self, key, value): + if isinstance(key, (int, slice)): + raise TypeError('index-based assignments are not allowed') + self.__dict__[key] = value + + def __contains__(self, key): + return key in self.__dict__ + + def __eq__(self, other): + if isinstance(other, HybridRow): + return other.__dict__ == self.__dict__ + elif isinstance(other, dict): + return other == self.__dict__ + elif isinstance(other, tuple): + return len(self.__dict__) == len(self._cols) and other == tuple(self) + return False + + def __repr__(self): + col_indexes = {name: i for i, name in enumerate(self._cols)} + after = len(self._cols) + key = lambda t: (col_indexes.get(t[0], after), t[0]) + return 'Row(%s)' % ( + ', '.join(map('%s=%r'.__mod__, sorted(self.__dict__.items(), key=key))) + ) + + def __getstate__(self): + # We only save the column names, not the other column attributes. + return tuple(map(itemgetter0, self._cols)), self.__dict__.copy() + + def __setstate__(self, data): + self._cols = tuple((col_name,) for col_name in data[0]) + self.__dict__.update(data[1]) + + def _asdict(self): + """For compatibility with namedtuple classes.""" + return self.__dict__.copy() + + @property + def _fields(self): + """For compatibility with namedtuple classes.""" + return tuple(map(itemgetter0, self._cols)) + + +class HybridRowCursor(_cursor): + """A cursor subclass that generates :class:`HybridRow` objects. + """ + + def fetchone(self): + t = _cursor.fetchone(self) + if t is not None: + return HybridRow(self.description, t) + + def fetchmany(self, size=None): + ts = _cursor.fetchmany(self, size) + cols = self.description + return [HybridRow(cols, t) for t in ts] + + def fetchall(self): + ts = _cursor.fetchall(self) + cols = self.description + return [HybridRow(cols, t) for t in ts] + + def __iter__(self): + it = _cursor.__iter__(self) + while True: + try: + t = next(it) + except StopIteration: + return + yield HybridRow(self.description, t) + + class LoggingConnection(_connection): """A connection that logs all queries to a file or logger__ object.