diff --git a/src/infi/clickhouse_orm/query.py b/src/infi/clickhouse_orm/query.py index 4a77226..b6a8baf 100644 --- a/src/infi/clickhouse_orm/query.py +++ b/src/infi/clickhouse_orm/query.py @@ -252,6 +252,7 @@ class QuerySet(object): self._order_by = [] self._where_q = Q() self._prewhere_q = Q() + self._grouping_fields = [] self._fields = model_cls.fields().keys() self._limits = None self._distinct = False @@ -292,19 +293,34 @@ class QuerySet(object): qs._limits = (start, stop - start) return qs + def select_fields_as_sql(self): + return comma_join('`%s`' % field for field in self._fields) if self._fields else '*' + def as_sql(self): """ Returns the whole query as a SQL string. """ distinct = 'DISTINCT ' if self._distinct else '' - fields = '*' - if self._fields: - fields = comma_join('`%s`' % field for field in self._fields) - ordering = '\nORDER BY ' + self.order_by_as_sql() if self._order_by else '' - limit = '\nLIMIT %d, %d' % self._limits if self._limits else '' - params = (distinct, fields, self._model_cls.table_name(), - self.conditions_as_sql(), ordering, limit) - return u'SELECT %s%s\nFROM `%s`\nWHERE %s%s%s' % params + + params = (distinct, self.select_fields_as_sql(), self._model_cls.table_name()) + sql = u'SELECT %s%s\nFROM `%s`\n' % params + + if self._prewhere_q: + sql += '\nPREWHERE ' + self.conditions_as_sql(self._prewhere_q) + + if self._where_q: + sql += '\nWHERE ' + self.conditions_as_sql(self._where_q) + + if self._grouping_fields: + sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) + + if self._order_by: + sql += '\nORDER BY ' + self.order_by_as_sql() + + if self._limits: + sql += '\nLIMIT %d, %d' % self._limits + + return def order_by_as_sql(self): """ @@ -333,8 +349,10 @@ class QuerySet(object): sql = u'SELECT count() FROM (%s)' % self.as_sql() raw = self._database.raw(sql) return int(raw) if raw else 0 + # Simple case - return self._database.count(self._model_cls, self.conditions_as_sql()) + conditions = self.conditions_as_sql(self._where_q & self._prewhere_q) + return self._database.count(self._model_cls, conditions) def order_by(self, *field_names): """ @@ -498,36 +516,8 @@ class AggregateQuerySet(QuerySet): """ raise NotImplementedError('Cannot re-aggregate an AggregateQuerySet') - def as_sql(self): - """ - Returns the whole query as a SQL string. - """ - distinct = 'DISTINCT ' if self._distinct else '' - fields = comma_join(list(self._fields) + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) - - params = dict( - distinct=distinct, - fields=fields, - table=self._model_cls.table_name(), - ) - sql = u'SELECT %(distinct)s%(fields)s\nFROM `%(table)s`' % params - - if self._prewhere_q: - sql += '\nPREWHERE ' + self.conditions_as_sql(self._prewhere_q) - - if self._where_q: - sql += '\nWHERE ' + self.conditions_as_sql(self._where_q) - - if self._grouping_fields: - sql += '\nGROUP BY %s' % comma_join('`%s`' % field for field in self._grouping_fields) - - if self._order_by: - sql += '\nORDER BY ' + self.order_by_as_sql() - - if self._limits: - sql += '\nLIMIT %d, %d' % self._limits - - return sql + def select_fields_as_sql(self): + return comma_join(list(self._fields) + ['%s AS %s' % (v, k) for k, v in self._calculated_fields.items()]) def __iter__(self): return self._database.select(self.as_sql()) # using an ad-hoc model diff --git a/tests/test_querysets.py b/tests/test_querysets.py index a4fef14..9dd1ca4 100644 --- a/tests/test_querysets.py +++ b/tests/test_querysets.py @@ -369,13 +369,13 @@ class AggregateTestCase(TestCaseWithData): the__next__number = Int32Field() engine = Memory() qs = Mdl.objects_in(self.database).filter(the__number=1) - self.assertEqual(qs.conditions_as_sql(), 'the__number = 1') + self.assertEqual(qs.conditions_as_sql(qs._where_q), 'the__number = 1') qs = Mdl.objects_in(self.database).filter(the__number__gt=1) - self.assertEqual(qs.conditions_as_sql(), 'the__number > 1') + self.assertEqual(qs.conditions_as_sql(qs._where_q), 'the__number > 1') qs = Mdl.objects_in(self.database).filter(the__next__number=1) - self.assertEqual(qs.conditions_as_sql(), 'the__next__number = 1') + self.assertEqual(qs.conditions_as_sql(qs._where_q), 'the__next__number = 1') qs = Mdl.objects_in(self.database).filter(the__next__number__gt=1) - self.assertEqual(qs.conditions_as_sql(), 'the__next__number > 1') + self.assertEqual(qs.conditions_as_sql(qs._where_q), 'the__next__number > 1') Color = Enum('Color', u'red blue green yellow brown white black')