Added type getter in the schema when accessing its attrs

This commit is contained in:
Syrus Akbary 2016-11-22 23:34:52 -08:00
parent df2900e215
commit 5e0923b560
3 changed files with 73 additions and 3 deletions

View File

@ -6,6 +6,7 @@ from graphql.type.introspection import IntrospectionSchema
from graphql.utils.introspection_query import introspection_query
from graphql.utils.schema_printer import print_schema
from .definitions import GrapheneGraphQLType
from .typemap import TypeMap, is_graphene_type
@ -46,6 +47,20 @@ class Schema(GraphQLSchema):
def get_subscription_type(self):
return self.get_graphql_type(self._subscription)
def __getattr__(self, type_name):
'''
This function let the developer select a type in a given schema
by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema
'''
_type = super(Schema, self).get_type(type_name)
if _type is None:
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type
return _type
def get_graphql_type(self, _type):
if not _type:
return _type

View File

@ -0,0 +1,54 @@
import pytest
from ..schema import Schema
from ..objecttype import ObjectType
from ..scalars import String
from ..field import Field
class MyOtherType(ObjectType):
field = String()
class Query(ObjectType):
inner = Field(MyOtherType)
def test_schema():
schema = Schema(Query)
assert schema.get_query_type() == schema.get_graphql_type(Query)
def test_schema_get_type():
schema = Schema(Query)
assert schema.Query == Query
assert schema.MyOtherType == MyOtherType
def test_schema_get_type_error():
schema = Schema(Query)
with pytest.raises(AttributeError) as exc_info:
schema.X
assert str(exc_info.value) == 'Type "X" not found in the Schema'
def test_schema_str():
schema = Schema(Query)
assert str(schema) == """schema {
query: Query
}
type MyOtherType {
field: String
}
type Query {
inner: MyOtherType
}
"""
def test_schema_introspect():
schema = Schema(Query)
assert '__schema' in schema.introspect()

View File

@ -13,7 +13,8 @@ from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case
from .definitions import (GrapheneEnumType, GrapheneInputObjectType,
GrapheneInterfaceType, GrapheneObjectType,
GrapheneScalarType, GrapheneUnionType)
GrapheneScalarType, GrapheneUnionType,
GrapheneGraphQLType)
from .dynamic import Dynamic
from .enum import Enum
from .field import Field
@ -68,7 +69,7 @@ class TypeMap(GraphQLTypeMap):
return self.reducer(map, type.of_type)
if type._meta.name in map:
_type = map[type._meta.name]
if is_graphene_type(_type):
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type
return map
if issubclass(type, ObjectType):
@ -127,7 +128,7 @@ class TypeMap(GraphQLTypeMap):
def construct_objecttype(self, map, type):
if type._meta.name in map:
_type = map[type._meta.name]
if is_graphene_type(_type):
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type
return map
map[type._meta.name] = GrapheneObjectType(