Generic Model

This commit is contained in:
sw 2022-06-02 22:54:42 +08:00
parent 7002912300
commit de915bb00a
4 changed files with 115 additions and 77 deletions

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import datetime import datetime
import logging import logging
from io import BytesIO from io import BytesIO
@ -42,7 +43,12 @@ class AioDatabase(Database):
async def close(self): async def close(self):
await self.request_session.aclose() await self.request_session.aclose()
async def _send(self, data, settings=None, stream=False): async def _send(
self,
data: str | bytes,
settings: dict = None,
stream: bool = False
):
r = await super()._send(data, settings, stream) r = await super()._send(data, settings, stream)
if r.status_code != 200: if r.status_code != 200:
raise ServerError(r.text) raise ServerError(r.text)
@ -50,7 +56,7 @@ class AioDatabase(Database):
async def count( async def count(
self, self,
model_class, model_class: type[MODEL],
conditions=None conditions=None
) -> int: ) -> int:
""" """

View File

@ -430,14 +430,19 @@ class Database:
query = self._substitute(query, MigrationHistory) query = self._substitute(query, MigrationHistory)
return set(obj.module_name for obj in self.select(query)) return set(obj.module_name for obj in self.select(query))
def _send(self, data, settings=None, stream=False): def _send(
self,
data: str | bytes,
settings: dict = None,
stream: bool = False
):
if isinstance(data, str): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode('utf-8')
if self.log_statements: if self.log_statements:
logger.info(data) logger.info(data)
params = self._build_params(settings) params = self._build_params(settings)
request = self.request_session.build_request( request = self.request_session.build_request(
method='POST', url=self.db_url, data=data, params=params method='POST', url=self.db_url, content=data, params=params
) )
r = self.request_session.send(request, stream=stream) r = self.request_session.send(request, stream=stream)
if isinstance(r, httpx.Response) and r.status_code != 200: if isinstance(r, httpx.Response) and r.status_code != 200:

View File

@ -313,7 +313,9 @@ class Model(metaclass=ModelBase):
if field: if field:
setattr(self, name, value) setattr(self, name, value)
else: else:
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name)) raise AttributeError(
'%s does not have a field called %s' % (self.__class__.__name__, name)
)
def __setattr__(self, name, value): def __setattr__(self, name, value):
""" """
@ -474,7 +476,7 @@ class Model(metaclass=ModelBase):
return {name: data[name] for name in fields} return {name: data[name] for name in fields}
@classmethod @classmethod
def objects_in(cls, database: Database) -> QuerySet: def objects_in(cls: type[MODEL], database: Database) -> QuerySet[MODEL]:
""" """
Returns a `QuerySet` for selecting instances of this model class. Returns a `QuerySet` for selecting instances of this model class.
""" """

View File

