Added type getter in the schema when accessing its attrs

This commit is contained in:
Syrus Akbary 2016-11-22 23:34:52 -08:00
parent df2900e215
commit 5e0923b560
3 changed files with 73 additions and 3 deletions

View File

@ -6,6 +6,7 @@ from graphql.type.introspection import IntrospectionSchema
from graphql.utils.introspection_query import introspection_query from graphql.utils.introspection_query import introspection_query
from graphql.utils.schema_printer import print_schema from graphql.utils.schema_printer import print_schema
from .definitions import GrapheneGraphQLType
from .typemap import TypeMap, is_graphene_type from .typemap import TypeMap, is_graphene_type
@ -46,6 +47,20 @@ class Schema(GraphQLSchema):
def get_subscription_type(self): def get_subscription_type(self):
return self.get_graphql_type(self._subscription) return self.get_graphql_type(self._subscription)
def __getattr__(self, type_name):
'''
This function let the developer select a type in a given schema
by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema
'''
_type = super(Schema, self).get_type(type_name)
if _type is None:
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type
return _type
def get_graphql_type(self, _type): def get_graphql_type(self, _type):
if not _type: if not _type:
return _type return _type

View File

@ -0,0 +1,54 @@
import pytest
from ..schema import Schema
from ..objecttype import ObjectType
from ..scalars import String
from ..field import Field
class MyOtherType(ObjectType):
field = String()
class Query(ObjectType):
inner = Field(MyOtherType)
def test_schema():
schema = Schema(Query)
assert schema.get_query_type() == schema.get_graphql_type(Query)
def test_schema_get_type():
schema = Schema(Query)
assert schema.Query == Query
assert schema.MyOtherType == MyOtherType
def test_schema_get_type_error():
schema = Schema(Query)
with pytest.raises(AttributeError) as exc_info:
schema.X
assert str(exc_info.value) == 'Type "X" not found in the Schema'
def test_schema_str():
schema = Schema(Query)
assert str(schema) == """schema {
query: Query
}
type MyOtherType {
field: String
}
type Query {
inner: MyOtherType
}
"""
def test_schema_introspect():
schema = Schema(Query)
assert '__schema' in schema.introspect()

View File

@ -13,7 +13,8 @@ from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case from ..utils.str_converters import to_camel_case
from .definitions import (GrapheneEnumType, GrapheneInputObjectType, from .definitions import (GrapheneEnumType, GrapheneInputObjectType,
GrapheneInterfaceType, GrapheneObjectType, GrapheneInterfaceType, GrapheneObjectType,
GrapheneScalarType, GrapheneUnionType) GrapheneScalarType, GrapheneUnionType,
GrapheneGraphQLType)
from .dynamic import Dynamic from .dynamic import Dynamic
from .enum import Enum from .enum import Enum
from .field import Field from .field import Field
@ -68,7 +69,7 @@ class TypeMap(GraphQLTypeMap):
return self.reducer(map, type.of_type) return self.reducer(map, type.of_type)
if type._meta.name in map: if type._meta.name in map:
_type = map[type._meta.name] _type = map[type._meta.name]
if is_graphene_type(_type): if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type assert _type.graphene_type == type
return map return map
if issubclass(type, ObjectType): if issubclass(type, ObjectType):
@ -127,7 +128,7 @@ class TypeMap(GraphQLTypeMap):
def construct_objecttype(self, map, type): def construct_objecttype(self, map, type):
if type._meta.name in map: if type._meta.name in map:
_type = map[type._meta.name] _type = map[type._meta.name]
if is_graphene_type(_type): if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type assert _type.graphene_type == type
return map return map
map[type._meta.name] = GrapheneObjectType( map[type._meta.name] = GrapheneObjectType(