mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-02 20:54:16 +03:00
Improved connection resolver from list
This commit is contained in:
parent
d4ecd504e4
commit
bc6240d378
|
@ -19,16 +19,13 @@ def convert_sqlalchemy_relationship(relationship):
|
||||||
|
|
||||||
|
|
||||||
def convert_sqlalchemy_column(column):
|
def convert_sqlalchemy_column(column):
|
||||||
try:
|
return convert_sqlalchemy_type(getattr(column, 'type', None), column)
|
||||||
return convert_sqlalchemy_type(column.type, column)
|
|
||||||
except Exception:
|
|
||||||
raise Exception(
|
|
||||||
"Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__))
|
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def convert_sqlalchemy_type(type, column):
|
def convert_sqlalchemy_type(type, column):
|
||||||
raise Exception()
|
raise Exception(
|
||||||
|
"Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__))
|
||||||
|
|
||||||
|
|
||||||
@convert_sqlalchemy_type.register(types.Date)
|
@convert_sqlalchemy_type.register(types.Date)
|
||||||
|
|
|
@ -4,29 +4,25 @@ from ...core.types.base import FieldType
|
||||||
from ...core.types.definitions import List
|
from ...core.types.definitions import List
|
||||||
from ...relay import ConnectionField
|
from ...relay import ConnectionField
|
||||||
from ...relay.utils import is_node
|
from ...relay.utils import is_node
|
||||||
from .utils import get_type_for_model
|
from .utils import get_type_for_model, maybe_query, get_query
|
||||||
|
|
||||||
|
|
||||||
class SQLAlchemyConnectionField(ConnectionField):
|
class SQLAlchemyConnectionField(ConnectionField):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.session = kwargs.pop('session', None)
|
|
||||||
return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs)
|
return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
return self.type._meta.model
|
return self.type._meta.model
|
||||||
|
|
||||||
def get_session(self, args, info):
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
def get_query(self, resolved_query, args, info):
|
def get_query(self, resolved_query, args, info):
|
||||||
self.get_session(args, info)
|
return resolved_query or get_query(self.model, info)
|
||||||
return resolved_query
|
|
||||||
|
|
||||||
def from_list(self, connection_type, resolved, args, info):
|
def from_list(self, connection_type, resolved, args, info):
|
||||||
qs = self.get_query(resolved, args, info)
|
query = self.get_query(resolved, args, info)
|
||||||
return super(SQLAlchemyConnectionField, self).from_list(connection_type, qs, args, info)
|
query = maybe_query(query)
|
||||||
|
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionOrListField(Field):
|
class ConnectionOrListField(Field):
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
|
||||||
|
from graphene.utils import LazyList
|
||||||
# from sqlalchemy.orm.base import object_mapper
|
|
||||||
# from sqlalchemy.orm.exc import UnmappedInstanceError
|
|
||||||
|
|
||||||
|
|
||||||
def get_type_for_model(schema, model):
|
def get_type_for_model(schema, model):
|
||||||
|
@ -20,10 +19,27 @@ def get_session(info):
|
||||||
return schema.options.get('session')
|
return schema.options.get('session')
|
||||||
|
|
||||||
|
|
||||||
|
def get_query(model, info):
|
||||||
|
query = getattr(model, 'query')
|
||||||
|
if not query:
|
||||||
|
query = get_session(info).query(model)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedQuery(LazyList):
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# Dont calculate the length using len(query), as this will
|
||||||
|
# evaluate the whole queryset and return it's length.
|
||||||
|
# Use .count() instead
|
||||||
|
return self._origin.count()
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_query(value):
|
||||||
|
if isinstance(value, Query):
|
||||||
|
return WrappedQuery(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def is_mapped(obj):
|
def is_mapped(obj):
|
||||||
return isinstance(obj, DeclarativeMeta)
|
return isinstance(obj, DeclarativeMeta)
|
||||||
# try:
|
|
||||||
# object_mapper(obj)
|
|
||||||
# except UnmappedInstanceError:
|
|
||||||
# return False
|
|
||||||
# return True
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user