ConnectionField now only accepts a connection. Removed fastcache.

This commit is contained in:
Markus Padourek 2016-09-16 16:22:11 +01:00
parent 7137a59749
commit ca0dcfbd22
14 changed files with 128 additions and 66 deletions

View File

@ -16,6 +16,8 @@ class Ship(graphene.ObjectType):
def get_node(cls, id, context, info): def get_node(cls, id, context, info):
return get_ship(id) return get_ship(id)
ShipConnection = relay.Connection.for_type(Ship)
class Faction(graphene.ObjectType): class Faction(graphene.ObjectType):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
@ -24,7 +26,7 @@ class Faction(graphene.ObjectType):
interfaces = (relay.Node, ) interfaces = (relay.Node, )
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(Ship, description='The ships used by the faction.') ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.')
@resolve_only_args @resolve_only_args
def resolve_ships(self, **args): def resolve_ships(self, **args):

View File

@ -1,7 +1,7 @@
from django.db import models from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull, Field, Dynamic from graphene import Enum, List, ID, Boolean, Float, Int, String, NonNull, Field, Dynamic
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.datetime import DateTime from graphene.types.datetime import DateTime
from graphene.utils.str_converters import to_const from graphene.utils.str_converters import to_const
@ -37,7 +37,7 @@ def convert_django_field_with_choices(field, registry=None):
name = '{}{}'.format(meta.object_name, field.name.capitalize()) name = '{}{}'.format(meta.object_name, field.name.capitalize())
choices = list(get_choices(choices)) choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices] named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]:c[2] for c in choices} named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object): class EnumWithDescriptionsType(object):
@property @property

View File

@ -30,8 +30,10 @@ class Role(SQLAlchemyObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = relay.Node.Field() node = relay.Node.Field()
all_employees = SQLAlchemyConnectionField(Employee) employee_connection = relay.Connection.for_type(Employee)
all_roles = SQLAlchemyConnectionField(Role) role_connection = relay.Connection.for_type(Role)
all_employees = SQLAlchemyConnectionField(employee_connection)
all_roles = SQLAlchemyConnectionField(role_connection)
role = graphene.Field(Role) role = graphene.Field(Role)

View File

@ -4,7 +4,7 @@ from sqlalchemy.orm import interfaces
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic
from graphene.relay import is_node from graphene.relay import is_node, Connection
from graphene.types.json import JSONString from graphene.types.json import JSONString
from .fields import SQLAlchemyConnectionField from .fields import SQLAlchemyConnectionField
@ -18,7 +18,7 @@ except ImportError:
pass pass
def convert_sqlalchemy_relationship(relationship, registry): def convert_sqlalchemy_relationship(relationship, registry, connections, type_name):
direction = relationship.direction direction = relationship.direction
model = relationship.mapper.entity model = relationship.mapper.entity
@ -31,7 +31,12 @@ def convert_sqlalchemy_relationship(relationship, registry):
elif (direction == interfaces.ONETOMANY or elif (direction == interfaces.ONETOMANY or
direction == interfaces.MANYTOMANY): direction == interfaces.MANYTOMANY):
if is_node(_type): if is_node(_type):
return SQLAlchemyConnectionField(_type) try:
connection_type = connections[relationship.key]
except KeyError:
raise KeyError("No Connection provided for relationship {} on type {}. Specify it in its Meta "
"class on the 'connections' dict.".format(relationship.key, type_name))
return SQLAlchemyConnectionField(connection_type)
return Field(List(_type)) return Field(List(_type))
return Dynamic(dynamic_type) return Dynamic(dynamic_type)

View File

@ -6,7 +6,7 @@ from sqlalchemy_utils import ChoiceType, ScalarListType
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node, Connection
from graphene.types.json import JSONString from graphene.types.json import JSONString
from ..converter import (convert_sqlalchemy_column, from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_composite, convert_sqlalchemy_composite,
@ -129,7 +129,7 @@ def test_should_scalar_list_convert_list():
def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist():
registry = Registry() registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry, {}, '')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type() assert not dynamic_field.get_type()
@ -139,7 +139,7 @@ def test_should_manytomany_convert_connectionorlist_list():
class Meta: class Meta:
model = Pet model = Pet
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
@ -153,14 +153,33 @@ def test_should_manytomany_convert_connectionorlist_connection():
model = Pet model = Pet
interfaces = (Node, ) interfaces = (Node, )
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) connections = {'pets': Connection.for_type(A)}
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField) assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField)
assert issubclass(dynamic_field.get_type().type, Connection)
def test_should_rais_when_no_connections_is_provided_for_manyto_many():
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
interfaces = (Node, )
connections = {}
with raises(KeyError) as ctx:
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A')
dynamic_field.get_type()
assert str(ctx.value) == ('\"No Connection provided for relationship pets on type A. Specify it in its Meta '
'class on the \'connections\' dict.\"')
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
registry = Registry() registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry, {}, '')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type() assert not dynamic_field.get_type()
@ -170,7 +189,7 @@ def test_should_manytoone_convert_connectionorlist_list():
class Meta: class Meta:
model = Reporter model = Reporter
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
@ -183,7 +202,7 @@ def test_should_manytoone_convert_connectionorlist_connection():
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
@ -196,7 +215,7 @@ def test_should_onetoone_convert_field():
model = Article model = Article
interfaces = (Node, ) interfaces = (Node, )
dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)

