diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 62b39968..f9659a99 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -62,7 +62,7 @@ class DjangoModelField(Field): self.object_type ) ) - return _type and _type.internal_type(schema) or Field.SKIP + return schema.T(_type) or Field.SKIP def get_object_type(self, schema): return get_type_for_model(schema, self.model) diff --git a/graphene/core/fields.py b/graphene/core/fields.py index a83ecbf0..f00105e5 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -114,7 +114,7 @@ class Field(object): else: object_type = self.get_object_type(schema) if object_type: - field_type = object_type.internal_type(schema) + field_type = schema.T(object_type) field_type = self.type_wrapper(field_type) return field_type diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 8fa3d340..f4a1b000 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -25,6 +25,7 @@ class Schema(object): def __init__(self, query=None, mutation=None, name='Schema', executor=None): self._internal_types = {} + self._types = {} self.mutation = mutation self.query = query self.name = name @@ -34,6 +35,13 @@ class Schema(object): def __repr__(self): return '' % (str(self.name), hash(self)) + def T(self, object_type): + if not object_type: + return + if object_type not in self._types: + self._types[object_type] = object_type.internal_type(self) + return self._types[object_type] + @property def query(self): return self._query @@ -41,7 +49,6 @@ class Schema(object): @query.setter def query(self, query): self._query = query - self._query_type = query and query.internal_type(self) @property def mutation(self): @@ -50,7 +57,6 @@ class Schema(object): @mutation.setter def mutation(self, mutation): self._mutation = mutation - self._mutation_type = mutation and mutation.internal_type(self) @property def executor(self): @@ -62,11 +68,11 @@ class Schema(object): def executor(self, value): self._executor = value - @cached_property + @property def schema(self): - if not self._query_type: + if not self._query: raise Exception('You have to define a base query type') - return GraphQLSchema(self, query=self._query_type, mutation=self._mutation_type) + return GraphQLSchema(self, query=self.T(self._query), mutation=self.T(self._mutation)) def associate_internal_type(self, internal_type, object_type): self._internal_types[internal_type.name] = object_type @@ -76,6 +82,7 @@ class Schema(object): return object_type def get_type(self, type_name): + self.schema._build_type_map() if type_name not in self._internal_types: raise Exception('Type %s not found in %r' % (type_name, self)) return self._internal_types[type_name] diff --git a/graphene/core/types.py b/graphene/core/types.py index d95ab4e8..e9bf78ec 100644 --- a/graphene/core/types.py +++ b/graphene/core/types.py @@ -179,15 +179,14 @@ class BaseObjectType(object): @classmethod def resolve_objecttype(cls, schema, instance, *_): - return instance + return instance.__class__ @classmethod def resolve_type(cls, schema, instance, *_): objecttype = cls.resolve_objecttype(schema, instance, *_) - return objecttype.internal_type(schema) + return schema.T(objecttype) @classmethod - @memoize @register_internal_type def internal_type(cls, schema): fields = lambda: OrderedDict([(f.name, f.internal_field(schema)) @@ -203,7 +202,7 @@ class BaseObjectType(object): return GraphQLObjectType( cls._meta.type_name, description=cls._meta.description, - interfaces=[i.internal_type(schema) for i in cls._meta.interfaces], + interfaces=[schema.T(i) for i in cls._meta.interfaces], fields=fields, is_type_of=getattr(cls, 'is_type_of', None) ) @@ -225,7 +224,6 @@ class Mutation(six.with_metaclass(ObjectTypeMeta, BaseObjectType)): class InputObjectType(ObjectType): @classmethod - @memoize @register_internal_type def internal_type(cls, schema): fields = lambda: OrderedDict([(f.name, f.internal_field(schema)) diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index 68233b22..8d115d83 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -63,8 +63,8 @@ class ConnectionField(Field): node = self.get_object_type(schema) assert is_node(node), 'Only nodes have connections.' schema.register(node) - - return self.get_connection_type(node).internal_type(schema) + connection_type = self.get_connection_type(node) + return schema.T(connection_type) class NodeField(Field): diff --git a/tests/contrib_django/test_types.py b/tests/contrib_django/test_types.py index 62da63c1..6ef741fd 100644 --- a/tests/contrib_django/test_types.py +++ b/tests/contrib_django/test_types.py @@ -50,7 +50,7 @@ def test_django_interface(): def test_pseudo_interface(): - object_type = Character.internal_type(schema) + object_type = schema.T(Character) assert Character._meta.is_interface is True assert isinstance(object_type, GraphQLInterfaceType) assert Character._meta.model == Reporter @@ -81,7 +81,7 @@ def test_interface_resolve_type(): def test_object_type(): - object_type = Human.internal_type(schema) + object_type = schema.T(Human) fields_map = Human._meta.fields_map assert Human._meta.is_interface is False assert isinstance(object_type, GraphQLObjectType) @@ -95,7 +95,7 @@ def test_object_type(): # 'reporter': fields_map['reporter'].internal_field(schema), # 'pubDate': fields_map['pub_date'].internal_field(schema), # } - assert DjangoNode.internal_type(schema) in object_type.get_interfaces() + assert schema.T(DjangoNode) in object_type.get_interfaces() def test_node_notinterface(): diff --git a/tests/core/test_query.py b/tests/core/test_query.py index 33cc60cf..075cdea2 100644 --- a/tests/core/test_query.py +++ b/tests/core/test_query.py @@ -46,7 +46,7 @@ class Human(Character): schema = Schema() -Human_type = Human.internal_type(schema) +Human_type = schema.T(Human) def test_type(): diff --git a/tests/core/test_schema.py b/tests/core/test_schema.py index f6e3d849..a310648f 100644 --- a/tests/core/test_schema.py +++ b/tests/core/test_schema.py @@ -125,9 +125,23 @@ def test_schema_register(): class MyType(ObjectType): type = StringField(resolve=lambda *_: 'Dog') + schema.query = MyType + assert schema.get_type('MyType') == MyType +def test_schema_register(): + schema = Schema(name='My own schema') + + @schema.register + class MyType(ObjectType): + type = StringField(resolve=lambda *_: 'Dog') + + with raises(Exception) as excinfo: + schema.get_type('MyType') + assert 'base query type' in str(excinfo.value) + + def test_schema_introspect(): schema = Schema(name='My own schema') @@ -138,4 +152,3 @@ def test_schema_introspect(): introspection = schema.introspect() assert '__schema' in introspection - diff --git a/tests/core/test_types.py b/tests/core/test_types.py index 7d5a17ed..a8370b72 100644 --- a/tests/core/test_types.py +++ b/tests/core/test_types.py @@ -38,7 +38,7 @@ schema = Schema() def test_interface(): - object_type = Character.internal_type(schema) + object_type = schema.T(Character) assert Character._meta.is_interface is True assert isinstance(object_type, GraphQLInterfaceType) assert Character._meta.type_name == 'core_Character' @@ -54,7 +54,7 @@ def test_interface_resolve_type(): def test_object_type(): - object_type = Human.internal_type(schema) + object_type = schema.T(Human) assert Human._meta.is_interface is False assert Human._meta.type_name == 'core_Human' assert isinstance(object_type, GraphQLObjectType) @@ -62,7 +62,7 @@ def test_object_type(): assert list(object_type.get_fields().keys()) == ['name', 'friends'] # assert object_type.get_fields() == {'name': Human._meta.fields_map['name'].internal_field( # schema), 'friends': Human._meta.fields_map['friends'].internal_field(schema)} - assert object_type.get_interfaces() == [Character.internal_type(schema)] + assert object_type.get_interfaces() == [schema.T(Character)] assert Human._meta.fields_map['name'].object_type == Human