create a new row type

This commit is contained in:
Changaco 2019-09-19 23:19:28 +02:00
parent f08019e356
commit e72f8d9732

View File

@ -29,6 +29,7 @@ import os as _os
import time as _time
import re as _re
from collections import namedtuple, OrderedDict
from operator import itemgetter
import logging as _logging
@ -61,6 +62,9 @@ from psycopg2._range import ( # noqa
from psycopg2._ipaddress import register_ipaddress # noqa
itemgetter0 = itemgetter(0)
class DictCursorBase(_cursor):
"""Base class for all dict-like cursors."""
@ -396,6 +400,108 @@ def _cached_make_nt(cls, key):
NamedTupleCursor._cached_make_nt = classmethod(_cached_make_nt)
class HybridRow(object):
"""A versatile row type.
This class implements both dict-style and attribute-style lookups and
assignments, in addition to index-based lookups. However, index-based
assigments aren't allowed.
Assignments aren't limited to the initial columns, extra attributes can be
added.
Beware that although hybrid rows support dict-style lookups and assigments,
they do not have the standard :class:`dict` methods (:meth:`~dict.get`,
:meth:`~dict.items`, etc.).
"""
__slots__ = ('_cols', '__dict__')
def __init__(self, cols, values):
self._cols = cols
self.__dict__.update(zip(map(itemgetter0, cols), values))
def __getitem__(self, key):
if isinstance(key, int):
return self.__dict__[self._cols[key][0]]
elif isinstance(key, slice):
return [self.__dict__[col[0]] for col in self._cols[key]]
else:
return self.__dict__[key]
def __setitem__(self, key, value):
if isinstance(key, (int, slice)):
raise TypeError('index-based assignments are not allowed')
self.__dict__[key] = value
def __contains__(self, key):
return key in self.__dict__
def __eq__(self, other):
if isinstance(other, HybridRow):
return other.__dict__ == self.__dict__
elif isinstance(other, dict):
return other == self.__dict__
elif isinstance(other, tuple):
return len(self.__dict__) == len(self._cols) and other == tuple(self)
return False
def __repr__(self):
col_indexes = {name: i for i, name in enumerate(self._cols)}
after = len(self._cols)
key = lambda t: (col_indexes.get(t[0], after), t[0])
return 'Row(%s)' % (
', '.join(map('%s=%r'.__mod__, sorted(self.__dict__.items(), key=key)))
)
def __getstate__(self):
# We only save the column names, not the other column attributes.
return tuple(map(itemgetter0, self._cols)), self.__dict__.copy()
def __setstate__(self, data):
self._cols = tuple((col_name,) for col_name in data[0])
self.__dict__.update(data[1])
def _asdict(self):
"""For compatibility with namedtuple classes."""
return self.__dict__.copy()
@property
def _fields(self):
"""For compatibility with namedtuple classes."""
return tuple(map(itemgetter0, self._cols))
class HybridRowCursor(_cursor):
"""A cursor subclass that generates :class:`HybridRow` objects.
"""
def fetchone(self):
t = _cursor.fetchone(self)
if t is not None:
return HybridRow(self.description, t)
def fetchmany(self, size=None):
ts = _cursor.fetchmany(self, size)
cols = self.description
return [HybridRow(cols, t) for t in ts]
def fetchall(self):
ts = _cursor.fetchall(self)
cols = self.description
return [HybridRow(cols, t) for t in ts]
def __iter__(self):
it = _cursor.__iter__(self)
while True:
try:
t = next(it)
except StopIteration:
return
yield HybridRow(self.description, t)
class LoggingConnection(_connection):
"""A connection that logs all queries to a file or logger__ object.