Improved schema

This commit is contained in:
Syrus Akbary 2016-08-11 01:25:24 -07:00
parent c4fba3b7ca
commit 3e77f258b4
2 changed files with 39 additions and 14 deletions

View File

@ -6,7 +6,8 @@ from graphql.utils.schema_printer import print_schema
from .objecttype import ObjectType from .objecttype import ObjectType
from .scalars import Scalar from .structures import List, NonNull
from .scalars import Scalar, String
# from ..utils.get_graphql_type import get_graphql_type # from ..utils.get_graphql_type import get_graphql_type
@ -19,13 +20,14 @@ from collections import Iterable, OrderedDict, defaultdict
from functools import reduce from functools import reduce
from graphql.utils.type_comparators import is_equal_type, is_type_sub_type_of from graphql.utils.type_comparators import is_equal_type, is_type_sub_type_of
from graphql.type.definition import (GraphQLInputObjectType, GraphQLInterfaceType, GraphQLField, from graphql.type.definition import (GraphQLInputObjectType, GraphQLInterfaceType, GraphQLField,GraphQLScalarType,
GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLList, GraphQLNonNull, GraphQLObjectType,
GraphQLUnionType) GraphQLUnionType)
from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective, from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective,
GraphQLSkipDirective) GraphQLSkipDirective)
from graphql.type.introspection import IntrospectionSchema from graphql.type.introspection import IntrospectionSchema
from graphql.type.schema import assert_object_implements_interface from graphql.type.schema import assert_object_implements_interface
from graphql.type.scalars import GraphQLString
class Schema(GraphQLSchema): class Schema(GraphQLSchema):
@ -97,18 +99,35 @@ class Schema(GraphQLSchema):
def _type_map_reducer(self, map, type): def _type_map_reducer(self, map, type):
if not type: if not type:
return map return map
if inspect.isclass(type) and issubclass(type, (ObjectType, Scalar)): if isinstance(type, List) or (inspect.isclass(type) and issubclass(type, (ObjectType, Scalar))):
return self._type_map_reducer_graphene(map, type) return self._type_map_reducer_graphene(map, type)
return super(Schema, self)._type_map_reducer(map, type) return super(Schema, self)._type_map_reducer(map, type)
def _type_map_reducer_graphene(self, map, type): def _type_map_reducer_graphene(self, map, type):
# from .structures import List, NonNull # from .structures import List, NonNull
from ..generators.definitions import GrapheneObjectType from ..generators.definitions import GrapheneObjectType, GrapheneScalarType
if isinstance(type, List):
return self._type_map_reducer(map, type.of_type)
if issubclass(type, String):
map[type._meta.name] = GraphQLString
return map
if type._meta.name in map:
assert map[type._meta.name].graphene_type == type
return map
if issubclass(type, ObjectType): if issubclass(type, ObjectType):
fields = OrderedDict() fields = OrderedDict()
map[type._meta.name] = GrapheneObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields={},
is_type_of=type.is_type_of,
interfaces=type._meta.interfaces
)
for name, field in type._meta.fields.items(): for name, field in type._meta.fields.items():
map = self._type_map_reducer(map, field.type) map = self._type_map_reducer(map, field.type)
field_type = map.get(field.type._meta.name) field_type = self.get_field_type(map, field.type)
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=field.args, args=field.args,
@ -117,16 +136,22 @@ class Schema(GraphQLSchema):
description=field.description description=field.description
) )
fields[name] = _field fields[name] = _field
map[type._meta.name] = GrapheneObjectType( map[type._meta.name].fields = fields
graphene_type=type, # map[type._meta.name] = GrapheneScalarType(
name=type._meta.name, # graphene_type=type,
description=type._meta.description, # name=type._meta.name,
fields=fields, # description=type._meta.description,
is_type_of=type.is_type_of,
interfaces=type._meta.interfaces # serialize=getattr(type, 'serialize', None),
) # parse_value=getattr(type, 'parse_value', None),
# parse_literal=getattr(type, 'parse_literal', None),
# )
return map return map
def get_field_type(self, map, type):
if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type))
return map.get(type._meta.name)
# def rebuild(self): # def rebuild(self):
# self._possible_type_map = defaultdict(set) # self._possible_type_map = defaultdict(set)
# self._type_map = self._build_type_map(self.types) # self._type_map = self._build_type_map(self.types)

View File

@ -64,7 +64,7 @@ def test_defines_a_query_only_schema():
assert issubclass(author_field_type, ObjectType) assert issubclass(author_field_type, ObjectType)
recent_article_field = author_field_type._meta.fields['recent_article'] recent_article_field = author_field_type._meta.fields['recent_article']
assert recent_article_field.type() == Article assert recent_article_field.type == Article
feed_field = Query._meta.fields['feed'] feed_field = Query._meta.fields['feed']
assert feed_field.type.of_type == Article assert feed_field.type.of_type == Article