From feb8fb9b132c9f4a007836edc7cbdc2a3736a79f Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 11 Aug 2016 01:00:46 -0700 Subject: [PATCH] Improved schema implementation --- graphene/new_types/field.py | 12 +++-- graphene/new_types/objecttype.py | 2 + graphene/new_types/schema.py | 64 +++++++++++++++++++++++++- graphene/new_types/tests/test_query.py | 21 +++++++++ 4 files changed, 95 insertions(+), 4 deletions(-) create mode 100644 graphene/new_types/tests/test_query.py diff --git a/graphene/new_types/field.py b/graphene/new_types/field.py index 3f2f975a..f3d998e0 100644 --- a/graphene/new_types/field.py +++ b/graphene/new_types/field.py @@ -1,4 +1,4 @@ -# import inspect +import inspect from functools import partial from collections import OrderedDict @@ -42,7 +42,7 @@ from .structures import NonNull def source_resolver(source, root, args, context, info): resolved = getattr(root, source, None) - if callable(resolved): + if inspect.isfunction(resolved): return resolved() return resolved @@ -58,7 +58,7 @@ class Field(OrderedType): # self.parent = None if required: type = NonNull(type) - self.type = type + self._type = type self.args = args or OrderedDict() # self.args = to_arguments(args, extra_args) assert not (source and resolver), ('You cannot provide a source and a ' @@ -68,3 +68,9 @@ class Field(OrderedType): self.resolver = resolver self.deprecation_reason = deprecation_reason self.description = description + + @property + def type(self): + if inspect.isfunction(self._type): + return self._type() + return self._type diff --git a/graphene/new_types/objecttype.py b/graphene/new_types/objecttype.py index 92c38e42..59e527e4 100644 --- a/graphene/new_types/objecttype.py +++ b/graphene/new_types/objecttype.py @@ -31,6 +31,8 @@ class ObjectTypeMeta(AbstractTypeMeta): class ObjectType(six.with_metaclass(ObjectTypeMeta)): + is_type_of = None + def __init__(self, *args, **kwargs): # GraphQL ObjectType acting as container args_len = len(args) diff --git a/graphene/new_types/schema.py b/graphene/new_types/schema.py index 7be15fb5..619c05bc 100644 --- a/graphene/new_types/schema.py +++ b/graphene/new_types/schema.py @@ -6,6 +6,7 @@ from graphql.utils.schema_printer import print_schema from .objecttype import ObjectType +from .scalars import Scalar # from ..utils.get_graphql_type import get_graphql_type @@ -14,6 +15,19 @@ from .objecttype import ObjectType # from collections import defaultdict +from collections import Iterable, OrderedDict, defaultdict +from functools import reduce + +from graphql.utils.type_comparators import is_equal_type, is_type_sub_type_of +from graphql.type.definition import (GraphQLInputObjectType, GraphQLInterfaceType, GraphQLField, + GraphQLList, GraphQLNonNull, GraphQLObjectType, + GraphQLUnionType) +from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective, + GraphQLSkipDirective) +from graphql.type.introspection import IntrospectionSchema +from graphql.type.schema import assert_object_implements_interface + + class Schema(GraphQLSchema): def __init__(self, query=None, mutation=None, subscription=None, directives=None, types=None, executor=None): @@ -29,6 +43,32 @@ class Schema(GraphQLSchema): # directives=directives, # types=self.types # ) + if directives is None: + directives = [ + GraphQLIncludeDirective, + GraphQLSkipDirective + ] + + assert all(isinstance(d, GraphQLDirective) for d in directives), \ + 'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( + directives + ) + + self._directives = directives + self._possible_type_map = defaultdict(set) + self._type_map = self._build_type_map(types) + # Keep track of all implementations by interface name. + self._implementations = defaultdict(list) + for type in self._type_map.values(): + if isinstance(type, GraphQLObjectType): + for interface in type.get_interfaces(): + self._implementations[interface.name].append(type) + + # Enforce correct interface implementations. + for type in self._type_map.values(): + if isinstance(type, GraphQLObjectType): + for interface in type.get_interfaces(): + assert_object_implements_interface(self, type, interface) def execute(self, request_string='', root_value=None, variable_values=None, context_value=None, operation_name=None, executor=None): @@ -57,12 +97,34 @@ class Schema(GraphQLSchema): def _type_map_reducer(self, map, type): if not type: return map - if inspect.isclass(type) and issubclass(type, (ObjectType)): + if inspect.isclass(type) and issubclass(type, (ObjectType, Scalar)): return self._type_map_reducer_graphene(map, type) return super(Schema, self)._type_map_reducer(map, type) def _type_map_reducer_graphene(self, map, type): # from .structures import List, NonNull + from ..generators.definitions import GrapheneObjectType + if issubclass(type, ObjectType): + fields = OrderedDict() + for name, field in type._meta.fields.items(): + map = self._type_map_reducer(map, field.type) + field_type = map.get(field.type._meta.name) + _field = GraphQLField( + field_type, + args=field.args, + resolver=field.resolver, + deprecation_reason=field.deprecation_reason, + description=field.description + ) + fields[name] = _field + map[type._meta.name] = GrapheneObjectType( + graphene_type=type, + name=type._meta.name, + description=type._meta.description, + fields=fields, + is_type_of=type.is_type_of, + interfaces=type._meta.interfaces + ) return map # def rebuild(self): diff --git a/graphene/new_types/tests/test_query.py b/graphene/new_types/tests/test_query.py new file mode 100644 index 00000000..7bdf3651 --- /dev/null +++ b/graphene/new_types/tests/test_query.py @@ -0,0 +1,21 @@ +from collections import OrderedDict + +from py.test import raises + +from ..objecttype import ObjectType +from ..scalars import String, Int, Boolean +from ..field import Field +from ..structures import List + +from ..schema import Schema + + +class Query(ObjectType): + hello = String(resolver=lambda *_: 'World') + + +def test_query(): + hello_schema = Schema(Query) + + executed = hello_schema.execute('{ hello }') + print executed.errors