Generic Model

This commit is contained in:
sw 2022-06-02 22:54:42 +08:00
parent 7002912300
commit de915bb00a
4 changed files with 115 additions and 77 deletions

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import datetime
import logging
from io import BytesIO
@ -42,7 +43,12 @@ class AioDatabase(Database):
async def close(self):
await self.request_session.aclose()
async def _send(self, data, settings=None, stream=False):
async def _send(
self,
data: str | bytes,
settings: dict = None,
stream: bool = False
):
r = await super()._send(data, settings, stream)
if r.status_code != 200:
raise ServerError(r.text)
@ -50,7 +56,7 @@ class AioDatabase(Database):
async def count(
self,
model_class,
model_class: type[MODEL],
conditions=None
) -> int:
"""

View File

@ -430,14 +430,19 @@ class Database:
query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query))
def _send(self, data, settings=None, stream=False):
def _send(
self,
data: str | bytes,
settings: dict = None,
stream: bool = False
):
if isinstance(data, str):
data = data.encode('utf-8')
if self.log_statements:
logger.info(data)
params = self._build_params(settings)
request = self.request_session.build_request(
method='POST', url=self.db_url, data=data, params=params
method='POST', url=self.db_url, content=data, params=params
)
r = self.request_session.send(request, stream=stream)
if isinstance(r, httpx.Response) and r.status_code != 200:

View File

@ -313,7 +313,9 @@ class Model(metaclass=ModelBase):
if field:
setattr(self, name, value)
else:
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name))
raise AttributeError(
'%s does not have a field called %s' % (self.__class__.__name__, name)
)
def __setattr__(self, name, value):
"""
@ -474,7 +476,7 @@ class Model(metaclass=ModelBase):
return {name: data[name] for name in fields}
@classmethod
def objects_in(cls, database: Database) -> QuerySet:
def objects_in(cls: type[MODEL], database: Database) -> QuerySet[MODEL]:
"""
Returns a `QuerySet` for selecting instances of this model class.
"""

View File

