diff --git a/graphene/core/types/base.py b/graphene/core/types/base.py index becf6883..b0acf52d 100644 --- a/graphene/core/types/base.py +++ b/graphene/core/types/base.py @@ -1,3 +1,4 @@ +import six from functools import total_ordering @@ -7,20 +8,36 @@ class BaseType(object): return getattr(cls, 'T', None) -class LazyType(BaseType): - def __init__(self, type_str): - self.type_str = type_str +class MountType(BaseType): + parent = None + def mount(self, cls): + self.parent = cls + + +class LazyType(MountType): + def __init__(self, type): + self.type = type + + @property def is_self(self): - return self.type_str == 'self' + return self.type == 'self' def internal_type(self, schema): - type = schema.get_type(self.type_str) + type = None + if callable(self.type): + type = self.type(self.parent) + elif isinstance(self.type, six.string_types): + if self.is_self: + type = self.parent + else: + type = schema.get_type(self.type) + assert type, 'Type in %s %r cannot be none' % (self.type, self.parent) return schema.T(type) @total_ordering -class OrderedType(BaseType): +class OrderedType(MountType): creation_counter = 0 def __init__(self, _creation_counter=None): @@ -44,6 +61,12 @@ class OrderedType(BaseType): return self.creation_counter < other.creation_counter return NotImplemented + def __gt__(self, other): + # This is needed because bisect does not take a comparison function. + if type(self) == type(other): + return self.creation_counter > other.creation_counter + return NotImplemented + def __hash__(self): return hash((self.creation_counter)) diff --git a/graphene/core/types/definitions.py b/graphene/core/types/definitions.py index 573be0f9..15b03cd8 100644 --- a/graphene/core/types/definitions.py +++ b/graphene/core/types/definitions.py @@ -1,7 +1,7 @@ import six from graphql.core.type import (GraphQLList, GraphQLNonNull) -from .base import MountedType, LazyType +from .base import MountType, MountedType, LazyType class OfType(MountedType): @@ -14,6 +14,11 @@ class OfType(MountedType): def internal_type(self, schema): return self.T(schema.T(self.of_type)) + def mount(self, cls): + self.parent = cls + if isinstance(self.of_type, MountType): + self.of_type.mount(cls) + class List(OfType): T = GraphQLList diff --git a/graphene/core/types/field.py b/graphene/core/types/field.py index a51fb967..eae3ae69 100644 --- a/graphene/core/types/field.py +++ b/graphene/core/types/field.py @@ -4,7 +4,7 @@ from functools import wraps from graphql.core.type import GraphQLField, GraphQLInputObjectField -from .base import LazyType, OrderedType +from .base import MountType, LazyType, OrderedType from .argument import ArgumentsGroup from .definitions import NonNull from ...utils import to_camel_case, ProxySnakeDict @@ -47,8 +47,9 @@ class Field(OrderedType): self.name = to_camel_case(attname) self.attname = attname self.object_type = cls - if isinstance(self.type, LazyType) and self.type.is_self(): - self.type = cls + self.mount(cls) + if isinstance(self.type, MountType): + self.type.mount(cls) cls._meta.add_field(self) @property @@ -68,13 +69,16 @@ class Field(OrderedType): return getattr(instance, self.attname, self.default) return default_getter + def get_type(self, schema): + return self.type + def internal_type(self, schema): resolver = self.resolver description = self.description arguments = self.arguments if not description and resolver: description = resolver.__doc__ - type = schema.T(self.type) + type = schema.T(self.get_type(schema)) type_objecttype = schema.objecttype(type) if type_objecttype and type_objecttype._meta.is_mutation: assert len(arguments) == 0 @@ -120,6 +124,9 @@ class InputField(OrderedType): self.name = to_camel_case(attname) self.attname = attname self.object_type = cls + self.mount(cls) + if isinstance(self.type, MountType): + self.type.mount(cls) cls._meta.add_field(self) def internal_type(self, schema): diff --git a/graphene/core/types/tests/test_field.py b/graphene/core/types/tests/test_field.py index c86918a5..1a0e90c7 100644 --- a/graphene/core/types/tests/test_field.py +++ b/graphene/core/types/tests/test_field.py @@ -3,6 +3,7 @@ from graphql.core.type import GraphQLField, GraphQLInputObjectField, GraphQLStri from ..field import Field, InputField from ..scalars import String from ..base import LazyType +from ..definitions import List from graphene.core.types import ObjectType, InputObjectType from graphene.core.schema import Schema @@ -59,7 +60,19 @@ def test_field_self(): class MyObjectType(ObjectType): my_field = field - assert field.type == MyObjectType + schema = Schema() + + assert schema.T(field).type == schema.T(MyObjectType) + + +def test_field_mounted(): + field = Field(List('MyObjectType'), name='my_customName') + + class MyObjectType(ObjectType): + my_field = field + + assert field.parent == MyObjectType + assert field.type.parent == MyObjectType def test_field_string_reference(): diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index 78ef0cc6..13e76559 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -23,6 +23,7 @@ class ConnectionField(Field): def wrap_resolved(self, value, instance, args, info): return value + def resolve(self, instance, args, info): from graphene.relay.types import PageInfo schema = info.schema.graphene_schema @@ -50,9 +51,10 @@ class ConnectionField(Field): def get_edge_type(self, node): return self.edge_type or node.get_edge_type() - def internal_type(self, schema): + def get_type(self, schema): from graphene.relay.utils import is_node - node = self.get_object_type(schema) + type = schema.T(self.type) + node = schema.objecttype(type) assert is_node(node), 'Only nodes have connections.' schema.register(node) connection_type = self.get_connection_type(node) diff --git a/graphene/relay/types.py b/graphene/relay/types.py index dfbd51c1..40f04d20 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -1,6 +1,7 @@ from graphene.core.fields import BooleanField, Field, ListField, StringField from graphene.core.types import (InputObjectType, Interface, Mutation, ObjectType) +from graphene.core.types.base import LazyType from graphene.core.types.argument import ArgumentsGroup from graphene.core.types.definitions import NonNull from graphene.relay.fields import GlobalIDField @@ -24,7 +25,7 @@ class Edge(ObjectType): class Meta: type_name = 'DefaultEdge' - node = Field(lambda field: field.object_type.node_type, + node = Field(LazyType(lambda object_type: object_type.node_type), description='The item at the end of the edge') cursor = StringField( required=True, description='A cursor for use in pagination') @@ -44,7 +45,7 @@ class Connection(ObjectType): page_info = Field(PageInfo, required=True, description='The Information to aid in pagination') - edges = ListField(lambda field: field.object_type.edge_type, + edges = ListField(LazyType(lambda object_type: object_type.edge_type), description='Information to aid in pagination.') _connection_data = None