Added TypeMap

This commit is contained in:
Syrus Akbary 2016-08-11 19:44:37 -07:00
parent 3620e2ad4e
commit 7923f45595
3 changed files with 388 additions and 88 deletions

View File

@ -1,6 +1,6 @@
import inspect
from graphql import GraphQLSchema, graphql
from graphql import GraphQLSchema, graphql, is_type
from graphql.utils.introspection_query import introspection_query
from graphql.utils.schema_printer import print_schema
@ -16,18 +16,10 @@ from .scalars import Scalar, String
# from collections import defaultdict
from collections import Iterable, OrderedDict, defaultdict
from functools import reduce
from graphql.utils.type_comparators import is_equal_type, is_type_sub_type_of
from graphql.type.definition import (GraphQLInputObjectType, GraphQLInterfaceType, GraphQLField,GraphQLScalarType,
GraphQLList, GraphQLNonNull, GraphQLObjectType,
GraphQLUnionType)
from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective,
GraphQLSkipDirective)
from graphql.type.introspection import IntrospectionSchema
from graphql.type.schema import assert_object_implements_interface
from graphql.type.scalars import GraphQLString
from .typemap import TypeMap, is_graphene_type
class Schema(GraphQLSchema):
@ -38,13 +30,6 @@ class Schema(GraphQLSchema):
self._subscription = subscription
self.types = types
self._executor = executor
# super(Schema, self).__init__(
# query=query,
# mutation=mutation,
# subscription=subscription,
# directives=directives,
# types=self.types
# )
if directives is None:
directives = [
GraphQLIncludeDirective,
@ -57,20 +42,34 @@ class Schema(GraphQLSchema):
)
self._directives = directives
self._possible_type_map = defaultdict(set)
self._type_map = self._build_type_map(types)
# Keep track of all implementations by interface name.
self._implementations = defaultdict(list)
for type in self._type_map.values():
if isinstance(type, GraphQLObjectType):
for interface in type.get_interfaces():
self._implementations[interface.name].append(type)
initial_types = [
query,
mutation,
subscription,
IntrospectionSchema
]
if types:
initial_types += types
self._type_map = TypeMap(initial_types)
# Enforce correct interface implementations.
for type in self._type_map.values():
if isinstance(type, GraphQLObjectType):
for interface in type.get_interfaces():
assert_object_implements_interface(self, type, interface)
def get_query_type(self):
return self.get_graphql_type(self._query)
def get_mutation_type(self):
return self.get_graphql_type(self._mutation)
def get_subscription_type(self):
return self.get_graphql_type(self._subscription)
def get_graphql_type(self, _type):
if is_type(_type):
return _type
if is_graphene_type(_type):
graphql_type = self.get_type(_type._meta.name)
assert graphql_type, "Type {} not found in this schema.".format(_type._meta.name)
assert graphql_type.graphene_type == _type
return graphql_type
raise Exception("{} is not a valid GraphQL type.".format(_type))
def execute(self, request_string='', root_value=None, variable_values=None,
context_value=None, operation_name=None, executor=None):
@ -96,62 +95,6 @@ class Schema(GraphQLSchema):
def lazy(self, _type):
return lambda: self.get_type(_type)
def _type_map_reducer(self, map, type):
if not type:
return map
if isinstance(type, List) or (inspect.isclass(type) and issubclass(type, (ObjectType, Scalar))):
return self._type_map_reducer_graphene(map, type)
return super(Schema, self)._type_map_reducer(map, type)
def _type_map_reducer_graphene(self, map, type):
# from .structures import List, NonNull
from ..generators.definitions import GrapheneObjectType, GrapheneScalarType
if isinstance(type, List):
return self._type_map_reducer(map, type.of_type)
if issubclass(type, String):
map[type._meta.name] = GraphQLString
return map
if type._meta.name in map:
assert map[type._meta.name].graphene_type == type
return map
if issubclass(type, ObjectType):
fields = OrderedDict()
map[type._meta.name] = GrapheneObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields={},
is_type_of=type.is_type_of,
interfaces=type._meta.interfaces
)
for name, field in type._meta.fields.items():
map = self._type_map_reducer(map, field.type)
field_type = self.get_field_type(map, field.type)
_field = GraphQLField(
field_type,
args=field.args,
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description
)
fields[name] = _field
map[type._meta.name].fields = fields
# map[type._meta.name] = GrapheneScalarType(
# graphene_type=type,
# name=type._meta.name,
# description=type._meta.description,
# serialize=getattr(type, 'serialize', None),
# parse_value=getattr(type, 'parse_value', None),
# parse_literal=getattr(type, 'parse_literal', None),
# )
return map
def get_field_type(self, map, type):
if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type))
return map.get(type._meta.name)
# def rebuild(self):
# self._possible_type_map = defaultdict(set)
# self._type_map = self._build_type_map(self.types)