@ -2,7 +2,17 @@ from __future__ import unicode_literals, annotations
from math import ceil
from copy import copy, deepcopy
from types import CoroutineType
from typing import TYPE_CHECKING, overload, Any, Union, Coroutine, Generic
from typing import (
TYPE_CHECKING,
overload,
Any,
Union,
Coroutine,
Generic,
TypeVar,
AsyncIterator,
Iterator
)
import pytz
@ -14,6 +24,8 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model
from clickhouse_orm.database import Database, Page
MODEL = TypeVar('MODEL', bound='Model')
class Operator:
"""
@ -307,14 +319,14 @@ class Q:
return q
class QuerySet:
class QuerySet(Generic[MODEL]):
"""
A queryset is an object that represents a database query using a specific `Model`.
It is lazy, meaning that it does not hit the database until you iterate over its
matching rows (model instances).
"""
def __init__(self, model_cls: type[Model], database: Database):
def __init__(self, model_cls: type[MODEL], database: Database):
"""
Initializer. It is possible to create a queryset like this, but the standard
way is to use `MyModel.objects_in(database)`.
@ -333,33 +345,34 @@ class QuerySet:
self._distinct = False
self._final = False
def __deepcopy__(self, memodict={}):
obj = type(self)(self._model_cls, self._database)
obj._order_by = deepcopy(self._order_by)
obj._where_q = deepcopy(self._where_q)
obj._prewhere_q = deepcopy(self._prewhere_q)
obj._grouping_fields = deepcopy(self._grouping_fields)
obj._grouping_with_totals = deepcopy(self._grouping_with_totals)
obj._fields = deepcopy(self._fields)
obj._limits = deepcopy(self._limits)
obj._limit_by = deepcopy(self._limit_by)
obj._limit_by_fields = deepcopy(self._limit_by_fields)
obj._distinct = deepcopy(self._distinct)
obj._final = deepcopy(self._final)
return obj
def _clone(self) -> "QuerySet[MODEL]":
queryset = type(self)(self._model_cls, self._database)
queryset._order_by = copy(self._order_by)
queryset._where_q = copy(self._where_q)
queryset._prewhere_q = copy(self._prewhere_q)
queryset._grouping_fields = copy(self._grouping_fields)
queryset._grouping_with_totals = self._grouping_with_totals
queryset._fields = self._fields
queryset._limits = copy(self._limits)
queryset._limit_by = copy(self._limit_by)
queryset._limit_by_fields = copy(self._limit_by_fields)
queryset._distinct = self._distinct
queryset._final = self._final
return queryset
def __iter__(self):
def __iter__(self) -> Iterator[MODEL]:
"""
Iterates over the model instances matching this queryset
"""
return self._database.select(self.as_sql(), self._model_cls)
for val in self._database.select(self.as_sql(), self._model_cls):
yield val
async def __aiter__(self):
async def __aiter__(self) -> AsyncIterator[MODEL]:
from clickhouse_orm.aio.database import AioDatabase
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
async for r in self._database.select(self.as_sql(), self._model_cls):
yield r
async for val in self._database.select(self.as_sql(), self._model_cls):
yield val
def __bool__(self):
"""
@ -378,27 +391,27 @@ class QuerySet:
...
@overload
def __getitem__(self, s: slice) -> "QuerySet":
def __getitem__(self, s: slice) -> "QuerySet[MODEL]":
...
def __getitem__(self, s):
if isinstance(s, int):
# Single index
assert s >= 0, 'negative indexes are not supported'
qs = copy(self)
qs._limits = (s, 1)
return next(iter(qs))
queryset = self._clone()
queryset._limits = (s, 1)
return next(iter(queryset))
# Slice
assert s.step in (None, 1), 'step is not supported in slices'
start = s.start or 0
stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
assert start <= stop, 'start of slice cannot be smaller than its end'
qs = copy(self)
qs._limits = (start, stop - start)
return qs
queryset = self._clone()
queryset._limits = (start, stop - start)
return queryset
def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet":
def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet[MODEL]":
"""
Adds a LIMIT BY clause to the query.
- `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit).
@ -410,10 +423,10 @@ class QuerySet:
offset = offset_limit[0]
limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
qs = copy(self)
qs._limit_by = (offset, limit)
qs._limit_by_fields = fields_or_expr
return qs
queryset = self._clone()
queryset._limit_by = (offset, limit)
queryset._limit_by_fields = fields_or_expr
return queryset
def select_fields_as_sql(self) -> str:
"""
@ -490,31 +503,31 @@ class QuerySet:
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
return self._database.count(self._model_cls, conditions)
def order_by(self, *field_names) -> "QuerySet":
def order_by(self, *field_names) -> "QuerySet[MODEL]":
"""
Returns a copy of this queryset with the ordering changed.
"""
qs = copy(self)
qs._order_by = field_names
return qs
queryset = self._clone()
queryset._order_by = field_names
return queryset
def only(self, *field_names) -> "QuerySet":
def only(self, *field_names) -> "QuerySet[MODEL]":
"""
Returns a copy of this queryset limited to the specified field names.
Useful when there are large fields that are not needed,
or for creating a subquery to use with an IN operator.
"""
qs = copy(self)
qs._fields = field_names
return qs
queryset = self._clone()
queryset._fields = field_names
return queryset
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet":
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
from clickhouse_orm.funcs import F
inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False)
qs = copy(self)
queryset = self._clone()
condition = Q()
for arg in q:
@ -533,20 +546,20 @@ class QuerySet:
condition = copy(self._prewhere_q if prewhere else self._where_q) & condition
if prewhere:
qs._prewhere_q = condition
queryset._prewhere_q = condition
else:
qs._where_q = condition
queryset._where_q = condition
return qs
return queryset
def filter(self, *q, **kwargs) -> "QuerySet":
def filter(self, *q: Q, **kwargs: Any) -> "QuerySet[MODEL]":
"""
Returns a copy of this queryset that includes only rows matching the conditions.
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
"""
return self._filter_or_exclude(*q, **kwargs)
def exclude(self, *q, **kwargs) -> "QuerySet":
def exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
"""
Returns a copy of this queryset that excludes all rows matching the conditions.
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
@ -582,16 +595,16 @@ class QuerySet:
page_size=page_size
)
def distinct(self) -> "QuerySet":
def distinct(self) -> "QuerySet[MODEL]":
"""
Adds a DISTINCT clause to the query, meaning that any duplicate rows
in the results will be omitted.
"""
qs = copy(self)
qs._distinct = True
return qs
queryset = self._clone()
queryset._distinct = True
return queryset
def final(self) -> "QuerySet":
def final(self) -> "QuerySet[MODEL]":
"""
Adds a FINAL modifier to table, meaning data will be collapsed to final version.
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
@ -604,11 +617,11 @@ class QuerySet:
' and ReplacingMergeTree engines'
)
qs = copy(self)
qs._final = True
return qs
queryset = self._clone()
queryset._final = True
return queryset
def delete(self) -> "QuerySet":
def delete(self) -> "QuerySet[MODEL]":
"""
Deletes all records matched by this queryset's conditions.
Note that ClickHouse performs deletions in the background, so they are not immediate.
@ -619,7 +632,7 @@ class QuerySet:
self._database.raw(sql)
return self
def update(self, **kwargs) -> "QuerySet":
def update(self, **kwargs) -> "QuerySet[MODEL]":
"""
Updates all records matched by this queryset's conditions.
Keyword arguments specify the field names and expressions to use for the update.
@ -644,7 +657,7 @@ class QuerySet:
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
assert not self._final, 'Mutations are not allowed after calling final()'
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet":
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]":
"""
Returns an `AggregateQuerySet` over this query, with `args` serving as
grouping fields and `kwargs` serving as calculated fields. At least one
@ -662,7 +675,7 @@ class QuerySet:
return AggregateQuerySet(self, args, kwargs)
class AggregateQuerySet(QuerySet):
class AggregateQuerySet(QuerySet[MODEL]):
"""
A queryset used for aggregation.
"""
@ -699,7 +712,7 @@ class AggregateQuerySet(QuerySet):
self._limits = base_queryset._limits
self._distinct = base_queryset._distinct
def group_by(self, *args) -> "AggregateQuerySet":
def group_by(self, *args) -> "AggregateQuerySet[MODEL]":
"""
This method lets you specify the grouping fields explicitly. The `args` must
be names of grouping fields or calculated fields that this queryset was
@ -708,9 +721,9 @@ class AggregateQuerySet(QuerySet):
for name in args:
assert name in self._fields or name in self._calculated_fields, \
'Cannot group by `%s` since it is not included in the query' % name
qs = copy(self)
qs._grouping_fields = args
return qs
queryset = copy(self)
queryset._grouping_fields = args
return queryset
def only(self, *field_names):
"""
@ -731,8 +744,20 @@ class AggregateQuerySet(QuerySet):
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in
self._calculated_fields.items()])
def __iter__(self):
return self._database.select(self.as_sql()) # using an ad-hoc model
def __iter__(self) -> Iterator[Model]:
"""
using an ad-hoc model
"""
for val in self._database.select(self.as_sql()):
yield val
return self._database.select(self.as_sql())
async def __aiter__(self) -> AsyncIterator[Model]:
from clickhouse_orm.aio.database import AioDatabase
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
async for val in self._database.select(self.as_sql()):
yield val
def count(self) -> Union[int, Coroutine[int]]:
"""
@ -744,15 +769,15 @@ class AggregateQuerySet(QuerySet):
return raw
return int(raw) if raw else 0
def with_totals(self) -> "AggregateQuerySet":
def with_totals(self) -> "AggregateQuerySet[MODEL]":
"""
Adds WITH TOTALS modifier ot GROUP BY, making query return extra row
with aggregate function calculated across all the rows. More information:
https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier
"""
qs = copy(self)
qs._grouping_with_totals = True
return qs
queryset = copy(self)
queryset._grouping_with_totals = True
return queryset
def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet')