Improved lazy type resolvers

This commit is contained in:
Syrus Akbary 2015-11-11 02:12:17 -08:00
parent bdcd533bf9
commit cfba52e6f3
6 changed files with 67 additions and 16 deletions

View File

@ -1,3 +1,4 @@
import six
from functools import total_ordering from functools import total_ordering
@ -7,20 +8,36 @@ class BaseType(object):
return getattr(cls, 'T', None) return getattr(cls, 'T', None)
class LazyType(BaseType): class MountType(BaseType):
def __init__(self, type_str): parent = None
self.type_str = type_str
def mount(self, cls):
self.parent = cls
class LazyType(MountType):
def __init__(self, type):
self.type = type
@property
def is_self(self): def is_self(self):
return self.type_str == 'self' return self.type == 'self'
def internal_type(self, schema): def internal_type(self, schema):
type = schema.get_type(self.type_str) type = None
if callable(self.type):
type = self.type(self.parent)
elif isinstance(self.type, six.string_types):
if self.is_self:
type = self.parent
else:
type = schema.get_type(self.type)
assert type, 'Type in %s %r cannot be none' % (self.type, self.parent)
return schema.T(type) return schema.T(type)
@total_ordering @total_ordering
class OrderedType(BaseType): class OrderedType(MountType):
creation_counter = 0 creation_counter = 0
def __init__(self, _creation_counter=None): def __init__(self, _creation_counter=None):
@ -44,6 +61,12 @@ class OrderedType(BaseType):
return self.creation_counter < other.creation_counter return self.creation_counter < other.creation_counter
return NotImplemented return NotImplemented
def __gt__(self, other):
# This is needed because bisect does not take a comparison function.
if type(self) == type(other):
return self.creation_counter > other.creation_counter
return NotImplemented
def __hash__(self): def __hash__(self):
return hash((self.creation_counter)) return hash((self.creation_counter))

View File

@ -1,7 +1,7 @@
import six import six
from graphql.core.type import (GraphQLList, GraphQLNonNull) from graphql.core.type import (GraphQLList, GraphQLNonNull)
from .base import MountedType, LazyType from .base import MountType, MountedType, LazyType
class OfType(MountedType): class OfType(MountedType):
@ -14,6 +14,11 @@ class OfType(MountedType):
def internal_type(self, schema): def internal_type(self, schema):
return self.T(schema.T(self.of_type)) return self.T(schema.T(self.of_type))
def mount(self, cls):
self.parent = cls
if isinstance(self.of_type, MountType):
self.of_type.mount(cls)
class List(OfType): class List(OfType):
T = GraphQLList T = GraphQLList

View File

@ -4,7 +4,7 @@ from functools import wraps
from graphql.core.type import GraphQLField, GraphQLInputObjectField from graphql.core.type import GraphQLField, GraphQLInputObjectField
from .base import LazyType, OrderedType from .base import MountType, LazyType, OrderedType
from .argument import ArgumentsGroup from .argument import ArgumentsGroup
from .definitions import NonNull from .definitions import NonNull
from ...utils import to_camel_case, ProxySnakeDict from ...utils import to_camel_case, ProxySnakeDict
@ -47,8 +47,9 @@ class Field(OrderedType):
self.name = to_camel_case(attname) self.name = to_camel_case(attname)
self.attname = attname self.attname = attname
self.object_type = cls self.object_type = cls
if isinstance(self.type, LazyType) and self.type.is_self(): self.mount(cls)
self.type = cls if isinstance(self.type, MountType):
self.type.mount(cls)
cls._meta.add_field(self) cls._meta.add_field(self)
@property @property
@ -68,13 +69,16 @@ class Field(OrderedType):
return getattr(instance, self.attname, self.default) return getattr(instance, self.attname, self.default)
return default_getter return default_getter
def get_type(self, schema):
return self.type
def internal_type(self, schema): def internal_type(self, schema):
resolver = self.resolver resolver = self.resolver
description = self.description description = self.description
arguments = self.arguments arguments = self.arguments
if not description and resolver: if not description and resolver:
description = resolver.__doc__ description = resolver.__doc__
type = schema.T(self.type) type = schema.T(self.get_type(schema))
type_objecttype = schema.objecttype(type) type_objecttype = schema.objecttype(type)
if type_objecttype and type_objecttype._meta.is_mutation: if type_objecttype and type_objecttype._meta.is_mutation:
assert len(arguments) == 0 assert len(arguments) == 0
@ -120,6 +124,9 @@ class InputField(OrderedType):
self.name = to_camel_case(attname) self.name = to_camel_case(attname)
self.attname = attname self.attname = attname
self.object_type = cls self.object_type = cls
self.mount(cls)
if isinstance(self.type, MountType):
self.type.mount(cls)
cls._meta.add_field(self) cls._meta.add_field(self)
def internal_type(self, schema): def internal_type(self, schema):