View File

@ -47,7 +47,7 @@ class Subscription(ObjectType):
def test_defines_a_query_only_schema():
blog_schema = Schema(Query)
assert blog_schema.get_query_type() == Query
assert blog_schema.get_query_type().graphene_type == Query
article_field = Query._meta.fields['article']
assert article_field.type == Article
@ -68,3 +68,254 @@ def test_defines_a_query_only_schema():
feed_field = Query._meta.fields['feed']
assert feed_field.type.of_type == Article
def test_defines_a_mutation_schema():
blog_schema = Schema(Query, mutation=Mutation)
assert blog_schema.get_mutation_type().graphene_type == Mutation
write_mutation = Mutation._meta.fields['write_article']
assert write_mutation.type == Article
assert write_mutation.type._meta.name == 'Article'
def test_defines_a_subscription_schema():
blog_schema = Schema(Query, subscription=Subscription)
assert blog_schema.get_subscription_type().graphene_type == Subscription
subscription = Subscription._meta.fields['article_subscribe']
assert subscription.type == Article
assert subscription.type._meta.name == 'Article'
# def test_includes_nested_input_objects_in_the_map():
# NestedInputObject = GraphQLInputObjectType(
# name='NestedInputObject',
# fields={'value': GraphQLInputObjectField(GraphQLString)}
# )
# SomeInputObject = GraphQLInputObjectType(
# name='SomeInputObject',
# fields={'nested': GraphQLInputObjectField(NestedInputObject)}
# )
# SomeMutation = GraphQLObjectType(
# name='SomeMutation',
# fields={
# 'mutateSomething': GraphQLField(
# type=BlogArticle,
# args={
# 'input': GraphQLArgument(SomeInputObject)
# }
# )
# }
# )
# SomeSubscription = GraphQLObjectType(
# name='SomeSubscription',
# fields={
# 'subscribeToSomething': GraphQLField(
# type=BlogArticle,
# args={
# 'input': GraphQLArgument(SomeInputObject)
# }
# )
# }
# )
# schema = GraphQLSchema(
# query=BlogQuery,
# mutation=SomeMutation,
# subscription=SomeSubscription
# )
# assert schema.get_type_map()['NestedInputObject'] is NestedInputObject
# def test_includes_interfaces_thunk_subtypes_in_the_type_map():
# SomeInterface = GraphQLInterfaceType(
# name='SomeInterface',
# fields={
# 'f': GraphQLField(GraphQLInt)
# }
# )
# SomeSubtype = GraphQLObjectType(
# name='SomeSubtype',
# fields={
# 'f': GraphQLField(GraphQLInt)
# },
# interfaces=lambda: [SomeInterface],
# is_type_of=lambda: True
# )
# schema = GraphQLSchema(query=GraphQLObjectType(
# name='Query',
# fields={
# 'iface': GraphQLField(SomeInterface)
# }
# ), types=[SomeSubtype])
# assert schema.get_type_map()['SomeSubtype'] is SomeSubtype
# def test_includes_interfaces_subtypes_in_the_type_map():
# SomeInterface = GraphQLInterfaceType('SomeInterface', fields={'f': GraphQLField(GraphQLInt)})
# SomeSubtype = GraphQLObjectType(
# name='SomeSubtype',
# fields={'f': GraphQLField(GraphQLInt)},
# interfaces=[SomeInterface],
# is_type_of=lambda: None
# )
# schema = GraphQLSchema(
# query=GraphQLObjectType(
# name='Query',
# fields={
# 'iface': GraphQLField(SomeInterface)}),
# types=[SomeSubtype])
# assert schema.get_type_map()['SomeSubtype'] == SomeSubtype
# def test_stringifies_simple_types():
# assert str(GraphQLInt) == 'Int'
# assert str(BlogArticle) == 'Article'
# assert str(InterfaceType) == 'Interface'
# assert str(UnionType) == 'Union'
# assert str(EnumType) == 'Enum'
# assert str(InputObjectType) == 'InputObject'
# assert str(GraphQLNonNull(GraphQLInt)) == 'Int!'
# assert str(GraphQLList(GraphQLInt)) == '[Int]'
# assert str(GraphQLNonNull(GraphQLList(GraphQLInt))) == '[Int]!'
# assert str(GraphQLList(GraphQLNonNull(GraphQLInt))) == '[Int!]'
# assert str(GraphQLList(GraphQLList(GraphQLInt))) == '[[Int]]'
# def test_identifies_input_types():
# expected = (
# (GraphQLInt, True),
# (ObjectType, False),
# (InterfaceType, False),
# (UnionType, False),
# (EnumType, True),
# (InputObjectType, True)
# )
# for type, answer in expected:
# assert is_input_type(type) == answer
# assert is_input_type(GraphQLList(type)) == answer
# assert is_input_type(GraphQLNonNull(type)) == answer
# def test_identifies_output_types():
# expected = (
# (GraphQLInt, True),
# (ObjectType, True),
# (InterfaceType, True),
# (UnionType, True),
# (EnumType, True),
# (InputObjectType, False)
# )
# for type, answer in expected:
# assert is_output_type(type) == answer
# assert is_output_type(GraphQLList(type)) == answer
# assert is_output_type(GraphQLNonNull(type)) == answer
# def test_prohibits_nesting_nonnull_inside_nonnull():
# with raises(Exception) as excinfo:
# GraphQLNonNull(GraphQLNonNull(GraphQLInt))
# assert 'Can only create NonNull of a Nullable GraphQLType but got: Int!.' in str(excinfo.value)
# def test_prohibits_putting_non_object_types_in_unions():
# bad_union_types = [
# GraphQLInt,
# GraphQLNonNull(GraphQLInt),
# GraphQLList(GraphQLInt),
# InterfaceType,
# UnionType,
# EnumType,
# InputObjectType
# ]
# for x in bad_union_types:
# with raises(Exception) as excinfo:
# GraphQLSchema(GraphQLObjectType('Root', fields={'union': GraphQLField(GraphQLUnionType('BadUnion', [x]))}))
# assert 'BadUnion may only contain Object types, it cannot contain: ' + str(x) + '.' \
# == str(excinfo.value)
# def test_does_not_mutate_passed_field_definitions():
# fields = {
# 'field1': GraphQLField(GraphQLString),
# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}),
# }
# TestObject1 = GraphQLObjectType(name='Test1', fields=fields)
# TestObject2 = GraphQLObjectType(name='Test1', fields=fields)
# assert TestObject1.get_fields() == TestObject2.get_fields()
# assert fields == {
# 'field1': GraphQLField(GraphQLString),
# 'field2': GraphQLField(GraphQLString, args={'id': GraphQLArgument(GraphQLString)}),
# }
# input_fields = {
# 'field1': GraphQLInputObjectField(GraphQLString),
# 'field2': GraphQLInputObjectField(GraphQLString),
# }
# TestInputObject1 = GraphQLInputObjectType(name='Test1', fields=input_fields)
# TestInputObject2 = GraphQLInputObjectType(name='Test2', fields=input_fields)
# assert TestInputObject1.get_fields() == TestInputObject2.get_fields()
# assert input_fields == {
# 'field1': GraphQLInputObjectField(GraphQLString),
# 'field2': GraphQLInputObjectField(GraphQLString),
# }
# def test_sorts_fields_and_argument_keys_if_not_using_ordered_dict():
# fields = {
# 'b': GraphQLField(GraphQLString),
# 'c': GraphQLField(GraphQLString),
# 'a': GraphQLField(GraphQLString),
# 'd': GraphQLField(GraphQLString, args={
# 'q': GraphQLArgument(GraphQLString),
# 'x': GraphQLArgument(GraphQLString),
# 'v': GraphQLArgument(GraphQLString),
# 'a': GraphQLArgument(GraphQLString),
# 'n': GraphQLArgument(GraphQLString)
# })
# }
# test_object = GraphQLObjectType(name='Test', fields=fields)
# ordered_fields = test_object.get_fields()
# assert list(ordered_fields.keys()) == ['a', 'b', 'c', 'd']
# field_with_args = test_object.get_fields().get('d')
# assert [a.name for a in field_with_args.args] == ['a', 'n', 'q', 'v', 'x']
# def test_does_not_sort_fields_and_argument_keys_when_using_ordered_dict():
# fields = OrderedDict([
# ('b', GraphQLField(GraphQLString)),
# ('c', GraphQLField(GraphQLString)),
# ('a', GraphQLField(GraphQLString)),
# ('d', GraphQLField(GraphQLString, args=OrderedDict([
# ('q', GraphQLArgument(GraphQLString)),
# ('x', GraphQLArgument(GraphQLString)),
# ('v', GraphQLArgument(GraphQLString)),
# ('a', GraphQLArgument(GraphQLString)),
# ('n', GraphQLArgument(GraphQLString))
# ])))
# ])
# test_object = GraphQLObjectType(name='Test', fields=fields)
# ordered_fields = test_object.get_fields()
# assert list(ordered_fields.keys()) == ['b', 'c', 'a', 'd']
# field_with_args = test_object.get_fields().get('d')
# assert [a.name for a in field_with_args.args] == ['q', 'x', 'v', 'a', 'n']

