diff --git a/graphene/new_types/schema.py b/graphene/new_types/schema.py index 7a8315f2..624ce9e1 100644 --- a/graphene/new_types/schema.py +++ b/graphene/new_types/schema.py @@ -1,6 +1,6 @@ import inspect -from graphql import GraphQLSchema, graphql +from graphql import GraphQLSchema, graphql, is_type from graphql.utils.introspection_query import introspection_query from graphql.utils.schema_printer import print_schema @@ -16,18 +16,10 @@ from .scalars import Scalar, String # from collections import defaultdict -from collections import Iterable, OrderedDict, defaultdict -from functools import reduce - -from graphql.utils.type_comparators import is_equal_type, is_type_sub_type_of -from graphql.type.definition import (GraphQLInputObjectType, GraphQLInterfaceType, GraphQLField,GraphQLScalarType, - GraphQLList, GraphQLNonNull, GraphQLObjectType, - GraphQLUnionType) from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective, - GraphQLSkipDirective) + GraphQLSkipDirective) from graphql.type.introspection import IntrospectionSchema -from graphql.type.schema import assert_object_implements_interface -from graphql.type.scalars import GraphQLString +from .typemap import TypeMap, is_graphene_type class Schema(GraphQLSchema): @@ -38,13 +30,6 @@ class Schema(GraphQLSchema): self._subscription = subscription self.types = types self._executor = executor - # super(Schema, self).__init__( - # query=query, - # mutation=mutation, - # subscription=subscription, - # directives=directives, - # types=self.types - # ) if directives is None: directives = [ GraphQLIncludeDirective, @@ -57,20 +42,34 @@ class Schema(GraphQLSchema): ) self._directives = directives - self._possible_type_map = defaultdict(set) - self._type_map = self._build_type_map(types) - # Keep track of all implementations by interface name. - self._implementations = defaultdict(list) - for type in self._type_map.values(): - if isinstance(type, GraphQLObjectType): - for interface in type.get_interfaces(): - self._implementations[interface.name].append(type) + initial_types = [ + query, + mutation, + subscription, + IntrospectionSchema + ] + if types: + initial_types += types + self._type_map = TypeMap(initial_types) - # Enforce correct interface implementations. - for type in self._type_map.values(): - if isinstance(type, GraphQLObjectType): - for interface in type.get_interfaces(): - assert_object_implements_interface(self, type, interface) + def get_query_type(self): + return self.get_graphql_type(self._query) + + def get_mutation_type(self): + return self.get_graphql_type(self._mutation) + + def get_subscription_type(self): + return self.get_graphql_type(self._subscription) + + def get_graphql_type(self, _type): + if is_type(_type): + return _type + if is_graphene_type(_type): + graphql_type = self.get_type(_type._meta.name) + assert graphql_type, "Type {} not found in this schema.".format(_type._meta.name) + assert graphql_type.graphene_type == _type + return graphql_type + raise Exception("{} is not a valid GraphQL type.".format(_type)) def execute(self, request_string='', root_value=None, variable_values=None, context_value=None, operation_name=None, executor=None): @@ -96,62 +95,6 @@ class Schema(GraphQLSchema): def lazy(self, _type): return lambda: self.get_type(_type) - def _type_map_reducer(self, map, type): - if not type: - return map - if isinstance(type, List) or (inspect.isclass(type) and issubclass(type, (ObjectType, Scalar))): - return self._type_map_reducer_graphene(map, type) - return super(Schema, self)._type_map_reducer(map, type) - - def _type_map_reducer_graphene(self, map, type): - # from .structures import List, NonNull - from ..generators.definitions import GrapheneObjectType, GrapheneScalarType - if isinstance(type, List): - return self._type_map_reducer(map, type.of_type) - if issubclass(type, String): - map[type._meta.name] = GraphQLString - return map - - if type._meta.name in map: - assert map[type._meta.name].graphene_type == type - return map - if issubclass(type, ObjectType): - fields = OrderedDict() - map[type._meta.name] = GrapheneObjectType( - graphene_type=type, - name=type._meta.name, - description=type._meta.description, - fields={}, - is_type_of=type.is_type_of, - interfaces=type._meta.interfaces - ) - for name, field in type._meta.fields.items(): - map = self._type_map_reducer(map, field.type) - field_type = self.get_field_type(map, field.type) - _field = GraphQLField( - field_type, - args=field.args, - resolver=field.resolver, - deprecation_reason=field.deprecation_reason, - description=field.description - ) - fields[name] = _field - map[type._meta.name].fields = fields - # map[type._meta.name] = GrapheneScalarType( - # graphene_type=type, - # name=type._meta.name, - # description=type._meta.description, - - # serialize=getattr(type, 'serialize', None), - # parse_value=getattr(type, 'parse_value', None), - # parse_literal=getattr(type, 'parse_literal', None), - # ) - return map - - def get_field_type(self, map, type): - if isinstance(type, List): - return GraphQLList(self.get_field_type(map, type.of_type)) - return map.get(type._meta.name) # def rebuild(self): # self._possible_type_map = defaultdict(set) # self._type_map = self._build_type_map(self.types) diff --git a/graphene/new_types/tests/test_definition.py b/graphene/new_types/tests/test_definition.py index f100562e..fb7e9b1c 100644 --- a/graphene/new_types/tests/test_definition.py +++ b/graphene/new_types/tests/test_definition.py @@ -47,7 +47,7 @@ class Subscription(ObjectType): def test_defines_a_query_only_schema(): blog_schema = Schema(Query) - assert blog_schema.get_query_type() == Query + assert blog_schema.get_query_type().graphene_type == Query article_field = Query._meta.fields['article'] assert article_field.type == Article @@ -68,3 +68,254 @@ def test_defines_a_query_only_schema(): feed_field = Query._meta.fields['feed'] assert feed_field.type.of_type == Article + + +def test_defines_a_mutation_schema(): + blog_schema = Schema(Query, mutation=Mutation) + + assert blog_schema.get_mutation_type().graphene_type == Mutation + + write_mutation = Mutation._meta.fields['write_article'] + assert write_mutation.type == Article + assert write_mutation.type._meta.name == 'Article' + + +def test_defines_a_subscription_schema(): + blog_schema = Schema(Query, subscription=Subscription) + + assert blog_schema.get_subscription_type().graphene_type == Subscription + + subscription = Subscription._meta.fields['article_subscribe'] + assert subscription.type == Article + assert subscription.type._meta.name == 'Article' + + +# def test_includes_nested_input_objects_in_the_map(): +# NestedInputObject = GraphQLInputObjectType( +# name='NestedInputObject', +# fields={'value': GraphQLInputObjectField(GraphQLString)} +# ) + +# SomeInputObject = GraphQLInputObjectType( +# name='SomeInputObject', +# fields={'nested': GraphQLInputObjectField(NestedInputObject)} +# ) + +# SomeMutation = GraphQLObjectType( +# name='SomeMutation', +# fields={ +# 'mutateSomething': GraphQLField( +# type=BlogArticle, +# args={ +# 'input': GraphQLArgument(SomeInputObject) +# } +# ) +# } +# ) +# SomeSubscription = GraphQLObjectType( +# name='SomeSubscription', +# fields={ +# 'subscribeToSomething': GraphQLField( +# type=BlogArticle, +# args={ +# 'input': GraphQLArgument(SomeInputObject) +# } +# ) +# } +# ) + +# schema = GraphQLSchema( +# query=BlogQuery, +# mutation=SomeMutation, +# subscription=SomeSubscription +# ) + +# assert schema.get_type_map()['NestedInputObject'] is NestedInputObject + + +# def test_includes_interfaces_thunk_subtypes_in_the_type_map(): +# SomeInterface = GraphQLInterfaceType( +# name='SomeInterface', +# fields={ +# 'f': GraphQLField(GraphQLInt) +# } +# ) + +# SomeSubtype = GraphQLObjectType( +# name='SomeSubtype', +# fields={ +# 'f': GraphQLField(GraphQLInt) +# }, +# interfaces=lambda: [SomeInterface], +# is_type_of=lambda: True +# ) + +# schema = GraphQLSchema(query=GraphQLObjectType( +# name='Query', +# fields={ +# 'iface': GraphQLField(SomeInterface) +# } +# ), types=[SomeSubtype]) + +# assert schema.get_type_map()['SomeSubtype'] is SomeSubtype + + +# def test_includes_interfaces_subtypes_in_the_type_map(): +# SomeInterface = GraphQLInterfaceType('SomeInterface', fields={'f': GraphQLField(GraphQLInt)}) +# SomeSubtype = GraphQLObjectType( +# name='SomeSubtype', +# fields={'f': GraphQLField(GraphQLInt)}, +# interfaces=[SomeInterface], +# is_type_of=lambda: None +# ) +# schema = GraphQLSchema( +# query=GraphQLObjectType( +# name='Query', +# fields={ +# 'iface': GraphQLField(SomeInterface)}), +# types=[SomeSubtype]) +# assert schema.get_type_map()['SomeSubtype'] == SomeSubtype + + +# def test_stringifies_simple_types(): +# assert str(GraphQLInt) == 'Int' +# assert str(BlogArticle) == 'Article' +# assert str(InterfaceType) == 'Interface' +# assert str(UnionType) == 'Union' +# assert str(EnumType) == 'Enum' +# assert str(InputObjectType) == 'InputObject' +# assert str(GraphQLNonNull(GraphQLInt)) == 'Int!' +# assert str(GraphQLList(GraphQLInt)) == '[Int]' +# assert str(GraphQLNonNull(GraphQLList(GraphQLInt))) == '[Int]!' +# assert str(GraphQLList(GraphQLNonNull(GraphQLInt))) == '[Int!]' +# assert str(GraphQLList(GraphQLList(GraphQLInt))) == '[[Int]]' + + +# def test_identifies_input_types(): +# expected = ( +# (GraphQLInt, True), +# (ObjectType, False), +# (InterfaceType, False), +# (UnionType, False), +# (EnumType, True), +# (InputObjectType, True) +# ) + +# for type, answer in expected: +# assert is_input_type(type) == answer +# assert is_input_type(GraphQLList(type)) == answer +# assert is_input_type(GraphQLNonNull(type)) == answer + + +# def test_identifies_output_types(): +# expected = ( +# (GraphQLInt, True), +# (ObjectType, True), +# (InterfaceType, True), +# (UnionType, True), +# (EnumType, True), +# (InputObjectType, False) +# ) + +# for type, answer in expected: +# assert is_output_type(type) == answer +# assert is_output_type(GraphQLList(type)) == answer +# assert is_output_type(GraphQLNonNull(type)) == answer + + +# def test_prohibits_nesting_nonnull_inside_nonnull(): +# with raises(Exception) as excinfo: +# GraphQLNonNull(GraphQLNonNull(GraphQLInt)) + +# assert 'Can only create NonNull of a Nullable GraphQLType but got: Int!.' in str(excinfo.value) + + +# def test_prohibits_putting_non_object_types_in_unions(): +# bad_union_types = [ +# GraphQLInt, +# GraphQLNonNull(GraphQLInt), +# GraphQLList(GraphQLInt), +# InterfaceType, +# UnionType, +# EnumType, +# InputObjectType +# ] +# for x in bad_union_types: +# with raises(Exception) as excinfo: +# GraphQLSchema(GraphQLObjectType('Root', fields={'union': GraphQLField(GraphQLUnionType('BadUnion', [x]))})) + +# assert 'BadUnion may only contain Object types, it cannot contain: ' + str(x) + '.' \ +# == str(excinfo.value) + + +# def test_does_not_mutate_passed_field_definitions(): +# fields = { +# 'field1': GraphQLField(GraphQLString), +# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}), +# } + +# TestObject1 = GraphQLObjectType(name='Test1', fields=fields) +# TestObject2 = GraphQLObjectType(name='Test1', fields=fields) + +# assert TestObject1.get_fields() == TestObject2.get_fields() +# assert fields == { +# 'field1': GraphQLField(GraphQLString), +# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}), +# } + +# input_fields = { +# 'field1': GraphQLInputObjectField(GraphQLString), +# 'field2': GraphQLInputObjectField(GraphQLString), +# } + +# TestInputObject1 = GraphQLInputObjectType(name='Test1', fields=input_fields) +# TestInputObject2 = GraphQLInputObjectType(name='Test2', fields=input_fields) + +# assert TestInputObject1.get_fields() == TestInputObject2.get_fields() + +# assert input_fields == { +# 'field1': GraphQLInputObjectField(GraphQLString), +# 'field2': GraphQLInputObjectField(GraphQLString), +# } + + +# def test_sorts_fields_and_argument_keys_if_not_using_ordered_dict(): +# fields = { +# 'b': GraphQLField(GraphQLString), +# 'c': GraphQLField(GraphQLString), +# 'a': GraphQLField(GraphQLString), +# 'd': GraphQLField(GraphQLString, args={ +# 'q': GraphQLArgument(GraphQLString), +# 'x': GraphQLArgument(GraphQLString), +# 'v': GraphQLArgument(GraphQLString), +# 'a': GraphQLArgument(GraphQLString), +# 'n': GraphQLArgument(GraphQLString) +# }) +# } + +# test_object = GraphQLObjectType(name='Test', fields=fields) +# ordered_fields = test_object.get_fields() +# assert list(ordered_fields.keys()) == ['a', 'b', 'c', 'd'] +# field_with_args = test_object.get_fields().get('d') +# assert [a.name for a in field_with_args.args] == ['a', 'n', 'q', 'v', 'x'] + + +# def test_does_not_sort_fields_and_argument_keys_when_using_ordered_dict(): +# fields = OrderedDict([ +# ('b', GraphQLField(GraphQLString)), +# ('c', GraphQLField(GraphQLString)), +# ('a', GraphQLField(GraphQLString)), +# ('d', GraphQLField(GraphQLString, args=OrderedDict([ +# ('q', GraphQLArgument(GraphQLString)), +# ('x', GraphQLArgument(GraphQLString)), +# ('v', GraphQLArgument(GraphQLString)), +# ('a', GraphQLArgument(GraphQLString)), +# ('n', GraphQLArgument(GraphQLString)) +# ]))) +# ]) + +# test_object = GraphQLObjectType(name='Test', fields=fields) +# ordered_fields = test_object.get_fields() +# assert list(ordered_fields.keys()) == ['b', 'c', 'a', 'd'] +# field_with_args = test_object.get_fields().get('d') +# assert [a.name for a in field_with_args.args] == ['q', 'x', 'v', 'a', 'n'] diff --git a/graphene/new_types/typemap.py b/graphene/new_types/typemap.py new file mode 100644 index 00000000..132ddccc --- /dev/null +++ b/graphene/new_types/typemap.py @@ -0,0 +1,106 @@ +import inspect +from collections import OrderedDict + +from graphql.type.typemap import GraphQLTypeMap + +from .objecttype import ObjectType +from .structures import List, NonNull +from .scalars import Scalar, String, Boolean, Int, Float, ID + +from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull + + +def is_graphene_type(_type): + if isinstance(_type, (List, NonNull)): + return True + if inspect.isclass(_type) and issubclass(_type, (ObjectType, Scalar)): + return True + + +class TypeMap(GraphQLTypeMap): + + @classmethod + def reducer(cls, map, type): + if not type: + return map + if is_graphene_type(type): + return cls.graphene_reducer(map, type) + return super(TypeMap, cls).reducer(map, type) + + @classmethod + def graphene_reducer(cls, map, type): + if isinstance(type, List): + return cls.reducer(map, type.of_type) + return map + if type._meta.name in map: + _type = map[type._meta.name] + if is_graphene_type(_type): + assert _type.graphene_type == type + return map + if issubclass(type, ObjectType): + return cls.construct_objecttype(map, type) + if issubclass(type, Scalar): + return cls.construct_scalar(map, type) + return map + + @classmethod + def construct_scalar(cls, map, type): + from ..generators.definitions import GrapheneScalarType + _scalars = { + String: GraphQLString, + Int: GraphQLInt, + Float: GraphQLFloat, + Boolean: GraphQLBoolean, + ID: GraphQLID + } + if type in _scalars: + map[type._meta.name] = _scalars[type] + else: + map[type._meta.name] = GrapheneScalarType( + graphene_type=type, + name=type._meta.name, + description=type._meta.description, + + serialize=getattr(type, 'serialize', None), + parse_value=getattr(type, 'parse_value', None), + parse_literal=getattr(type, 'parse_literal', None), + ) + return map + + @classmethod + def construct_objecttype(cls, map, type): + from ..generators.definitions import GrapheneObjectType + map[type._meta.name] = GrapheneObjectType( + graphene_type=type, + name=type._meta.name, + description=type._meta.description, + fields={}, + is_type_of=type.is_type_of, + interfaces=type._meta.interfaces + ) + map[type._meta.name].fields = cls.construct_fields_for_type(map, type) + return map + + @classmethod + def construct_fields_for_type(cls, map, type): + fields = OrderedDict() + for name, field in type._meta.fields.items(): + map = cls.reducer(map, field.type) + field_type = cls.get_field_type(map, field.type) + _field = GraphQLField( + field_type, + args=field.args, + resolver=field.resolver, + deprecation_reason=field.deprecation_reason, + description=field.description + ) + fields[name] = _field + return fields + + @classmethod + def get_field_type(self, map, type): + if isinstance(type, List): + return GraphQLList(self.get_field_type(map, type.of_type)) + if isinstance(type, NonNull): + return GraphQLNonNull(self.get_field_type(map, type.of_type)) + return map.get(type._meta.name)