From 639867bb32586d0b439547f29d16948adbc71c5c Mon Sep 17 00:00:00 2001 From: Itai Shirav Date: Fri, 11 Aug 2017 17:26:46 +0300 Subject: [PATCH] - Added `QuerySet.paginate()` - Support for basic aggregation in querysets --- CHANGELOG.md | 5 + docs/class_reference.md | 138 ++++++++++++++++++++- docs/models_and_databases.md | 3 +- docs/querysets.md | 58 ++++++++- docs/toc.md | 3 + scripts/generate_ref.py | 4 +- src/infi/clickhouse_orm/engines.py | 14 ++- src/infi/clickhouse_orm/fields.py | 4 +- src/infi/clickhouse_orm/migrations.py | 2 +- src/infi/clickhouse_orm/query.py | 149 +++++++++++++++++++++-- src/infi/clickhouse_orm/system_models.py | 3 +- src/infi/clickhouse_orm/utils.py | 13 +- tests/test_querysets.py | 148 ++++++++++++++++++++++ 13 files changed, 512 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f875d12..a60ecda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ Change Log ========== +Unreleased +---------- +- Added `QuerySet.paginate()` +- Support for basic aggregation in querysets + v0.9.4 ------ - Migrations: when creating a table for a `BufferModel`, create the underlying table too if necessary diff --git a/docs/class_reference.md b/docs/class_reference.md index d526e20..7e4bc74 100644 --- a/docs/class_reference.md +++ b/docs/class_reference.md @@ -550,6 +550,23 @@ Initializer. It is possible to create a queryset like this, but the standard way is to use `MyModel.objects_in(database)`. +#### aggregate(*args, **kwargs) + + +Returns an `AggregateQuerySet` over this query, with `args` serving as +grouping fields and `kwargs` serving as calculated fields. At least one +calculated field is required. For example: +``` + Event.objects_in(database).filter(date__gt='2017-08-01').aggregate('event_type', count='count()') +``` +is equivalent to: +``` + SELECT event_type, count() AS count FROM event + WHERE data > '2017-08-01' + GROUP BY event_type +``` + + #### as_sql() @@ -571,19 +588,19 @@ Returns the number of matching model instances. #### exclude(**kwargs) -Returns a new `QuerySet` instance that excludes all rows matching the conditions. +Returns a copy of this queryset that excludes all rows matching the conditions. #### filter(**kwargs) -Returns a new `QuerySet` instance that includes only rows matching the conditions. +Returns a copy of this queryset that includes only rows matching the conditions. #### only(*field_names) -Returns a new `QuerySet` instance limited to the specified field names. +Returns a copy of this queryset limited to the specified field names. Useful when there are large fields that are not needed, or for creating a subquery to use with an IN operator. @@ -591,7 +608,7 @@ or for creating a subquery to use with an IN operator. #### order_by(*field_names) -Returns a new `QuerySet` instance with the ordering changed. +Returns a copy of this queryset with the ordering changed. #### order_by_as_sql() @@ -600,3 +617,116 @@ Returns a new `QuerySet` instance with the ordering changed. Returns the contents of the query's `ORDER BY` clause as a string. +#### paginate(page_num=1, page_size=100) + + +Returns a single page of model instances that match the queryset. +Note that `order_by` should be used first, to ensure a correct +partitioning of records into pages. + +- `page_num`: the page number (1-based), or -1 to get the last page. +- `page_size`: number of records to return per page. + +The result is a namedtuple containing `objects` (list), `number_of_objects`, +`pages_total`, `number` (of the current page), and `page_size`. + + +### AggregateQuerySet + +Extends QuerySet + + +A queryset used for aggregation. + +#### AggregateQuerySet(base_qs, grouping_fields, calculated_fields) + + +Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`. + +The grouping fields should be a list/tuple of field names from the model. For example: +``` + ('event_type', 'event_subtype') +``` +The calculated fields should be a mapping from name to a ClickHouse aggregation function. For example: +``` + {'weekday': 'toDayOfWeek(event_date)', 'number_of_events': 'count()'} +``` +At least one calculated field is required. + + +#### aggregate(*args, **kwargs) + + +This method is not supported on `AggregateQuerySet`. + + +#### as_sql() + + +Returns the whole query as a SQL string. + + +#### conditions_as_sql() + + +Returns the contents of the query's `WHERE` clause as a string. + + +#### count() + + +Returns the number of rows after aggregation. + + +#### exclude(**kwargs) + + +Returns a copy of this queryset that excludes all rows matching the conditions. + + +#### filter(**kwargs) + + +Returns a copy of this queryset that includes only rows matching the conditions. + + +#### group_by(*args) + + +This method lets you specify the grouping fields explicitly. The `args` must +be names of grouping fields or calculated fields that this queryset was +created with. + + +#### only(*field_names) + + +This method is not supported on `AggregateQuerySet`. + + +#### order_by(*field_names) + + +Returns a copy of this queryset with the ordering changed. + + +#### order_by_as_sql() + + +Returns the contents of the query's `ORDER BY` clause as a string. + + +#### paginate(page_num=1, page_size=100) + + +Returns a single page of model instances that match the queryset. +Note that `order_by` should be used first, to ensure a correct +partitioning of records into pages. + +- `page_num`: the page number (1-based), or -1 to get the last page. +- `page_size`: number of records to return per page. + +The result is a namedtuple containing `objects` (list), `number_of_objects`, +`pages_total`, `number` (of the current page), and `page_size`. + + diff --git a/docs/models_and_databases.md b/docs/models_and_databases.md index ffeb54e..2b84f99 100644 --- a/docs/models_and_databases.md +++ b/docs/models_and_databases.md @@ -158,8 +158,7 @@ The `paginate` method returns a `namedtuple` containing the following fields: - `objects` - the list of objects in this page - `number_of_objects` - total number of objects in all pages - `pages_total` - total number of pages -- `number` - the page number, starting from 1; the special value -1 - may be used to retrieve the last page +- `number` - the page number, starting from 1; the special value -1 may be used to retrieve the last page - `page_size` - the number of objects per page You can optionally pass conditions to the query: diff --git a/docs/querysets.md b/docs/querysets.md index bb10332..2bbefd9 100644 --- a/docs/querysets.md +++ b/docs/querysets.md @@ -99,11 +99,10 @@ When some of the model fields aren't needed, it is more efficient to omit them f qs = Person.objects_in(database).only('first_name', 'birthday') - Slicing ------- -It is possible to get a specific item from the queryset by index. +It is possible to get a specific item from the queryset by index: qs = Person.objects_in(database).order_by('last_name', 'first_name') first = qs[0] @@ -119,6 +118,61 @@ You should use `order_by` to ensure a consistent ordering of the results. Trying to use negative indexes or a slice with a step (e.g. [0:100:2]) is not supported and will raise an `AssertionError`. +Pagination +---------- + +Similar to `Database.paginate`, you can go over the queryset results one page at a time: + + >>> qs = Person.objects_in(database).order_by('last_name', 'first_name') + >>> page = qs.paginate(page_num=1, page_size=10) + >>> print page.number_of_objects + 2507 + >>> print page.pages_total + 251 + >>> for person in page.objects: + >>> # do something + +The `paginate` method returns a `namedtuple` containing the following fields: + +- `objects` - the list of objects in this page +- `number_of_objects` - total number of objects in all pages +- `pages_total` - total number of pages +- `number` - the page number, starting from 1; the special value -1 may be used to retrieve the last page +- `page_size` - the number of objects per page + +Note that you should use `QuerySet.order_by` so that the ordering is unique, otherwise there might be inconsistencies in the pagination (such as an instance that appears on two different pages). + +Aggregation +----------- + +It is possible to use aggregation functions over querysets using the `aggregate` method. The simplest form of aggregation works over all rows in the queryset: + + >>> qs = Person.objects_in(database).aggregate(average_height='avg(height)') + >>> print qs.count() + 1 + >>> for row in qs: print row.average_height + 1.71 + +The returned row or rows are no longer instances of the base model (`Person` in this example), but rather instances of an ad-hoc model that includes only the fields specified in the call to `aggregate`. + +You can pass names of fields from the model that will be included in the query. By default, they will be also used in the GROUP BY clause. For example to count the number of people per last name you could do this: + + qs = Person.objects_in(database).aggregate('last_name', num='count()') + +The underlying SQL query would be something like this: + + SELECT last_name, count() AS num FROM person GROUP BY last_name + +If you would like to control the GROUP BY explicitly, use the `group_by` method. This is useful when you need to group by a calculated field, instead of a field that exists in the model. For example, to count the number of people born on each weekday: + + qs = Person.objects_in(database).aggregate(weekday='toDayOfWeek(birthday)', num='count()').group_by('weekday') + +This queryset is translated to: + + SELECT toDayOfWeek(birthday) AS weekday, count() AS num FROM person GROUP BY weekday + +After calling `aggregate` you can still use most of the regular queryset methods, such as `count`, `order_by` and `paginate`. It is not possible, however, to call `only` or `aggregate`. It is also not possible to filter the queryset on calculated fields, only on fields that exist in the model. + --- [<< Models and Databases](models_and_databases.md) | [Table of Contents](toc.md) | [Field Types >>](field_types.md) \ No newline at end of file diff --git a/docs/toc.md b/docs/toc.md index da69986..aa5bb3b 100644 --- a/docs/toc.md +++ b/docs/toc.md @@ -21,6 +21,8 @@ * [Ordering](querysets.md#ordering) * [Omitting Fields](querysets.md#omitting-fields) * [Slicing](querysets.md#slicing) + * [Pagination](querysets.md#pagination) + * [Aggregation](querysets.md#aggregation) * [Field Types](field_types.md#field-types) * [DateTimeField and Time Zones](field_types.md#datetimefield-and-time-zones) @@ -88,4 +90,5 @@ * [ReplacingMergeTree](class_reference.md#replacingmergetree) * [infi.clickhouse_orm.query](class_reference.md#infi.clickhouse_orm.query) * [QuerySet](class_reference.md#queryset) + * [AggregateQuerySet](class_reference.md#aggregatequeryset) diff --git a/scripts/generate_ref.py b/scripts/generate_ref.py index 8d11249..c35e881 100644 --- a/scripts/generate_ref.py +++ b/scripts/generate_ref.py @@ -110,7 +110,7 @@ def module_doc(classes, list_methods=True): print '-' * len(mdl) print for cls in classes: - class_doc(cls, list_methods) + class_doc(cls, list_methods) def all_subclasses(cls): @@ -132,4 +132,4 @@ if __name__ == '__main__': module_doc([models.Model, models.BufferModel]) module_doc([fields.Field] + all_subclasses(fields.Field), False) module_doc([engines.Engine] + all_subclasses(engines.Engine), False) - module_doc([query.QuerySet]) + module_doc([query.QuerySet, query.AggregateQuerySet]) diff --git a/src/infi/clickhouse_orm/engines.py b/src/infi/clickhouse_orm/engines.py index 9db37da..b28fedd 100644 --- a/src/infi/clickhouse_orm/engines.py +++ b/src/infi/clickhouse_orm/engines.py @@ -1,8 +1,10 @@ +from .utils import comma_join + class Engine(object): def create_table_sql(self): - raise NotImplementedError() + raise NotImplementedError() # pragma: no cover class TinyLog(Engine): @@ -41,7 +43,7 @@ class MergeTree(Engine): if self.replica_name: name = 'Replicated' + name params = self._build_sql_params() - return '%s(%s)' % (name, ', '.join(params)) + return '%s(%s)' % (name, comma_join(params)) def _build_sql_params(self): params = [] @@ -50,7 +52,7 @@ class MergeTree(Engine): params.append(self.date_col) if self.sampling_expr: params.append(self.sampling_expr) - params.append('(%s)' % ', '.join(self.key_cols)) + params.append('(%s)' % comma_join(self.key_cols)) params.append(str(self.index_granularity)) return params @@ -79,7 +81,7 @@ class SummingMergeTree(MergeTree): def _build_sql_params(self): params = super(SummingMergeTree, self)._build_sql_params() if self.summing_cols: - params.append('(%s)' % ', '.join(self.summing_cols)) + params.append('(%s)' % comma_join(self.summing_cols)) return params @@ -103,7 +105,7 @@ class Buffer(Engine): Must be used in conjuction with a `BufferModel`. Read more [here](https://clickhouse.yandex/reference_en.html#Buffer). """ - + #Buffer(database, table, num_layers, min_time, max_time, min_rows, max_rows, min_bytes, max_bytes) def __init__(self, main_model, num_layers=16, min_time=10, max_time=100, min_rows=10000, max_rows=1000000, min_bytes=10000000, max_bytes=100000000): self.main_model = main_model @@ -117,7 +119,7 @@ class Buffer(Engine): def create_table_sql(self, db_name): - # Overriden create_table_sql example: + # Overriden create_table_sql example: #sql = 'ENGINE = Buffer(merge, hits, 16, 10, 100, 10000, 1000000, 10000000, 100000000)' sql = 'ENGINE = Buffer(`%s`, `%s`, %d, %d, %d, %d, %d, %d, %d)' % ( db_name, self.main_model.table_name(), self.num_layers, diff --git a/src/infi/clickhouse_orm/fields.py b/src/infi/clickhouse_orm/fields.py index 988ace7..e6e100a 100644 --- a/src/infi/clickhouse_orm/fields.py +++ b/src/infi/clickhouse_orm/fields.py @@ -4,7 +4,7 @@ import pytz import time from calendar import timegm -from .utils import escape, parse_array +from .utils import escape, parse_array, comma_join class Field(object): @@ -356,7 +356,7 @@ class ArrayField(Field): def to_db_string(self, value, quote=True): array = [self.inner_field.to_db_string(v, quote=True) for v in value] - return '[' + ', '.join(array) + ']' + return '[' + comma_join(array) + ']' def get_sql(self, with_default=True): from .utils import escape diff --git a/src/infi/clickhouse_orm/migrations.py b/src/infi/clickhouse_orm/migrations.py index ebcbacc..1a82a5b 100644 --- a/src/infi/clickhouse_orm/migrations.py +++ b/src/infi/clickhouse_orm/migrations.py @@ -15,7 +15,7 @@ class Operation(object): ''' def apply(self, database): - raise NotImplementedError() + raise NotImplementedError() # pragma: no cover class CreateTable(Operation): diff --git a/src/infi/clickhouse_orm/query.py b/src/infi/clickhouse_orm/query.py index cecee19..217fb68 100644 --- a/src/infi/clickhouse_orm/query.py +++ b/src/infi/clickhouse_orm/query.py @@ -1,12 +1,13 @@ import six import pytz from copy import copy +from math import ceil +from .utils import comma_join # TODO # - and/or between Q objects # - check that field names are valid -# - qs slicing # - operators for arrays: length, has, empty class Operator(object): @@ -19,7 +20,7 @@ class Operator(object): Subclasses should implement this method. It returns an SQL string that applies this operator on the given field and value. """ - raise NotImplementedError + raise NotImplementedError # pragma: no cover class SimpleOperator(Operator): @@ -52,7 +53,7 @@ class InOperator(Operator): elif isinstance(value, six.string_types): pass else: - value = ', '.join([field.to_db_string(field.to_python(v, pytz.utc)) for v in value]) + value = comma_join([field.to_db_string(field.to_python(v, pytz.utc)) for v in value]) return '%s IN (%s)' % (field_name, value) @@ -189,6 +190,7 @@ class QuerySet(object): """ Iterates over the model instances matching this queryset """ + print self.as_sql() return self._database.select(self.as_sql(), self._model_cls) def __bool__(self): @@ -227,7 +229,7 @@ class QuerySet(object): """ fields = '*' if self._fields: - fields = ', '.join('`%s`' % field for field in self._fields) + fields = comma_join('`%s`' % field for field in self._fields) ordering = '\nORDER BY ' + self.order_by_as_sql() if self._order_by else '' limit = '\nLIMIT %d, %d' % self._limits if self._limits else '' params = (fields, self._model_cls.table_name(), @@ -238,7 +240,7 @@ class QuerySet(object): """ Returns the contents of the query's `ORDER BY` clause as a string. """ - return u', '.join([ + return comma_join([ '%s DESC' % field[1:] if field[0] == '-' else field for field in self._order_by ]) @@ -260,7 +262,7 @@ class QuerySet(object): def order_by(self, *field_names): """ - Returns a new `QuerySet` instance with the ordering changed. + Returns a copy of this queryset with the ordering changed. """ qs = copy(self) qs._order_by = field_names @@ -268,7 +270,7 @@ class QuerySet(object): def only(self, *field_names): """ - Returns a new `QuerySet` instance limited to the specified field names. + Returns a copy of this queryset limited to the specified field names. Useful when there are large fields that are not needed, or for creating a subquery to use with an IN operator. """ @@ -278,7 +280,7 @@ class QuerySet(object): def filter(self, **kwargs): """ - Returns a new `QuerySet` instance that includes only rows matching the conditions. + Returns a copy of this queryset that includes only rows matching the conditions. """ qs = copy(self) qs._q = list(self._q) + [Q(**kwargs)] @@ -286,8 +288,137 @@ class QuerySet(object): def exclude(self, **kwargs): """ - Returns a new `QuerySet` instance that excludes all rows matching the conditions. + Returns a copy of this queryset that excludes all rows matching the conditions. """ qs = copy(self) qs._q = list(self._q) + [~Q(**kwargs)] return qs + + def paginate(self, page_num=1, page_size=100): + ''' + Returns a single page of model instances that match the queryset. + Note that `order_by` should be used first, to ensure a correct + partitioning of records into pages. + + - `page_num`: the page number (1-based), or -1 to get the last page. + - `page_size`: number of records to return per page. + + The result is a namedtuple containing `objects` (list), `number_of_objects`, + `pages_total`, `number` (of the current page), and `page_size`. + ''' + from .database import Page + count = self.count() + pages_total = int(ceil(count / float(page_size))) + if page_num == -1: + page_num = pages_total + elif page_num < 1: + raise ValueError('Invalid page number: %d' % page_num) + offset = (page_num - 1) * page_size + return Page( + objects=list(self[offset : offset + page_size]), + number_of_objects=count, + pages_total=pages_total, + number=page_num, + page_size=page_size + ) + + def aggregate(self, *args, **kwargs): + ''' + Returns an `AggregateQuerySet` over this query, with `args` serving as + grouping fields and `kwargs` serving as calculated fields. At least one + calculated field is required. For example: + ``` + Event.objects_in(database).filter(date__gt='2017-08-01').aggregate('event_type', count='count()') + ``` + is equivalent to: + ``` + SELECT event_type, count() AS count FROM event + WHERE data > '2017-08-01' + GROUP BY event_type + ``` + ''' + return AggregateQuerySet(self, args, kwargs) + + +class AggregateQuerySet(QuerySet): + """ + A queryset used for aggregation. + """ + + def __init__(self, base_qs, grouping_fields, calculated_fields): + """ + Initializer. Normally you should not call this but rather use `QuerySet.aggregate()`. + + The grouping fields should be a list/tuple of field names from the model. For example: + ``` + ('event_type', 'event_subtype') + ``` + The calculated fields should be a mapping from name to a ClickHouse aggregation function. For example: + ``` + {'weekday': 'toDayOfWeek(event_date)', 'number_of_events': 'count()'} + ``` + At least one calculated field is required. + """ + super(AggregateQuerySet, self).__init__(base_qs._model_cls, base_qs._database) + assert calculated_fields, 'No calculated fields specified for aggregation' + self._fields = grouping_fields + self._grouping_fields = grouping_fields + self._calculated_fields = calculated_fields + self._order_by = list(base_qs._order_by) + self._q = list(base_qs._q) + self._limits = base_qs._limits + + def group_by(self, *args): + """ + This method lets you specify the grouping fields explicitly. The `args` must + be names of grouping fields or calculated fields that this queryset was + created with. + """ + for name in args: + assert name in self._fields or name in self._calculated_fields, \ + 'Cannot group by `%s` since it is not included in the query' % name + qs = copy(self) + qs._grouping_fields = args + return qs + + def only(self, *field_names): + """ + This method is not supported on `AggregateQuerySet`. + """ + raise NotImplementedError('Cannot use "only" with AggregateQuerySet') + + def aggregate(self, *args, **kwargs): + """ + This method is not supported on `AggregateQuerySet`. + """ + raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') + + def as_sql(self): + """ + Returns the whole query as a SQL string. + """ + grouping = comma_join('`%s`' % field for field in self._grouping_fields) + fields = comma_join(list(self._fields) + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) + params = dict( + grouping=grouping or "''", + fields=fields, + table=self._model_cls.table_name(), + conds=self.conditions_as_sql() + ) + sql = u'SELECT %(fields)s\nFROM `%(table)s`\nWHERE %(conds)s\nGROUP BY %(grouping)s' % params + if self._order_by: + sql += '\nORDER BY ' + self.order_by_as_sql() + if self._limits: + sql += '\nLIMIT %d, %d' % self._limits + return sql + + def __iter__(self): + return self._database.select(self.as_sql()) # using an ad-hoc model + + def count(self): + """ + Returns the number of rows after aggregation. + """ + sql = u'SELECT count() FROM (%s)' % self.as_sql() + raw = self._database.raw(sql) + return int(raw) if raw else 0 diff --git a/src/infi/clickhouse_orm/system_models.py b/src/infi/clickhouse_orm/system_models.py index c151302..b1e1bb7 100644 --- a/src/infi/clickhouse_orm/system_models.py +++ b/src/infi/clickhouse_orm/system_models.py @@ -7,6 +7,7 @@ from six import string_types from .database import Database from .fields import * from .models import Model +from .utils import comma_join class SystemPart(Model): @@ -61,7 +62,7 @@ class SystemPart(Model): :return: Operation execution result """ operation = operation.upper() - assert operation in self.OPERATIONS, "operation must be in [%s]" % ', '.join(self.OPERATIONS) + assert operation in self.OPERATIONS, "operation must be in [%s]" % comma_join(self.OPERATIONS) sql = "ALTER TABLE `%s`.`%s` %s PARTITION '%s'" % (self._database.db_name, self.table, operation, self.partition) if from_part is not None: sql += " FROM %s" % from_part diff --git a/src/infi/clickhouse_orm/utils.py b/src/infi/clickhouse_orm/utils.py index 83d11e0..5a8a17a 100644 --- a/src/infi/clickhouse_orm/utils.py +++ b/src/infi/clickhouse_orm/utils.py @@ -21,7 +21,7 @@ SPECIAL_CHARS_REGEX = re.compile("[" + ''.join(SPECIAL_CHARS.values()) + "]") def escape(value, quote=True): ''' If the value is a string, escapes any special characters and optionally - surrounds it with single quotes. If the value is not a string (e.g. a number), + surrounds it with single quotes. If the value is not a string (e.g. a number), converts it to one. ''' def escape_one(match): @@ -56,7 +56,7 @@ def parse_array(array_string): if len(array_string) < 2 or array_string[0] != '[' or array_string[-1] != ']': raise ValueError('Invalid array string: "%s"' % array_string) # Drop opening brace - array_string = array_string[1:] + array_string = array_string[1:] # Go over the string, lopping off each value at the beginning until nothing is left values = [] while True: @@ -65,7 +65,7 @@ def parse_array(array_string): return values elif array_string[0] in ', ': # In between values - array_string = array_string[1:] + array_string = array_string[1:] elif array_string[0] == "'": # Start of quoted value, find its end match = re.search(r"[^\\]'", array_string) @@ -90,3 +90,10 @@ def import_submodules(package_name): name: importlib.import_module(package_name + '.' + name) for _, name, _ in pkgutil.iter_modules(package.__path__) } + + +def comma_join(items): + """ + Joins an iterable of strings with commas. + """ + return ', '.join(items) diff --git a/tests/test_querysets.py b/tests/test_querysets.py index c26e3c4..ded1d84 100644 --- a/tests/test_querysets.py +++ b/tests/test_querysets.py @@ -162,6 +162,154 @@ class QuerySetTestCase(TestCaseWithData): with self.assertRaises(AssertionError): qs[50:1] + def test_pagination(self): + qs = Person.objects_in(self.database).order_by('first_name', 'last_name') + # Try different page sizes + for page_size in (1, 2, 7, 10, 30, 100, 150): + # Iterate over pages and collect all intances + page_num = 1 + instances = set() + while True: + page = qs.paginate(page_num, page_size) + self.assertEquals(page.number_of_objects, len(data)) + self.assertGreater(page.pages_total, 0) + [instances.add(obj.to_tsv()) for obj in page.objects] + if page.pages_total == page_num: + break + page_num += 1 + # Verify that all instances were returned + self.assertEquals(len(instances), len(data)) + + def test_pagination_last_page(self): + qs = Person.objects_in(self.database).order_by('first_name', 'last_name') + # Try different page sizes + for page_size in (1, 2, 7, 10, 30, 100, 150): + # Ask for the last page in two different ways and verify equality + page_a = qs.paginate(-1, page_size) + page_b = qs.paginate(page_a.pages_total, page_size) + self.assertEquals(page_a[1:], page_b[1:]) + self.assertEquals([obj.to_tsv() for obj in page_a.objects], + [obj.to_tsv() for obj in page_b.objects]) + + def test_pagination_invalid_page(self): + qs = Person.objects_in(self.database).order_by('first_name', 'last_name') + for page_num in (0, -2, -100): + with self.assertRaises(ValueError): + qs.paginate(page_num, 100) + + def test_pagination_with_conditions(self): + qs = Person.objects_in(self.database).order_by('first_name', 'last_name').filter(first_name__lt='Ava') + page = qs.paginate(1, 100) + self.assertEquals(page.number_of_objects, 10) + + +class AggregateTestCase(TestCaseWithData): + + def setUp(self): + super(AggregateTestCase, self).setUp() + self.database.insert(self._sample_data()) + + def test_aggregate_no_grouping(self): + qs = Person.objects_in(self.database).aggregate(average_height='avg(height)', count='count()') + print qs.as_sql() + self.assertEquals(qs.count(), 1) + for row in qs: + self.assertAlmostEqual(row.average_height, 1.6923, places=4) + self.assertEquals(row.count, 100) + + def test_aggregate_with_filter(self): + # When filter comes before aggregate + qs = Person.objects_in(self.database).filter(first_name='Warren').aggregate(average_height='avg(height)', count='count()') + print qs.as_sql() + self.assertEquals(qs.count(), 1) + for row in qs: + self.assertAlmostEqual(row.average_height, 1.675, places=4) + self.assertEquals(row.count, 2) + # When filter comes after aggregate + qs = Person.objects_in(self.database).aggregate(average_height='avg(height)', count='count()').filter(first_name='Warren') + print qs.as_sql() + self.assertEquals(qs.count(), 1) + for row in qs: + self.assertAlmostEqual(row.average_height, 1.675, places=4) + self.assertEquals(row.count, 2) + + def test_aggregate_with_implicit_grouping(self): + qs = Person.objects_in(self.database).aggregate('first_name', average_height='avg(height)', count='count()') + print qs.as_sql() + self.assertEquals(qs.count(), 94) + total = 0 + for row in qs: + self.assertTrue(1.5 < row.average_height < 2) + self.assertTrue(0 < row.count < 3) + total += row.count + self.assertEquals(total, 100) + + def test_aggregate_with_explicit_grouping(self): + qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') + print qs.as_sql() + self.assertEquals(qs.count(), 7) + total = 0 + for row in qs: + total += row.count + self.assertEquals(total, 100) + + def test_aggregate_with_order_by(self): + qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') + days = [row.weekday for row in qs.order_by('weekday')] + self.assertEquals(days, range(1, 8)) + + def test_aggregate_with_indexing(self): + qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') + total = 0 + for i in range(7): + total += qs[i].count + self.assertEquals(total, 100) + + def test_aggregate_with_slicing(self): + qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') + total = sum(row.count for row in qs[:3]) + sum(row.count for row in qs[3:]) + self.assertEquals(total, 100) + + def test_aggregate_with_pagination(self): + qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') + total = 0 + page_num = 1 + while True: + page = qs.paginate(page_num, page_size=3) + self.assertEquals(page.number_of_objects, 7) + total += sum(row.count for row in page.objects) + if page.pages_total == page_num: + break + page_num += 1 + self.assertEquals(total, 100) + + def test_aggregate_with_wrong_grouping(self): + with self.assertRaises(AssertionError): + Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('first_name') + + def test_aggregate_with_no_calculated_fields(self): + with self.assertRaises(AssertionError): + Person.objects_in(self.database).aggregate() + + def test_aggregate_with_only(self): + # Cannot put only() after aggregate() + with self.assertRaises(NotImplementedError): + Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').only('weekday') + # When only() comes before aggregate(), it gets overridden + qs = Person.objects_in(self.database).only('last_name').aggregate(average_height='avg(height)', count='count()') + self.assertTrue('last_name' not in qs.as_sql()) + + def test_aggregate_on_aggregate(self): + with self.assertRaises(NotImplementedError): + Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').aggregate(s='sum(height)') + + def test_filter_on_calculated_field(self): + # This is currently not supported, so we expect it to fail + with self.assertRaises(AttributeError): + qs = Person.objects_in(self.database).aggregate(weekday='toDayOfWeek(birthday)', count='count()').group_by('weekday') + qs = qs.filter(weekday=1) + self.assertEquals(qs.count(), 1) + Color = Enum('Color', u'red blue green yellow brown white black')