Improved relay

This commit is contained in:
Syrus Akbary 2015-11-11 17:33:23 -08:00
parent cfba52e6f3
commit b0f2b4dd55
6 changed files with 28 additions and 19 deletions

View File

@ -40,9 +40,9 @@ class Schema(object):
if object_type not in self._types: if object_type not in self._types:
internal_type = object_type.internal_type(self) internal_type = object_type.internal_type(self)
self._types[object_type] = internal_type self._types[object_type] = internal_type
name = getattr(internal_type, 'name', None) is_objecttype = inspect.isclass(object_type) and issubclass(object_type, BaseObjectType)
if name: if is_objecttype:
self._types_names[name] = object_type self.register(object_type)
return self._types[object_type] return self._types[object_type]
else: else:
return object_type return object_type
@ -65,6 +65,10 @@ class Schema(object):
return GraphQLSchema(self, query=self.T(self.query), mutation=self.T(self.mutation)) return GraphQLSchema(self, query=self.T(self.query), mutation=self.T(self.mutation))
def register(self, object_type): def register(self, object_type):
type_name = object_type._meta.type_name
registered_object_type = self._types_names.get(type_name, None)
if registered_object_type:
assert registered_object_type == object_type, 'Type {} already registered with other object type'.format(type_name)
self._types_names[object_type._meta.type_name] = object_type self._types_names[object_type._meta.type_name] = object_type
return object_type return object_type

View File

@ -31,14 +31,12 @@ class Field(OrderedType):
if isinstance(type, six.string_types): if isinstance(type, six.string_types):
type = LazyType(type) type = LazyType(type)
self.required = required self.required = required
if self.required:
type = NonNull(type)
self.type = type self.type = type
self.description = description self.description = description
args = OrderedDict(args or {}, **kwargs) args = OrderedDict(args or {}, **kwargs)
self.arguments = ArgumentsGroup(*args_list, **args) self.arguments = ArgumentsGroup(*args_list, **args)
self.object_type = None self.object_type = None
self.resolver = resolver self.resolver_fn = resolver
self.default = default self.default = default
def contribute_to_class(self, cls, attname): def contribute_to_class(self, cls, attname):
@ -54,11 +52,11 @@ class Field(OrderedType):
@property @property
def resolver(self): def resolver(self):
return self._resolver or self.get_resolver_fn() return self.resolver_fn or self.get_resolver_fn()
@resolver.setter @resolver.setter
def resolver(self, value): def resolver(self, value):
self._resolver = value self.resolver_fn = value
def get_resolver_fn(self): def get_resolver_fn(self):
resolve_fn_name = 'resolve_%s' % self.attname resolve_fn_name = 'resolve_%s' % self.attname
@ -70,6 +68,8 @@ class Field(OrderedType):
return default_getter return default_getter
def get_type(self, schema): def get_type(self, schema):
if self.required:
return NonNull(self.type)
return self.type return self.type
def internal_type(self, schema): def internal_type(self, schema):

View File

@ -1,7 +1,7 @@
from collections import Iterable from collections import Iterable
from graphene.core.fields import Field, IDField from graphene.core.fields import Field, IDField
from graphene.core.types.scalars import String, ID from graphene.core.types.scalars import String, ID, Int
from graphql.core.type import GraphQLArgument, GraphQLID, GraphQLNonNull from graphql.core.type import GraphQLArgument, GraphQLID, GraphQLNonNull
from graphql_relay.connection.arrayconnection import connection_from_list from graphql_relay.connection.arrayconnection import connection_from_list
from graphql_relay.node.node import from_global_id from graphql_relay.node.node import from_global_id
@ -14,8 +14,8 @@ class ConnectionField(Field):
super(ConnectionField, self).__init__(field_type, resolver=resolver, super(ConnectionField, self).__init__(field_type, resolver=resolver,
before=String(), before=String(),
after=String(), after=String(),
first=String(), first=Int(),
last=String(), last=Int(),
description=description, **kwargs) description=description, **kwargs)
self.connection_type = connection_type self.connection_type = connection_type
self.edge_type = edge_type self.edge_type = edge_type
@ -24,17 +24,18 @@ class ConnectionField(Field):
return value return value
def resolve(self, instance, args, info): def resolver(self, instance, args, info):
from graphene.relay.types import PageInfo from graphene.relay.types import PageInfo
schema = info.schema.graphene_schema schema = info.schema.graphene_schema
resolved = super(ConnectionField, self).resolve(instance, args, info) resolved = super(ConnectionField, self).resolver(instance, args, info)
if resolved: if resolved:
resolved = self.wrap_resolved(resolved, instance, args, info) resolved = self.wrap_resolved(resolved, instance, args, info)
assert isinstance( assert isinstance(
resolved, Iterable), 'Resolved value from the connection field have to be iterable' resolved, Iterable), 'Resolved value from the connection field have to be iterable'
node = self.get_object_type(schema) type = schema.T(self.type)
node = schema.objecttype(type)
connection_type = self.get_connection_type(node) connection_type = self.get_connection_type(node)
edge_type = self.get_edge_type(node) edge_type = self.get_edge_type(node)
@ -81,14 +82,16 @@ class NodeField(Field):
return object_type.get_node(_id) return object_type.get_node(_id)
def resolve(self, instance, args, info): def resolver(self, instance, args, info):
global_id = args.get('id') global_id = args.get('id')
return self.id_fetcher(global_id, info) return self.id_fetcher(global_id, info)
class GlobalIDField(IDField): class GlobalIDField(Field):
'''The ID of an object''' '''The ID of an object'''
required = True def __init__(self, *args, **kwargs):
super(GlobalIDField, self).__init__(ID(), *args, **kwargs)
self.required = True
def contribute_to_class(self, cls, name): def contribute_to_class(self, cls, name):
from graphene.relay.utils import is_node, is_node_type from graphene.relay.utils import is_node, is_node_type
@ -96,5 +99,5 @@ class GlobalIDField(IDField):
assert in_node, 'GlobalIDField could only be inside a Node, but got %r' % cls assert in_node, 'GlobalIDField could only be inside a Node, but got %r' % cls
super(GlobalIDField, self).contribute_to_class(cls, name) super(GlobalIDField, self).contribute_to_class(cls, name)
def resolve(self, instance, args, info): def resolver(self, instance, args, info):
return self.object_type.to_global_id(instance, args, info) return self.object_type.to_global_id(instance, args, info)

View File

@ -15,7 +15,7 @@ class MyNode(relay.Node):
@classmethod @classmethod
def get_node(cls, id): def get_node(cls, id):
return MyNode(name='mo') return MyNode(id=id, name='mo')
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
@ -35,6 +35,7 @@ def test_nodefield_query():
query = ''' query = '''
query RebelsShipsQuery { query RebelsShipsQuery {
myNode(id:"TXlOb2RlOjE=") { myNode(id:"TXlOb2RlOjE=") {
id
name name
}, },
allMyNodes (customArg:"1") { allMyNodes (customArg:"1") {
@ -52,6 +53,7 @@ def test_nodefield_query():
''' '''
expected = { expected = {
'myNode': { 'myNode': {
'id': 'TXlOb2RlOjE=',
'name': 'mo' 'name': 'mo'
}, },
'allMyNodes': { 'allMyNodes': {