View File

@ -3,7 +3,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node, Connection
from ..types import SQLAlchemyObjectType from ..types import SQLAlchemyObjectType
from ..fields import SQLAlchemyConnectionField from ..fields import SQLAlchemyConnectionField
@ -94,16 +94,6 @@ def test_should_query_well(session):
def test_should_node(session): def test_should_node(session):
setup_fixtures(session) setup_fixtures(session)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class ArticleNode(SQLAlchemyObjectType): class ArticleNode(SQLAlchemyObjectType):
class Meta: class Meta:
@ -114,11 +104,26 @@ def test_should_node(session):
# def get_node(cls, id, info): # def get_node(cls, id, info):
# return Article(id=1, headline='Article node') # return Article(id=1, headline='Article node')
ArticleNodeConnection = Connection.for_type(ArticleNode)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
connections = {
'articles': ArticleNodeConnection
}
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = Node.Field() node = Node.Field()
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleNodeConnection)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return session.query(Reporter).first() return session.query(Reporter).first()
@ -202,7 +207,8 @@ def test_should_custom_identifier(session):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = Node.Field() node = Node.Field()
all_editors = SQLAlchemyConnectionField(EditorNode) EditorNodeConnection = Connection.for_type(EditorNode)
all_editors = SQLAlchemyConnectionField(EditorNodeConnection)
query = ''' query = '''
query EditorQuery { query EditorQuery {
@ -250,23 +256,27 @@ def test_should_mutate_well(session):
model = Editor model = Editor
interfaces = (Node, ) interfaces = (Node, )
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class ArticleNode(SQLAlchemyObjectType): class ArticleNode(SQLAlchemyObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node, )
ArticleNodeConnection = Connection.for_type(ArticleNode)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
connections = {
'articles': ArticleNodeConnection
}
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class CreateArticle(graphene.Mutation): class CreateArticle(graphene.Mutation):
class Input: class Input:
headline = graphene.String() headline = graphene.String()
@ -279,7 +289,7 @@ def test_should_mutate_well(session):
def mutate(cls, instance, args, context, info): def mutate(cls, instance, args, context, info):
new_article = Article( new_article = Article(
headline=args.get('headline'), headline=args.get('headline'),
reporter_id = args.get('reporter_id'), reporter_id=args.get('reporter_id'),
) )
session.add(new_article) session.add(new_article)

View File

@ -18,7 +18,7 @@ from graphene.types.utils import yank_fields_from_attrs, merge
from .utils import get_query from .utils import get_query
def construct_fields(options): def construct_fields(options, type_name):
only_fields = options.only_fields only_fields = options.only_fields
exclude_fields = options.exclude_fields exclude_fields = options.exclude_fields
inspected_model = sqlalchemyinspect(options.model) inspected_model = sqlalchemyinspect(options.model)
@ -56,7 +56,7 @@ def construct_fields(options):
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields # in there. Or when we excldue this field in exclude_fields
continue continue
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry, options.connections, type_name)
name = relationship.key name = relationship.key
fields[name] = converted_relationship fields[name] = converted_relationship
@ -82,7 +82,8 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
exclude_fields=(), exclude_fields=(),
id='id', id='id',
interfaces=(), interfaces=(),
registry=None registry=None,
connections={},
) )
if not options.registry: if not options.registry:
@ -96,13 +97,12 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
'{}.Meta, received "{}".' '{}.Meta, received "{}".'
).format(name, options.model) ).format(name, options.model)
cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options))
options.registry.register(cls) options.registry.register(cls)
options.sqlalchemy_fields = yank_fields_from_attrs( options.sqlalchemy_fields = yank_fields_from_attrs(
construct_fields(options), construct_fields(options, name),
_as=Field, _as=Field,
) )
options.fields = merge( options.fields = merge(

View File

@ -5,7 +5,6 @@ from functools import partial
from promise import Promise from promise import Promise
import six import six
from fastcache import clru_cache
from graphql_relay import connection_from_list from graphql_relay import connection_from_list
from ..types import Boolean, Int, List, String, AbstractType from ..types import Boolean, Int, List, String, AbstractType
@ -102,7 +101,6 @@ class ConnectionMeta(ObjectTypeMeta):
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
@classmethod @classmethod
@clru_cache(maxsize=None)
def for_type(cls, gql_type): def for_type(cls, gql_type):
connection_name = '{}Connection'.format(gql_type._meta.name) connection_name = '{}Connection'.format(gql_type._meta.name)
@ -125,6 +123,10 @@ def is_connection(gql_type):
class IterableConnectionField(Field): class IterableConnectionField(Field):
def __init__(self, gql_type, *args, **kwargs): def __init__(self, gql_type, *args, **kwargs):
assert is_connection(gql_type), (
'The provided type "{}" for this ConnectionField has to be a Connection as defined by the Relay'
' spec.'.format(gql_type)
)
super(IterableConnectionField, self).__init__( super(IterableConnectionField, self).__init__(
gql_type, gql_type,
*args, *args,
@ -134,19 +136,17 @@ class IterableConnectionField(Field):
last=Int(), last=Int(),
**kwargs **kwargs
) )
self._gql_type = gql_type
@property
def type(self):
return self._gql_type if is_connection(self._gql_type) else Connection.for_type(self._gql_type)
@staticmethod @staticmethod
def connection_resolver(resolver, connection, root, args, context, info): def connection_resolver(resolver, connection, root, args, context, info):
resolved = Promise.resolve(resolver(root, args, context, info)) resolved = Promise.resolve(resolver(root, args, context, info))
def handle_connection_and_list(result): def handle_connection_and_list(result):
if is_connection(result): if isinstance(result, connection):
return result return result
elif is_connection(result):
raise AssertionError('Resolved value from the connection field has to be a {}. '
'Received {}.'.format(connection, type(result)))
else: else:
assert isinstance(result, Iterable), ( assert isinstance(result, Iterable), (
'Resolved value from the connection field have to be iterable. ' 'Resolved value from the connection field have to be iterable. '

View File

@ -100,8 +100,8 @@ def test_defaul_connection_for_type():
assert list(fields.keys()) == ['edges', 'page_info'] assert list(fields.keys()) == ['edges', 'page_info']
def test_defaul_connection_for_type_returns_same_Connection(): def test_default_connection_for_type_does_not_returns_same_Connection():
assert Connection.for_type(MyObject) == Connection.for_type(MyObject) assert Connection.for_type(MyObject) != Connection.for_type(MyObject)
def test_edge(): def test_edge():
@ -179,4 +179,5 @@ def test_edge_for_object_type():
def test_edge_for_type_returns_same_edge(): def test_edge_for_type_returns_same_edge():
assert Connection.for_type(MyObject).Edge == Connection.for_type(MyObject).Edge MyObjectConnection = Connection.for_type(MyObject)
assert MyObjectConnection.Edge == MyObjectConnection.Edge

View File

@ -26,15 +26,21 @@ class MyLetterObjectConnection(Connection):
class Edge: class Edge:
other = String() other = String()
LetterConnection = Connection.for_type(Letter)
class Query(ObjectType): class Query(ObjectType):
letters = ConnectionField(Letter) letters = ConnectionField(LetterConnection)
letters_promise = ConnectionField(Letter) letters_wrong_connection = ConnectionField(LetterConnection)
letters_promise = ConnectionField(LetterConnection)
letters_connection = ConnectionField(MyLetterObjectConnection) letters_connection = ConnectionField(MyLetterObjectConnection)
def resolve_letters(self, *_): def resolve_letters(self, *_):
return list(letters.values()) return list(letters.values())
def resolve_letters_wrong_connection(self, *_):
return MyLetterObjectConnection()
def resolve_letters_connection(self, *_): def resolve_letters_connection(self, *_):
return MyLetterObjectConnection( return MyLetterObjectConnection(
extra='1', extra='1',
@ -121,6 +127,22 @@ def check(args, letters, has_previous_page=False, has_next_page=False):
assert result.data == create_expexted_result(letters, has_previous_page, has_next_page) assert result.data == create_expexted_result(letters, has_previous_page, has_next_page)
def test_resolver_throws_error_on_returning_wrong_connection_type():
result = schema.execute('''
{
lettersWrongConnection {
edges {
node {
id
}
}
}
}
''')
assert result.errors[0].message == ('Resolved value from the connection field has to be a LetterConnection. '
'Received MyLetterObjectConnection.')
def test_resolver_handles_returned_connection_field_correctly(): def test_resolver_handles_returned_connection_field_correctly():
result = schema.execute(''' result = schema.execute('''
{ {

View File

@ -20,6 +20,9 @@ class MyNode(ObjectType):
name = String() name = String()
MyNodeConnection = Connection.for_type(MyNode)
class SaySomething(ClientIDMutation): class SaySomething(ClientIDMutation):
class Input: class Input:
@ -38,13 +41,13 @@ class OtherMutation(ClientIDMutation):
additional_field = String() additional_field = String()
name = String() name = String()
my_node_edge = Field(Connection.for_type(MyNode).Edge) my_node_edge = Field(MyNodeConnection.Edge)
@classmethod @classmethod
def mutate_and_get_payload(cls, args, context, info): def mutate_and_get_payload(cls, args, context, info):
shared = args.get('shared', '') shared = args.get('shared', '')
additionalField = args.get('additionalField', '') additionalField = args.get('additionalField', '')
edge_type = Connection.for_type(MyNode).Edge edge_type = MyNodeConnection.Edge
return OtherMutation(name=shared + additionalField, return OtherMutation(name=shared + additionalField,
my_node_edge=edge_type( my_node_edge=edge_type(
cursor='1', node=MyNode(name='name'))) cursor='1', node=MyNode(name='name')))

View File

@ -16,7 +16,7 @@ def source_resolver(source, root, args, context, info):
class Field(OrderedType): class Field(OrderedType):
def __init__(self, type, args=None, resolver=None, source=None, def __init__(self, gql_type, args=None, resolver=None, source=None,
deprecation_reason=None, name=None, description=None, deprecation_reason=None, name=None, description=None,
required=False, _creation_counter=None, **extra_args): required=False, _creation_counter=None, **extra_args):
super(Field, self).__init__(_creation_counter=_creation_counter) super(Field, self).__init__(_creation_counter=_creation_counter)
@ -28,10 +28,10 @@ class Field(OrderedType):
) )
if required: if required:
type = NonNull(type) gql_type = NonNull(gql_type)
self.name = name self.name = name
self._type = type self._type = gql_type
self.args = to_arguments(args or OrderedDict(), extra_args) self.args = to_arguments(args or OrderedDict(), extra_args)
if source: if source:
resolver = partial(source_resolver, source) resolver = partial(source_resolver, source)

View File

@ -72,7 +72,6 @@ setup(
'six>=1.10.0', 'six>=1.10.0',
'graphql-core>=1.0.dev', 'graphql-core>=1.0.dev',
'graphql-relay>=0.4.4', 'graphql-relay>=0.4.4',
'fastcache>=1.0.2',
'promise', 'promise',
], ],
tests_require=[ tests_require=[

View File

@ -7,7 +7,6 @@ deps=
pytest>=2.7.2 pytest>=2.7.2
graphql-core>=1.0.dev graphql-core>=1.0.dev
graphql-relay>=0.4.4 graphql-relay>=0.4.4
fastcache>=1.0.2
six six
blinker blinker
singledispatch singledispatch