@ -2,7 +2,17 @@ from __future__ import unicode_literals, annotations
from math import ceil from math import ceil
from copy import copy, deepcopy from copy import copy, deepcopy
from types import CoroutineType from types import CoroutineType
from typing import TYPE_CHECKING, overload, Any, Union, Coroutine, Generic from typing import (
TYPE_CHECKING,
overload,
Any,
Union,
Coroutine,
Generic,
TypeVar,
AsyncIterator,
Iterator
)
import pytz import pytz
@ -14,6 +24,8 @@ if TYPE_CHECKING:
from clickhouse_orm.models import Model from clickhouse_orm.models import Model
from clickhouse_orm.database import Database, Page from clickhouse_orm.database import Database, Page
MODEL = TypeVar('MODEL', bound='Model')
class Operator: class Operator:
""" """
@ -307,14 +319,14 @@ class Q:
return q return q
class QuerySet: class QuerySet(Generic[MODEL]):
""" """
A queryset is an object that represents a database query using a specific `Model`. A queryset is an object that represents a database query using a specific `Model`.
It is lazy, meaning that it does not hit the database until you iterate over its It is lazy, meaning that it does not hit the database until you iterate over its
matching rows (model instances). matching rows (model instances).
""" """
def __init__(self, model_cls: type[Model], database: Database): def __init__(self, model_cls: type[MODEL], database: Database):
""" """
Initializer. It is possible to create a queryset like this, but the standard Initializer. It is possible to create a queryset like this, but the standard
way is to use `MyModel.objects_in(database)`. way is to use `MyModel.objects_in(database)`.
@ -333,33 +345,34 @@ class QuerySet:
self._distinct = False self._distinct = False
self._final = False self._final = False
def __deepcopy__(self, memodict={}): def _clone(self) -> "QuerySet[MODEL]":
obj = type(self)(self._model_cls, self._database) queryset = type(self)(self._model_cls, self._database)
obj._order_by = deepcopy(self._order_by) queryset._order_by = copy(self._order_by)
obj._where_q = deepcopy(self._where_q) queryset._where_q = copy(self._where_q)
obj._prewhere_q = deepcopy(self._prewhere_q) queryset._prewhere_q = copy(self._prewhere_q)
obj._grouping_fields = deepcopy(self._grouping_fields) queryset._grouping_fields = copy(self._grouping_fields)
obj._grouping_with_totals = deepcopy(self._grouping_with_totals) queryset._grouping_with_totals = self._grouping_with_totals
obj._fields = deepcopy(self._fields) queryset._fields = self._fields
obj._limits = deepcopy(self._limits) queryset._limits = copy(self._limits)
obj._limit_by = deepcopy(self._limit_by) queryset._limit_by = copy(self._limit_by)
obj._limit_by_fields = deepcopy(self._limit_by_fields) queryset._limit_by_fields = copy(self._limit_by_fields)
obj._distinct = deepcopy(self._distinct) queryset._distinct = self._distinct
obj._final = deepcopy(self._final) queryset._final = self._final
return obj return queryset
def __iter__(self): def __iter__(self) -> Iterator[MODEL]:
""" """
Iterates over the model instances matching this queryset Iterates over the model instances matching this queryset
""" """
return self._database.select(self.as_sql(), self._model_cls) for val in self._database.select(self.as_sql(), self._model_cls):
yield val
async def __aiter__(self): async def __aiter__(self) -> AsyncIterator[MODEL]:
from clickhouse_orm.aio.database import AioDatabase from clickhouse_orm.aio.database import AioDatabase
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'" assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
async for r in self._database.select(self.as_sql(), self._model_cls): async for val in self._database.select(self.as_sql(), self._model_cls):
yield r yield val
def __bool__(self): def __bool__(self):
""" """
@ -378,27 +391,27 @@ class QuerySet:
... ...
@overload @overload
def __getitem__(self, s: slice) -> "QuerySet": def __getitem__(self, s: slice) -> "QuerySet[MODEL]":
... ...
def __getitem__(self, s): def __getitem__(self, s):
if isinstance(s, int): if isinstance(s, int):
# Single index # Single index
assert s >= 0, 'negative indexes are not supported' assert s >= 0, 'negative indexes are not supported'
qs = copy(self) queryset = self._clone()
qs._limits = (s, 1) queryset._limits = (s, 1)
return next(iter(qs)) return next(iter(queryset))
# Slice # Slice
assert s.step in (None, 1), 'step is not supported in slices' assert s.step in (None, 1), 'step is not supported in slices'
start = s.start or 0 start = s.start or 0
stop = s.stop or 2 ** 63 - 1 stop = s.stop or 2 ** 63 - 1
assert start >= 0 and stop >= 0, 'negative indexes are not supported' assert start >= 0 and stop >= 0, 'negative indexes are not supported'
assert start <= stop, 'start of slice cannot be smaller than its end' assert start <= stop, 'start of slice cannot be smaller than its end'
qs = copy(self) queryset = self._clone()
qs._limits = (start, stop - start) queryset._limits = (start, stop - start)
return qs return queryset
def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet": def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet[MODEL]":
""" """
Adds a LIMIT BY clause to the query. Adds a LIMIT BY clause to the query.
- `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit). - `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit).
@ -410,10 +423,10 @@ class QuerySet:
offset = offset_limit[0] offset = offset_limit[0]
limit = offset_limit[1] limit = offset_limit[1]
assert offset >= 0 and limit >= 0, 'negative limits are not supported' assert offset >= 0 and limit >= 0, 'negative limits are not supported'
qs = copy(self) queryset = self._clone()
qs._limit_by = (offset, limit) queryset._limit_by = (offset, limit)
qs._limit_by_fields = fields_or_expr queryset._limit_by_fields = fields_or_expr
return qs return queryset
def select_fields_as_sql(self) -> str: def select_fields_as_sql(self) -> str:
""" """
@ -490,31 +503,31 @@ class QuerySet:
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
return self._database.count(self._model_cls, conditions) return self._database.count(self._model_cls, conditions)
def order_by(self, *field_names) -> "QuerySet": def order_by(self, *field_names) -> "QuerySet[MODEL]":
""" """
Returns a copy of this queryset with the ordering changed. Returns a copy of this queryset with the ordering changed.
""" """
qs = copy(self) queryset = self._clone()
qs._order_by = field_names queryset._order_by = field_names
return qs return queryset
def only(self, *field_names) -> "QuerySet": def only(self, *field_names) -> "QuerySet[MODEL]":
""" """
Returns a copy of this queryset limited to the specified field names. Returns a copy of this queryset limited to the specified field names.
Useful when there are large fields that are not needed, Useful when there are large fields that are not needed,
or for creating a subquery to use with an IN operator. or for creating a subquery to use with an IN operator.
""" """
qs = copy(self) queryset = self._clone()
qs._fields = field_names queryset._fields = field_names
return qs return queryset
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet": def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
from clickhouse_orm.funcs import F from clickhouse_orm.funcs import F
inverse = kwargs.pop('_inverse', False) inverse = kwargs.pop('_inverse', False)
prewhere = kwargs.pop('prewhere', False) prewhere = kwargs.pop('prewhere', False)
qs = copy(self) queryset = self._clone()
condition = Q() condition = Q()
for arg in q: for arg in q:
@ -533,20 +546,20 @@ class QuerySet:
condition = copy(self._prewhere_q if prewhere else self._where_q) & condition condition = copy(self._prewhere_q if prewhere else self._where_q) & condition
if prewhere: if prewhere:
qs._prewhere_q = condition queryset._prewhere_q = condition
else: else:
qs._where_q = condition queryset._where_q = condition
return qs return queryset
def filter(self, *q, **kwargs) -> "QuerySet": def filter(self, *q: Q, **kwargs: Any) -> "QuerySet[MODEL]":
""" """
Returns a copy of this queryset that includes only rows matching the conditions. Returns a copy of this queryset that includes only rows matching the conditions.
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE. Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
""" """
return self._filter_or_exclude(*q, **kwargs) return self._filter_or_exclude(*q, **kwargs)
def exclude(self, *q, **kwargs) -> "QuerySet": def exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
""" """
Returns a copy of this queryset that excludes all rows matching the conditions. Returns a copy of this queryset that excludes all rows matching the conditions.
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE. Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
@ -582,16 +595,16 @@ class QuerySet:
page_size=page_size page_size=page_size
) )
def distinct(self) -> "QuerySet": def distinct(self) -> "QuerySet[MODEL]":
""" """
Adds a DISTINCT clause to the query, meaning that any duplicate rows Adds a DISTINCT clause to the query, meaning that any duplicate rows
in the results will be omitted. in the results will be omitted.
""" """
qs = copy(self) queryset = self._clone()
qs._distinct = True queryset._distinct = True
return qs return queryset
def final(self) -> "QuerySet": def final(self) -> "QuerySet[MODEL]":
""" """
Adds a FINAL modifier to table, meaning data will be collapsed to final version. Adds a FINAL modifier to table, meaning data will be collapsed to final version.
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only. Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
@ -604,11 +617,11 @@ class QuerySet:
' and ReplacingMergeTree engines' ' and ReplacingMergeTree engines'
) )
qs = copy(self) queryset = self._clone()
qs._final = True queryset._final = True
return qs return queryset
def delete(self) -> "QuerySet": def delete(self) -> "QuerySet[MODEL]":
""" """
Deletes all records matched by this queryset's conditions. Deletes all records matched by this queryset's conditions.
Note that ClickHouse performs deletions in the background, so they are not immediate. Note that ClickHouse performs deletions in the background, so they are not immediate.
@ -619,7 +632,7 @@ class QuerySet:
self._database.raw(sql) self._database.raw(sql)
return self return self
def update(self, **kwargs) -> "QuerySet": def update(self, **kwargs) -> "QuerySet[MODEL]":
""" """
Updates all records matched by this queryset's conditions. Updates all records matched by this queryset's conditions.
Keyword arguments specify the field names and expressions to use for the update. Keyword arguments specify the field names and expressions to use for the update.
@ -644,7 +657,7 @@ class QuerySet:
assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._distinct, 'Mutations are not allowed after calling distinct()'
assert not self._final, 'Mutations are not allowed after calling final()' assert not self._final, 'Mutations are not allowed after calling final()'
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet": def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]":
""" """
Returns an `AggregateQuerySet` over this query, with `args` serving as Returns an `AggregateQuerySet` over this query, with `args` serving as
grouping fields and `kwargs` serving as calculated fields. At least one grouping fields and `kwargs` serving as calculated fields. At least one
@ -662,7 +675,7 @@ class QuerySet:
return AggregateQuerySet(self, args, kwargs) return AggregateQuerySet(self, args, kwargs)
class AggregateQuerySet(QuerySet): class AggregateQuerySet(QuerySet[MODEL]):
""" """
A queryset used for aggregation. A queryset used for aggregation.
""" """
@ -699,7 +712,7 @@ class AggregateQuerySet(QuerySet):
self._limits = base_queryset._limits self._limits = base_queryset._limits
self._distinct = base_queryset._distinct self._distinct = base_queryset._distinct
def group_by(self, *args) -> "AggregateQuerySet": def group_by(self, *args) -> "AggregateQuerySet[MODEL]":
""" """
This method lets you specify the grouping fields explicitly. The `args` must This method lets you specify the grouping fields explicitly. The `args` must
be names of grouping fields or calculated fields that this queryset was be names of grouping fields or calculated fields that this queryset was
@ -708,9 +721,9 @@ class AggregateQuerySet(QuerySet):
for name in args: for name in args:
assert name in self._fields or name in self._calculated_fields, \ assert name in self._fields or name in self._calculated_fields, \
'Cannot group by `%s` since it is not included in the query' % name 'Cannot group by `%s` since it is not included in the query' % name
qs = copy(self) queryset = copy(self)
qs._grouping_fields = args queryset._grouping_fields = args
return qs return queryset
def only(self, *field_names): def only(self, *field_names):
""" """
@ -731,8 +744,20 @@ class AggregateQuerySet(QuerySet):
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in
self._calculated_fields.items()]) self._calculated_fields.items()])
def __iter__(self): def __iter__(self) -> Iterator[Model]:
return self._database.select(self.as_sql()) # using an ad-hoc model """
using an ad-hoc model
"""
for val in self._database.select(self.as_sql()):
yield val
return self._database.select(self.as_sql())
async def __aiter__(self) -> AsyncIterator[Model]:
from clickhouse_orm.aio.database import AioDatabase
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
async for val in self._database.select(self.as_sql()):
yield val
def count(self) -> Union[int, Coroutine[int]]: def count(self) -> Union[int, Coroutine[int]]:
""" """
@ -744,15 +769,15 @@ class AggregateQuerySet(QuerySet):
return raw return raw
return int(raw) if raw else 0 return int(raw) if raw else 0
def with_totals(self) -> "AggregateQuerySet": def with_totals(self) -> "AggregateQuerySet[MODEL]":
""" """
Adds WITH TOTALS modifier ot GROUP BY, making query return extra row Adds WITH TOTALS modifier ot GROUP BY, making query return extra row
with aggregate function calculated across all the rows. More information: with aggregate function calculated across all the rows. More information:
https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier
""" """
qs = copy(self) queryset = copy(self)
qs._grouping_with_totals = True queryset._grouping_with_totals = True
return qs return queryset
def _verify_mutation_allowed(self): def _verify_mutation_allowed(self):
raise AssertionError('Cannot mutate an AggregateQuerySet') raise AssertionError('Cannot mutate an AggregateQuerySet')