View File

@ -3,6 +3,7 @@ from graphql.core.type import GraphQLField, GraphQLInputObjectField, GraphQLStri
from ..field import Field, InputField from ..field import Field, InputField
from ..scalars import String from ..scalars import String
from ..base import LazyType from ..base import LazyType
from ..definitions import List
from graphene.core.types import ObjectType, InputObjectType from graphene.core.types import ObjectType, InputObjectType
from graphene.core.schema import Schema from graphene.core.schema import Schema
@ -59,7 +60,19 @@ def test_field_self():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
my_field = field my_field = field
assert field.type == MyObjectType schema = Schema()
assert schema.T(field).type == schema.T(MyObjectType)
def test_field_mounted():
field = Field(List('MyObjectType'), name='my_customName')
class MyObjectType(ObjectType):
my_field = field
assert field.parent == MyObjectType
assert field.type.parent == MyObjectType
def test_field_string_reference(): def test_field_string_reference():

View File

@ -23,6 +23,7 @@ class ConnectionField(Field):
def wrap_resolved(self, value, instance, args, info): def wrap_resolved(self, value, instance, args, info):
return value return value
def resolve(self, instance, args, info): def resolve(self, instance, args, info):
from graphene.relay.types import PageInfo from graphene.relay.types import PageInfo
schema = info.schema.graphene_schema schema = info.schema.graphene_schema
@ -50,9 +51,10 @@ class ConnectionField(Field):
def get_edge_type(self, node): def get_edge_type(self, node):
return self.edge_type or node.get_edge_type() return self.edge_type or node.get_edge_type()
def internal_type(self, schema): def get_type(self, schema):
from graphene.relay.utils import is_node from graphene.relay.utils import is_node
node = self.get_object_type(schema) type = schema.T(self.type)
node = schema.objecttype(type)
assert is_node(node), 'Only nodes have connections.' assert is_node(node), 'Only nodes have connections.'
schema.register(node) schema.register(node)
connection_type = self.get_connection_type(node) connection_type = self.get_connection_type(node)

View File

@ -1,6 +1,7 @@
from graphene.core.fields import BooleanField, Field, ListField, StringField from graphene.core.fields import BooleanField, Field, ListField, StringField
from graphene.core.types import (InputObjectType, Interface, Mutation, from graphene.core.types import (InputObjectType, Interface, Mutation,
ObjectType) ObjectType)
from graphene.core.types.base import LazyType
from graphene.core.types.argument import ArgumentsGroup from graphene.core.types.argument import ArgumentsGroup
from graphene.core.types.definitions import NonNull from graphene.core.types.definitions import NonNull
from graphene.relay.fields import GlobalIDField from graphene.relay.fields import GlobalIDField
@ -24,7 +25,7 @@ class Edge(ObjectType):
class Meta: class Meta:
type_name = 'DefaultEdge' type_name = 'DefaultEdge'
node = Field(lambda field: field.object_type.node_type, node = Field(LazyType(lambda object_type: object_type.node_type),
description='The item at the end of the edge') description='The item at the end of the edge')
cursor = StringField( cursor = StringField(
required=True, description='A cursor for use in pagination') required=True, description='A cursor for use in pagination')
@ -44,7 +45,7 @@ class Connection(ObjectType):
page_info = Field(PageInfo, required=True, page_info = Field(PageInfo, required=True,
description='The Information to aid in pagination') description='The Information to aid in pagination')
edges = ListField(lambda field: field.object_type.edge_type, edges = ListField(LazyType(lambda object_type: object_type.edge_type),
description='Information to aid in pagination.') description='Information to aid in pagination.')
_connection_data = None _connection_data = None