diff --git a/graphene/contrib/sqlalchemy/converter.py b/graphene/contrib/sqlalchemy/converter.py index 4e609dca..25ed6b35 100644 --- a/graphene/contrib/sqlalchemy/converter.py +++ b/graphene/contrib/sqlalchemy/converter.py @@ -19,16 +19,13 @@ def convert_sqlalchemy_relationship(relationship): def convert_sqlalchemy_column(column): - try: - return convert_sqlalchemy_type(column.type, column) - except Exception: - raise Exception( - "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) + return convert_sqlalchemy_type(getattr(column, 'type', None), column) @singledispatch def convert_sqlalchemy_type(type, column): - raise Exception() + raise Exception( + "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) @convert_sqlalchemy_type.register(types.Date) diff --git a/graphene/contrib/sqlalchemy/fields.py b/graphene/contrib/sqlalchemy/fields.py index c38ffcac..03209c38 100644 --- a/graphene/contrib/sqlalchemy/fields.py +++ b/graphene/contrib/sqlalchemy/fields.py @@ -4,29 +4,25 @@ from ...core.types.base import FieldType from ...core.types.definitions import List from ...relay import ConnectionField from ...relay.utils import is_node -from .utils import get_type_for_model +from .utils import get_type_for_model, maybe_query, get_query class SQLAlchemyConnectionField(ConnectionField): def __init__(self, *args, **kwargs): - self.session = kwargs.pop('session', None) return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs) @property def model(self): return self.type._meta.model - def get_session(self, args, info): - return self.session - def get_query(self, resolved_query, args, info): - self.get_session(args, info) - return resolved_query + return resolved_query or get_query(self.model, info) def from_list(self, connection_type, resolved, args, info): - qs = self.get_query(resolved, args, info) - return super(SQLAlchemyConnectionField, self).from_list(connection_type, qs, args, info) + query = self.get_query(resolved, args, info) + query = maybe_query(query) + return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info) class ConnectionOrListField(Field): diff --git a/graphene/contrib/sqlalchemy/utils.py b/graphene/contrib/sqlalchemy/utils.py index 18093d0e..b3e47bdf 100644 --- a/graphene/contrib/sqlalchemy/utils.py +++ b/graphene/contrib/sqlalchemy/utils.py @@ -1,8 +1,7 @@ from sqlalchemy.ext.declarative.api import DeclarativeMeta +from sqlalchemy.orm.query import Query - -# from sqlalchemy.orm.base import object_mapper -# from sqlalchemy.orm.exc import UnmappedInstanceError +from graphene.utils import LazyList def get_type_for_model(schema, model): @@ -20,10 +19,27 @@ def get_session(info): return schema.options.get('session') +def get_query(model, info): + query = getattr(model, 'query') + if not query: + query = get_session(info).query(model) + return query + + +class WrappedQuery(LazyList): + + def __len__(self): + # Dont calculate the length using len(query), as this will + # evaluate the whole queryset and return it's length. + # Use .count() instead + return self._origin.count() + + +def maybe_query(value): + if isinstance(value, Query): + return WrappedQuery(value) + return value + + def is_mapped(obj): return isinstance(obj, DeclarativeMeta) - # try: - # object_mapper(obj) - # except UnmappedInstanceError: - # return False - # return True