Improved schema implementation

This commit is contained in:
Syrus Akbary 2016-08-11 01:00:46 -07:00
parent ac0e699069
commit feb8fb9b13
4 changed files with 95 additions and 4 deletions

View File

@ -1,4 +1,4 @@
# import inspect import inspect
from functools import partial from functools import partial
from collections import OrderedDict from collections import OrderedDict
@ -42,7 +42,7 @@ from .structures import NonNull
def source_resolver(source, root, args, context, info): def source_resolver(source, root, args, context, info):
resolved = getattr(root, source, None) resolved = getattr(root, source, None)
if callable(resolved): if inspect.isfunction(resolved):
return resolved() return resolved()
return resolved return resolved
@ -58,7 +58,7 @@ class Field(OrderedType):
# self.parent = None # self.parent = None
if required: if required:
type = NonNull(type) type = NonNull(type)
self.type = type self._type = type
self.args = args or OrderedDict() self.args = args or OrderedDict()
# self.args = to_arguments(args, extra_args) # self.args = to_arguments(args, extra_args)
assert not (source and resolver), ('You cannot provide a source and a ' assert not (source and resolver), ('You cannot provide a source and a '
@ -68,3 +68,9 @@ class Field(OrderedType):
self.resolver = resolver self.resolver = resolver
self.deprecation_reason = deprecation_reason self.deprecation_reason = deprecation_reason
self.description = description self.description = description
@property
def type(self):
if inspect.isfunction(self._type):
return self._type()
return self._type

View File

@ -31,6 +31,8 @@ class ObjectTypeMeta(AbstractTypeMeta):
class ObjectType(six.with_metaclass(ObjectTypeMeta)): class ObjectType(six.with_metaclass(ObjectTypeMeta)):
is_type_of = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# GraphQL ObjectType acting as container # GraphQL ObjectType acting as container
args_len = len(args) args_len = len(args)

View File

@ -6,6 +6,7 @@ from graphql.utils.schema_printer import print_schema
from .objecttype import ObjectType from .objecttype import ObjectType
from .scalars import Scalar
# from ..utils.get_graphql_type import get_graphql_type # from ..utils.get_graphql_type import get_graphql_type
@ -14,6 +15,19 @@ from .objecttype import ObjectType
# from collections import defaultdict # from collections import defaultdict
from collections import Iterable, OrderedDict, defaultdict
from functools import reduce
from graphql.utils.type_comparators import is_equal_type, is_type_sub_type_of
from graphql.type.definition import (GraphQLInputObjectType, GraphQLInterfaceType, GraphQLField,
GraphQLList, GraphQLNonNull, GraphQLObjectType,
GraphQLUnionType)
from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective,
GraphQLSkipDirective)
from graphql.type.introspection import IntrospectionSchema
from graphql.type.schema import assert_object_implements_interface
class Schema(GraphQLSchema): class Schema(GraphQLSchema):
def __init__(self, query=None, mutation=None, subscription=None, directives=None, types=None, executor=None): def __init__(self, query=None, mutation=None, subscription=None, directives=None, types=None, executor=None):
@ -29,6 +43,32 @@ class Schema(GraphQLSchema):
# directives=directives, # directives=directives,
# types=self.types # types=self.types
# ) # )
if directives is None:
directives = [
GraphQLIncludeDirective,
GraphQLSkipDirective
]
assert all(isinstance(d, GraphQLDirective) for d in directives), \
'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format(
directives
)
self._directives = directives
self._possible_type_map = defaultdict(set)
self._type_map = self._build_type_map(types)
# Keep track of all implementations by interface name.
self._implementations = defaultdict(list)
for type in self._type_map.values():
if isinstance(type, GraphQLObjectType):
for interface in type.get_interfaces():
self._implementations[interface.name].append(type)
# Enforce correct interface implementations.
for type in self._type_map.values():
if isinstance(type, GraphQLObjectType):
for interface in type.get_interfaces():
assert_object_implements_interface(self, type, interface)
def execute(self, request_string='', root_value=None, variable_values=None, def execute(self, request_string='', root_value=None, variable_values=None,
context_value=None, operation_name=None, executor=None): context_value=None, operation_name=None, executor=None):
@ -57,12 +97,34 @@ class Schema(GraphQLSchema):
def _type_map_reducer(self, map, type): def _type_map_reducer(self, map, type):
if not type: if not type:
return map return map
if inspect.isclass(type) and issubclass(type, (ObjectType)): if inspect.isclass(type) and issubclass(type, (ObjectType, Scalar)):
return self._type_map_reducer_graphene(map, type) return self._type_map_reducer_graphene(map, type)
return super(Schema, self)._type_map_reducer(map, type) return super(Schema, self)._type_map_reducer(map, type)
def _type_map_reducer_graphene(self, map, type): def _type_map_reducer_graphene(self, map, type):
# from .structures import List, NonNull # from .structures import List, NonNull
from ..generators.definitions import GrapheneObjectType
if issubclass(type, ObjectType):
fields = OrderedDict()
for name, field in type._meta.fields.items():
map = self._type_map_reducer(map, field.type)
field_type = map.get(field.type._meta.name)
_field = GraphQLField(
field_type,
args=field.args,
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description
)
fields[name] = _field
map[type._meta.name] = GrapheneObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=fields,
is_type_of=type.is_type_of,
interfaces=type._meta.interfaces
)
return map return map
# def rebuild(self): # def rebuild(self):

View File

@ -0,0 +1,21 @@
from collections import OrderedDict
from py.test import raises
from ..objecttype import ObjectType
from ..scalars import String, Int, Boolean
from ..field import Field
from ..structures import List
from ..schema import Schema
class Query(ObjectType):
hello = String(resolver=lambda *_: 'World')
def test_query():
hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello }')
print executed.errors