diff --git a/src/clickhouse_orm/aio/database.py b/src/clickhouse_orm/aio/database.py index 0b68669..652532e 100644 --- a/src/clickhouse_orm/aio/database.py +++ b/src/clickhouse_orm/aio/database.py @@ -1,3 +1,4 @@ +from __future__ import annotations import datetime import logging from io import BytesIO @@ -42,7 +43,12 @@ class AioDatabase(Database): async def close(self): await self.request_session.aclose() - async def _send(self, data, settings=None, stream=False): + async def _send( + self, + data: str | bytes, + settings: dict = None, + stream: bool = False + ): r = await super()._send(data, settings, stream) if r.status_code != 200: raise ServerError(r.text) @@ -50,7 +56,7 @@ class AioDatabase(Database): async def count( self, - model_class, + model_class: type[MODEL], conditions=None ) -> int: """ diff --git a/src/clickhouse_orm/database.py b/src/clickhouse_orm/database.py index 8e058a1..4bfca8d 100644 --- a/src/clickhouse_orm/database.py +++ b/src/clickhouse_orm/database.py @@ -430,14 +430,19 @@ class Database: query = self._substitute(query, MigrationHistory) return set(obj.module_name for obj in self.select(query)) - def _send(self, data, settings=None, stream=False): + def _send( + self, + data: str | bytes, + settings: dict = None, + stream: bool = False + ): if isinstance(data, str): data = data.encode('utf-8') if self.log_statements: logger.info(data) params = self._build_params(settings) request = self.request_session.build_request( - method='POST', url=self.db_url, data=data, params=params + method='POST', url=self.db_url, content=data, params=params ) r = self.request_session.send(request, stream=stream) if isinstance(r, httpx.Response) and r.status_code != 200: diff --git a/src/clickhouse_orm/models.py b/src/clickhouse_orm/models.py index b64cd82..d73c646 100644 --- a/src/clickhouse_orm/models.py +++ b/src/clickhouse_orm/models.py @@ -313,7 +313,9 @@ class Model(metaclass=ModelBase): if field: setattr(self, name, value) else: - raise AttributeError('%s does not have a field called %s' % (self.__class__.__name__, name)) + raise AttributeError( + '%s does not have a field called %s' % (self.__class__.__name__, name) + ) def __setattr__(self, name, value): """ @@ -474,7 +476,7 @@ class Model(metaclass=ModelBase): return {name: data[name] for name in fields} @classmethod - def objects_in(cls, database: Database) -> QuerySet: + def objects_in(cls: type[MODEL], database: Database) -> QuerySet[MODEL]: """ Returns a `QuerySet` for selecting instances of this model class. """ diff --git a/src/clickhouse_orm/query.py b/src/clickhouse_orm/query.py index d122431..54f0b1f 100644 --- a/src/clickhouse_orm/query.py +++ b/src/clickhouse_orm/query.py @@ -2,7 +2,17 @@ from __future__ import unicode_literals, annotations from math import ceil from copy import copy, deepcopy from types import CoroutineType -from typing import TYPE_CHECKING, overload, Any, Union, Coroutine, Generic +from typing import ( + TYPE_CHECKING, + overload, + Any, + Union, + Coroutine, + Generic, + TypeVar, + AsyncIterator, + Iterator +) import pytz @@ -14,6 +24,8 @@ if TYPE_CHECKING: from clickhouse_orm.models import Model from clickhouse_orm.database import Database, Page +MODEL = TypeVar('MODEL', bound='Model') + class Operator: """ @@ -307,14 +319,14 @@ class Q: return q -class QuerySet: +class QuerySet(Generic[MODEL]): """ A queryset is an object that represents a database query using a specific `Model`. It is lazy, meaning that it does not hit the database until you iterate over its matching rows (model instances). """ - def __init__(self, model_cls: type[Model], database: Database): + def __init__(self, model_cls: type[MODEL], database: Database): """ Initializer. It is possible to create a queryset like this, but the standard way is to use `MyModel.objects_in(database)`. @@ -333,33 +345,34 @@ class QuerySet: self._distinct = False self._final = False - def __deepcopy__(self, memodict={}): - obj = type(self)(self._model_cls, self._database) - obj._order_by = deepcopy(self._order_by) - obj._where_q = deepcopy(self._where_q) - obj._prewhere_q = deepcopy(self._prewhere_q) - obj._grouping_fields = deepcopy(self._grouping_fields) - obj._grouping_with_totals = deepcopy(self._grouping_with_totals) - obj._fields = deepcopy(self._fields) - obj._limits = deepcopy(self._limits) - obj._limit_by = deepcopy(self._limit_by) - obj._limit_by_fields = deepcopy(self._limit_by_fields) - obj._distinct = deepcopy(self._distinct) - obj._final = deepcopy(self._final) - return obj + def _clone(self) -> "QuerySet[MODEL]": + queryset = type(self)(self._model_cls, self._database) + queryset._order_by = copy(self._order_by) + queryset._where_q = copy(self._where_q) + queryset._prewhere_q = copy(self._prewhere_q) + queryset._grouping_fields = copy(self._grouping_fields) + queryset._grouping_with_totals = self._grouping_with_totals + queryset._fields = self._fields + queryset._limits = copy(self._limits) + queryset._limit_by = copy(self._limit_by) + queryset._limit_by_fields = copy(self._limit_by_fields) + queryset._distinct = self._distinct + queryset._final = self._final + return queryset - def __iter__(self): + def __iter__(self) -> Iterator[MODEL]: """ Iterates over the model instances matching this queryset """ - return self._database.select(self.as_sql(), self._model_cls) + for val in self._database.select(self.as_sql(), self._model_cls): + yield val - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[MODEL]: from clickhouse_orm.aio.database import AioDatabase assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'" - async for r in self._database.select(self.as_sql(), self._model_cls): - yield r + async for val in self._database.select(self.as_sql(), self._model_cls): + yield val def __bool__(self): """ @@ -378,27 +391,27 @@ class QuerySet: ... @overload - def __getitem__(self, s: slice) -> "QuerySet": + def __getitem__(self, s: slice) -> "QuerySet[MODEL]": ... def __getitem__(self, s): if isinstance(s, int): # Single index assert s >= 0, 'negative indexes are not supported' - qs = copy(self) - qs._limits = (s, 1) - return next(iter(qs)) + queryset = self._clone() + queryset._limits = (s, 1) + return next(iter(queryset)) # Slice assert s.step in (None, 1), 'step is not supported in slices' start = s.start or 0 stop = s.stop or 2 ** 63 - 1 assert start >= 0 and stop >= 0, 'negative indexes are not supported' assert start <= stop, 'start of slice cannot be smaller than its end' - qs = copy(self) - qs._limits = (start, stop - start) - return qs + queryset = self._clone() + queryset._limits = (start, stop - start) + return queryset - def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet": + def limit_by(self, offset_limit, *fields_or_expr) -> "QuerySet[MODEL]": """ Adds a LIMIT BY clause to the query. - `offset_limit`: either an integer specifying the limit, or a tuple of integers (offset, limit). @@ -410,10 +423,10 @@ class QuerySet: offset = offset_limit[0] limit = offset_limit[1] assert offset >= 0 and limit >= 0, 'negative limits are not supported' - qs = copy(self) - qs._limit_by = (offset, limit) - qs._limit_by_fields = fields_or_expr - return qs + queryset = self._clone() + queryset._limit_by = (offset, limit) + queryset._limit_by_fields = fields_or_expr + return queryset def select_fields_as_sql(self) -> str: """ @@ -490,31 +503,31 @@ class QuerySet: conditions = (self._where_q & self._prewhere_q).to_sql(self._model_cls) return self._database.count(self._model_cls, conditions) - def order_by(self, *field_names) -> "QuerySet": + def order_by(self, *field_names) -> "QuerySet[MODEL]": """ Returns a copy of this queryset with the ordering changed. """ - qs = copy(self) - qs._order_by = field_names - return qs + queryset = self._clone() + queryset._order_by = field_names + return queryset - def only(self, *field_names) -> "QuerySet": + def only(self, *field_names) -> "QuerySet[MODEL]": """ Returns a copy of this queryset limited to the specified field names. Useful when there are large fields that are not needed, or for creating a subquery to use with an IN operator. """ - qs = copy(self) - qs._fields = field_names - return qs + queryset = self._clone() + queryset._fields = field_names + return queryset - def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet": + def _filter_or_exclude(self, *q, **kwargs) -> "QuerySet[MODEL]": from clickhouse_orm.funcs import F inverse = kwargs.pop('_inverse', False) prewhere = kwargs.pop('prewhere', False) - qs = copy(self) + queryset = self._clone() condition = Q() for arg in q: @@ -533,20 +546,20 @@ class QuerySet: condition = copy(self._prewhere_q if prewhere else self._where_q) & condition if prewhere: - qs._prewhere_q = condition + queryset._prewhere_q = condition else: - qs._where_q = condition + queryset._where_q = condition - return qs + return queryset - def filter(self, *q, **kwargs) -> "QuerySet": + def filter(self, *q: Q, **kwargs: Any) -> "QuerySet[MODEL]": """ Returns a copy of this queryset that includes only rows matching the conditions. Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE. """ return self._filter_or_exclude(*q, **kwargs) - def exclude(self, *q, **kwargs) -> "QuerySet": + def exclude(self, *q, **kwargs) -> "QuerySet[MODEL]": """ Returns a copy of this queryset that excludes all rows matching the conditions. Pass `prewhere=True` to apply the conditions as PREWHERE instead of WHERE. @@ -582,16 +595,16 @@ class QuerySet: page_size=page_size ) - def distinct(self) -> "QuerySet": + def distinct(self) -> "QuerySet[MODEL]": """ Adds a DISTINCT clause to the query, meaning that any duplicate rows in the results will be omitted. """ - qs = copy(self) - qs._distinct = True - return qs + queryset = self._clone() + queryset._distinct = True + return queryset - def final(self) -> "QuerySet": + def final(self) -> "QuerySet[MODEL]": """ Adds a FINAL modifier to table, meaning data will be collapsed to final version. Can be used with the `CollapsingMergeTree` and `ReplacingMergeTree` engines only. @@ -604,11 +617,11 @@ class QuerySet: ' and ReplacingMergeTree engines' ) - qs = copy(self) - qs._final = True - return qs + queryset = self._clone() + queryset._final = True + return queryset - def delete(self) -> "QuerySet": + def delete(self) -> "QuerySet[MODEL]": """ Deletes all records matched by this queryset's conditions. Note that ClickHouse performs deletions in the background, so they are not immediate. @@ -619,7 +632,7 @@ class QuerySet: self._database.raw(sql) return self - def update(self, **kwargs) -> "QuerySet": + def update(self, **kwargs) -> "QuerySet[MODEL]": """ Updates all records matched by this queryset's conditions. Keyword arguments specify the field names and expressions to use for the update. @@ -644,7 +657,7 @@ class QuerySet: assert not self._distinct, 'Mutations are not allowed after calling distinct()' assert not self._final, 'Mutations are not allowed after calling final()' - def aggregate(self, *args, **kwargs) -> "AggregateQuerySet": + def aggregate(self, *args, **kwargs) -> "AggregateQuerySet[MODEL]": """ Returns an `AggregateQuerySet` over this query, with `args` serving as grouping fields and `kwargs` serving as calculated fields. At least one @@ -662,7 +675,7 @@ class QuerySet: return AggregateQuerySet(self, args, kwargs) -class AggregateQuerySet(QuerySet): +class AggregateQuerySet(QuerySet[MODEL]): """ A queryset used for aggregation. """ @@ -699,7 +712,7 @@ class AggregateQuerySet(QuerySet): self._limits = base_queryset._limits self._distinct = base_queryset._distinct - def group_by(self, *args) -> "AggregateQuerySet": + def group_by(self, *args) -> "AggregateQuerySet[MODEL]": """ This method lets you specify the grouping fields explicitly. The `args` must be names of grouping fields or calculated fields that this queryset was @@ -708,9 +721,9 @@ class AggregateQuerySet(QuerySet): for name in args: assert name in self._fields or name in self._calculated_fields, \ 'Cannot group by `%s` since it is not included in the query' % name - qs = copy(self) - qs._grouping_fields = args - return qs + queryset = copy(self) + queryset._grouping_fields = args + return queryset def only(self, *field_names): """ @@ -731,8 +744,20 @@ class AggregateQuerySet(QuerySet): return comma_join([str(f) for f in self._fields] + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) - def __iter__(self): - return self._database.select(self.as_sql()) # using an ad-hoc model + def __iter__(self) -> Iterator[Model]: + """ + using an ad-hoc model + """ + for val in self._database.select(self.as_sql()): + yield val + return self._database.select(self.as_sql()) + + async def __aiter__(self) -> AsyncIterator[Model]: + from clickhouse_orm.aio.database import AioDatabase + + assert isinstance(self._database, AioDatabase), "only AioDatabase support 'async for'" + async for val in self._database.select(self.as_sql()): + yield val def count(self) -> Union[int, Coroutine[int]]: """ @@ -744,15 +769,15 @@ class AggregateQuerySet(QuerySet): return raw return int(raw) if raw else 0 - def with_totals(self) -> "AggregateQuerySet": + def with_totals(self) -> "AggregateQuerySet[MODEL]": """ Adds WITH TOTALS modifier ot GROUP BY, making query return extra row with aggregate function calculated across all the rows. More information: https://clickhouse.tech/docs/en/query_language/select/#with-totals-modifier """ - qs = copy(self) - qs._grouping_with_totals = True - return qs + queryset = copy(self) + queryset._grouping_with_totals = True + return queryset def _verify_mutation_allowed(self): raise AssertionError('Cannot mutate an AggregateQuerySet')