View File

@ -0,0 +1,106 @@
import inspect
from collections import OrderedDict
from graphql.type.typemap import GraphQLTypeMap
from .objecttype import ObjectType
from .structures import List, NonNull
from .scalars import Scalar, String, Boolean, Int, Float, ID
from graphql import GraphQLString, GraphQLField, GraphQLList, GraphQLBoolean, GraphQLInt, GraphQLFloat, GraphQLID, GraphQLNonNull
def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)):
return True
if inspect.isclass(_type) and issubclass(_type, (ObjectType, Scalar)):
return True
class TypeMap(GraphQLTypeMap):
@classmethod
def reducer(cls, map, type):
if not type:
return map
if is_graphene_type(type):
return cls.graphene_reducer(map, type)
return super(TypeMap, cls).reducer(map, type)
@classmethod
def graphene_reducer(cls, map, type):
if isinstance(type, List):
return cls.reducer(map, type.of_type)
return map
if type._meta.name in map:
_type = map[type._meta.name]
if is_graphene_type(_type):
assert _type.graphene_type == type
return map
if issubclass(type, ObjectType):
return cls.construct_objecttype(map, type)
if issubclass(type, Scalar):
return cls.construct_scalar(map, type)
return map
@classmethod
def construct_scalar(cls, map, type):
from ..generators.definitions import GrapheneScalarType
_scalars = {
String: GraphQLString,
Int: GraphQLInt,
Float: GraphQLFloat,
Boolean: GraphQLBoolean,
ID: GraphQLID
}
if type in _scalars:
map[type._meta.name] = _scalars[type]
else:
map[type._meta.name] = GrapheneScalarType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
serialize=getattr(type, 'serialize', None),
parse_value=getattr(type, 'parse_value', None),
parse_literal=getattr(type, 'parse_literal', None),
)
return map
@classmethod
def construct_objecttype(cls, map, type):
from ..generators.definitions import GrapheneObjectType
map[type._meta.name] = GrapheneObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields={},
is_type_of=type.is_type_of,
interfaces=type._meta.interfaces
)
map[type._meta.name].fields = cls.construct_fields_for_type(map, type)
return map
@classmethod
def construct_fields_for_type(cls, map, type):
fields = OrderedDict()
for name, field in type._meta.fields.items():
map = cls.reducer(map, field.type)
field_type = cls.get_field_type(map, field.type)
_field = GraphQLField(
field_type,
args=field.args,
resolver=field.resolver,
deprecation_reason=field.deprecation_reason,
description=field.description
)
fields[name] = _field
return fields
@classmethod
def get_field_type(self, map, type):
if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type))
if isinstance(type, NonNull):
return GraphQLNonNull(self.get_field_type(map, type.of_type))
return map.get(type._meta.name)