mirror of
https://github.com/Infinidat/infi.clickhouse_orm.git
synced 2025-08-02 11:10:11 +03:00
Generic Model
This commit is contained in:
parent
7002912300
commit
de915bb00a
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import datetime
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
@ -42,7 +43,12 @@ class AioDatabase(Database):
|
|||
async def close(self):
|
||||
await self.request_session.aclose()
|
||||
|
||||
async def _send(self, data, settings=None, stream=False):
|
||||
async def _send(
|
||||
self,
|
||||
data: str | bytes,
|
||||
settings: dict = None,
|
||||
stream: bool = False
|
||||
):
|
||||
r = await super()._send(data, settings, stream)
|
||||
if r.status_code != 200:
|
||||
raise ServerError(r.text)
|
||||
|
@ -50,7 +56,7 @@ class AioDatabase(Database):
|
|||
|
||||
async def count(
|
||||
self,
|
||||
model_class,
|
||||
model_class: type[MODEL],
|
||||
conditions=None
|
||||
) -> int:
|
||||
"""
|
||||
|
|
|
@ -430,14 +430,19 @@ class Database:
|
|||
query = self._substitute(query, MigrationHistory)
|
||||
return set(obj.module_name for obj in self.select(query))
|
||||
|
||||
def _send(self, data, settings=None, stream=False):
|
||||
def _send(
|
||||
self,
|
||||
data: str | bytes,
|
||||
settings: dict = None,
|
||||
stream: bool = False
|
||||
):
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
if self.log_statements:
|
||||
logger.info(data)
|
||||
params = self._build_params(settings)
|
||||
request = self.request_session.build_request(
|
||||
method='POST', url=self.db_url, data=data, params=params
|
||||
method='POST', url=self.db_url, content=data, params=params
|
||||
)
|
||||
r = self.request_session.send(request, stream=stream)
|
||||
if isinstance(r, httpx.Response) and r.status_code != 200:
|
||||
|
|
|
@ -313,7 +313,9 @@ class Model(metaclass=ModelBase):
|
|||
if field:
|
||||
setattr(self, name, value)
|
||||
else:
|
||||
raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name))
|
||||
raise AttributeError(
|
||||
'%s does not have a field called %s' % (self.__class__.__name__, name)
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""
|
||||
|
@ -474,7 +476,7 @@ class Model(metaclass=ModelBase):
|
|||
return {name: data[name] for name in fields}
|
||||
|
||||
@classmethod
|
||||
def objects_in(cls, database: Database) -> QuerySet:
|
||||
def objects_in(cls: type[MODEL], database: Database) -> QuerySet[MODEL]:
|
||||
"""
|
||||
Returns a `QuerySet` for selecting instances of this model class.
|
||||
"""
|
||||
|
|
|
@ -2,7 +2,17 @@ from __future__ import unicode_literals, annotations
|
|||
from math import ceil
|
||||
from copy import copy, deepcopy
|
||||
from types import CoroutineType
|
||||
from typing import TYPE_CHECKING, overload, Any, Union, Coroutine, Generic
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
overload,
|
||||
Any,
|
||||
Union,
|
||||
Coroutine,
|
||||
Generic,
|
||||
TypeVar,
|
||||
AsyncIterator,
|
||||
Iterator
|
||||
)
|
||||
|
||||
import pytz
|
||||
|
||||
|
@ -14,6 +24,8 @@ if TYPE_CHECKING:
|
|||
from clickhouse_orm.models import Model
|
||||
from clickhouse_orm.database import Database, Page
|
||||
|
||||
MODEL = TypeVar('MODEL', bound='Model')
|
||||
|
||||
|
||||
class Operator:
|
||||
"""
|
||||
|
@ -307,14 +319,14 @@ class Q:
|
|||
return q
|
||||
|
||||
|
||||
class QuerySet:
|
||||
class QuerySet(Generic[MODEL]):
|
||||
"""
|
||||
A queryset is an object that represents a database query using a specific `Model`.
|
||||
It is lazy, meaning that it does not hit the database until you iterate over its
|
||||
matching rows (model instances).
|
||||
"""
|
||||
|
||||
def __init__(self, model_cls: type[Model], database: Database):
|
||||
def __init__(self, model_cls: type[MODEL], database: Database):
|
||||
"""
|
||||
Initializer. It is possible to create a queryset like this, but the standard
|
||||
way is to use `MyModel.objects_in(database)`.
|
||||
|
@ -333,33 +345,34 @@ class QuerySet:
|
|||
self._distinct = False
|
||||
self._final = False
|
||||
|
||||
def __deepcopy__(self, memodict={}):
|
||||
obj = type(self)(self._model_cls, self._database)
|
||||
obj._order_by = deepcopy(self._order_by)
|
||||
obj._where_q = deepcopy(self._where_q)
|
||||
obj._prewhere_q = deepcopy(self._prewhere_q)
|
||||
obj._grouping_fields = deepcopy(self._grouping_fields)
|
||||
obj._grouping_with_totals = deepcopy(self._grouping_with_totals)
|
||||
obj._fields = deepcopy(self._fields)
|
||||
obj._limits = deepcopy(self._limits)
|
||||
obj._limit_by = deepcopy(self._limit_by)
|
||||
obj._limit_by_fields = deepcopy(self._limit_by_fields)
|
||||
obj._distinct = deepcopy(self._distinct)
|
||||
obj._final = deepcopy(self._final)
|
||||
return obj
|
||||
def _clone(self) -> "QuerySet[MODEL]":
|
||||
queryset = type(self)(self._model_cls, self._database)
|
||||
queryset._order_by = copy(self._order_by)
|
||||
queryset._where_q = copy(self._where_q)
|
||||
queryset._prewhere_q = copy(self._prewhere_q)
|
||||
queryset._grouping_fields = copy(self._grouping_fields)
|
||||
queryset._grouping_with_totals = self._grouping_with_totals
|
||||
queryset._fields = self._fields
|
||||
queryset._limits = copy(self._limits)
|
||||
queryset._limit_by = copy(self._limit_by)
|
||||
queryset._limit_by_fields = copy(self._limit_by_fields)
|
||||
queryset._distinct = self._distinct
|
||||
queryset._final = self._final
|
||||
return queryset
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[MODEL]:
|
||||
"""
|
||||
Iterates over the model instances matching this queryset
|
||||
"""
|
||||
return self._database.select(self.as_sql(), self._model_cls)
|
||||
for val in self._database.select(self.as_sql(), self._model_cls):
|
||||
yield val
|
||||
|
||||
async def __aiter__(self):
|
||||
async def __aiter__(self) -> AsyncIterator[MODEL]:
|
||||
from clickhouse_orm.aio.database import AioDatabase
|
||||
|
||||
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
|
||||
async for r in self._database.select(self.as_sql(), self._model_cls):
|
||||
yield r
|
||||
async for val in self._database.select(self.as_sql(), self._model_cls):
|
||||
yield val
|
||||
|
||||
def __bool__(self):
|
||||
"""
|
||||
|
@ -378,27 +391,27 @@ class QuerySet:
|
|||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> "QuerySet":
|
||||
def __getitem__(self, s: slice) -> "QuerySet[MODEL]":
|
||||
...
|
||||
|
||||
def __getitem__(self, s):
|
||||
if isinstance(s, int):
|
||||
# Single index
|
||||
assert s >= 0, 'negative indexes are not supported'
|
||||
qs = copy(self)
|
||||
qs._limits = (s, 1)
|
||||
return next(iter(qs))
|
||||
queryset = self._clone()
|
||||
queryset._limits = (s, 1)
|
||||
return next(iter(queryset))
|
||||
# Slice
|
||||
assert s.step in (None, 1), 'step is not supported in slices'
|
||||
start = s.start or 0
|
||||
stop = s.stop or 2 ** 63 - 1
|
||||
assert start >= 0 and stop >= 0, 'negative indexes are not supported'
|
||||
assert start <= stop, 'start of slice cannot be smaller than its end'
|
||||
qs = copy(self)
|
||||
qs._limits = (start, stop - start)
|
||||
return qs
|
||||
queryset = self._clone()
|
||||
queryset._limits = (start, stop - start)
|
||||
return queryset
|
||||
|
||||
def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet":
|
||||
def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Adds a LIMIT BY clause to the query.
|
||||
- `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit).
|
||||
|
@ -410,10 +423,10 @@ class QuerySet:
|
|||
offset = offset_limit[0]
|
||||
limit = offset_limit[1]
|
||||
assert offset >= 0 and limit >= 0, 'negative limits are not supported'
|
||||
qs = copy(self)
|
||||
qs._limit_by = (offset, limit)
|
||||
qs._limit_by_fields = fields_or_expr
|
||||
return qs
|
||||
queryset = self._clone()
|
||||
queryset._limit_by = (offset, limit)
|
||||
queryset._limit_by_fields = fields_or_expr
|
||||
return queryset
|
||||
|
||||
def select_fields_as_sql(self) -> str:
|
||||
"""
|
||||
|
@ -490,31 +503,31 @@ class QuerySet:
|
|||
conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls)
|
||||
return self._database.count(self._model_cls, conditions)
|
||||
|
||||
def order_by(self, *field_names) -> "QuerySet":
|
||||
def order_by(self, *field_names) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Returns a copy of this queryset with the ordering changed.
|
||||
"""
|
||||
qs = copy(self)
|
||||
qs._order_by = field_names
|
||||
return qs
|
||||
queryset = self._clone()
|
||||
queryset._order_by = field_names
|
||||
return queryset
|
||||
|
||||
def only(self, *field_names) -> "QuerySet":
|
||||
def only(self, *field_names) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Returns a copy of this queryset limited to the specified field names.
|
||||
Useful when there are large fields that are not needed,
|
||||
or for creating a subquery to use with an IN operator.
|
||||
"""
|
||||
qs = copy(self)
|
||||
qs._fields = field_names
|
||||
return qs
|
||||
queryset = self._clone()
|
||||
queryset._fields = field_names
|
||||
return queryset
|
||||
|
||||
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet":
|
||||
def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
|
||||
from clickhouse_orm.funcs import F
|
||||
|
||||
inverse = kwargs.pop('_inverse', False)
|
||||
prewhere = kwargs.pop('prewhere', False)
|
||||
|
||||
qs = copy(self)
|
||||
queryset = self._clone()
|
||||
|
||||
condition = Q()
|
||||
for arg in q:
|
||||
|
@ -533,20 +546,20 @@ class QuerySet:
|
|||
|
||||
condition = copy(self._prewhere_q if prewhere else self._where_q) & condition
|
||||
if prewhere:
|
||||
qs._prewhere_q = condition
|
||||
queryset._prewhere_q = condition
|
||||
else:
|
||||
qs._where_q = condition
|
||||
queryset._where_q = condition
|
||||
|
||||
return qs
|
||||
return queryset
|
||||
|
||||
def filter(self, *q, **kwargs) -> "QuerySet":
|
||||
def filter(self, *q: Q, **kwargs: Any) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Returns a copy of this queryset that includes only rows matching the conditions.
|
||||
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
|
||||
"""
|
||||
return self._filter_or_exclude(*q, **kwargs)
|
||||
|
||||
def exclude(self, *q, **kwargs) -> "QuerySet":
|
||||
def exclude(self, *q, **kwargs) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Returns a copy of this queryset that excludes all rows matching the conditions.
|
||||
Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE.
|
||||
|
@ -582,16 +595,16 @@ class QuerySet:
|
|||
page_size=page_size
|
||||
)
|
||||
|
||||
def distinct(self) -> "QuerySet":
|
||||
def distinct(self) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Adds a DISTINCT clause to the query, meaning that any duplicate rows
|
||||
in the results will be omitted.
|
||||
"""
|
||||
qs = copy(self)
|
||||
qs._distinct = True
|
||||
return qs
|
||||
queryset = self._clone()
|
||||
queryset._distinct = True
|
||||
return queryset
|
||||
|
||||
def final(self) -> "QuerySet":
|
||||
def final(self) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Adds a FINAL modifier to table, meaning data will be collapsed to final version.
|
||||
Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only.
|
||||
|
@ -604,11 +617,11 @@ class QuerySet:
|
|||
' and ReplacingMergeTree engines'
|
||||
)
|
||||
|
||||
qs = copy(self)
|
||||
qs._final = True
|
||||
return qs
|
||||
queryset = self._clone()
|
||||
queryset._final = True
|
||||
return queryset
|
||||
|
||||
def delete(self) -> "QuerySet":
|
||||
def delete(self) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Deletes all records matched by this queryset's conditions.
|
||||
Note that ClickHouse performs deletions in the background, so they are not immediate.
|
||||
|
@ -619,7 +632,7 @@ class QuerySet:
|
|||
self._database.raw(sql)
|
||||
return self
|
||||
|
||||
def update(self, **kwargs) -> "QuerySet":
|
||||
def update(self, **kwargs) -> "QuerySet[MODEL]":
|
||||
"""
|
||||
Updates all records matched by this queryset's conditions.
|
||||
Keyword arguments specify the field names and expressions to use for the update.
|
||||
|
@ -644,7 +657,7 @@ class QuerySet:
|
|||
assert not self._distinct, 'Mutations are not allowed after calling distinct()'
|
||||
assert not self._final, 'Mutations are not allowed after calling final()'
|
||||
|
||||
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet":
|
||||
def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]":
|
||||
"""
|
||||
Returns an `AggregateQuerySet` over this query, with `args` serving as
|
||||
grouping fields and `kwargs` serving as calculated fields. At least one
|
||||
|
@ -662,7 +675,7 @@ class QuerySet:
|
|||
return AggregateQuerySet(self, args, kwargs)
|
||||
|
||||
|
||||
class AggregateQuerySet(QuerySet):
|
||||
class AggregateQuerySet(QuerySet[MODEL]):
|
||||
"""
|
||||
A queryset used for aggregation.
|
||||
"""
|
||||
|
@ -699,7 +712,7 @@ class AggregateQuerySet(QuerySet):
|
|||
self._limits = base_queryset._limits
|
||||
self._distinct = base_queryset._distinct
|
||||
|
||||
def group_by(self, *args) -> "AggregateQuerySet":
|
||||
def group_by(self, *args) -> "AggregateQuerySet[MODEL]":
|
||||
"""
|
||||
This method lets you specify the grouping fields explicitly. The `args` must
|
||||
be names of grouping fields or calculated fields that this queryset was
|
||||
|
@ -708,9 +721,9 @@ class AggregateQuerySet(QuerySet):
|
|||
for name in args:
|
||||
assert name in self._fields or name in self._calculated_fields, \
|
||||
'Cannot group by `%s` since it is not included in the query' % name
|
||||
qs = copy(self)
|
||||
qs._grouping_fields = args
|
||||
return qs
|
||||
queryset = copy(self)
|
||||
queryset._grouping_fields = args
|
||||
return queryset
|
||||
|
||||
def only(self, *field_names):
|
||||
"""
|
||||
|
@ -731,8 +744,20 @@ class AggregateQuerySet(QuerySet):
|
|||
return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in
|
||||
self._calculated_fields.items()])
|
||||
|
||||
def __iter__(self):
|
||||
return self._database.select(self.as_sql()) # using an ad-hoc model
|
||||
def __iter__(self) -> Iterator[Model]:
|
||||
"""
|
||||
using an ad-hoc model
|
||||
"""
|
||||
for val in self._database.select(self.as_sql()):
|
||||
yield val
|
||||
return self._database.select(self.as_sql())
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[Model]:
|
||||
from clickhouse_orm.aio.database import AioDatabase
|
||||
|
||||
assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'"
|
||||
async for val in self._database.select(self.as_sql()):
|
||||
yield val
|
||||
|
||||
def count(self) -> Union[int, Coroutine[int]]:
|
||||
"""
|
||||
|
@ -744,15 +769,15 @@ class AggregateQuerySet(QuerySet):
|
|||
return raw
|
||||
return int(raw) if raw else 0
|
||||
|
||||
def with_totals(self) -> "AggregateQuerySet":
|
||||
def with_totals(self) -> "AggregateQuerySet[MODEL]":
|
||||
"""
|
||||
Adds WITH TOTALS modifier ot GROUP BY, making query return extra row
|
||||
with aggregate function calculated across all the rows. More information:
|
||||
https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier
|
||||
"""
|
||||
qs = copy(self)
|
||||
qs._grouping_with_totals = True
|
||||
return qs
|
||||
queryset = copy(self)
|
||||
queryset._grouping_with_totals = True
|
||||
return queryset
|
||||
|
||||
def _verify_mutation_allowed(self):
|
||||
raise AssertionError('Cannot mutate an AggregateQuerySet')
|
||||
|
|
Loading…
Reference in New Issue
Block a user