diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 8066de3e..63211dc6 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -21,7 +21,8 @@ class Schema(GraphQLSchema): ''' def __init__(self, query=None, mutation=None, subscription=None, - directives=None, types=None, auto_camelcase=True): + directives=None, types=None, auto_camelcase=True, + resolvers=None): assert inspect.isclass(query) and issubclass(query, ObjectType), ( 'Schema query must be Object Type but got: {}.' ).format(query) @@ -41,6 +42,7 @@ class Schema(GraphQLSchema): directives ) self._directives = directives + self._resolvers = resolvers self.build_typemap() def get_query_type(self): @@ -102,4 +104,4 @@ class Schema(GraphQLSchema): ] if self.types: initial_types += self.types - self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self) + self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self, resolvers=self._resolvers) diff --git a/graphene/types/tests/test_schema.py b/graphene/types/tests/test_schema.py index af9bc14c..52ea3759 100644 --- a/graphene/types/tests/test_schema.py +++ b/graphene/types/tests/test_schema.py @@ -6,12 +6,12 @@ from ..scalars import String from ..field import Field -class MyOtherType(ObjectType): +class InnerType(ObjectType): field = String() class Query(ObjectType): - inner = Field(MyOtherType) + inner = Field(InnerType) def test_schema(): @@ -22,7 +22,7 @@ def test_schema(): def test_schema_get_type(): schema = Schema(Query) assert schema.Query == Query - assert schema.MyOtherType == MyOtherType + assert schema.InnerType == InnerType def test_schema_get_type_error(): @@ -39,12 +39,12 @@ def test_schema_str(): query: Query } -type MyOtherType { +type InnerType { field: String } type Query { - inner: MyOtherType + inner: InnerType } """ @@ -52,3 +52,26 @@ type Query { def test_schema_introspect(): schema = Schema(Query) assert '__schema' in schema.introspect() + + +def test_schema_external_resolution(): + class InnerTypeResolvers(object): + def resolve_field(self, args, context, info): + return self['key'] + + class QueryResolvers(object): + def resolve_inner(self, args, context, info): + return {'key': 'value'} + + schema = Schema(Query, resolvers={ + 'Query': QueryResolvers, + 'InnerType': InnerTypeResolvers, + }) + + result = schema.execute('{ inner { field } }') + assert not result.errors + assert result.data == { + 'inner': { + 'field': 'value' + } + } diff --git a/graphene/types/typemap.py b/graphene/types/typemap.py index 9dc17242..f86f2920 100644 --- a/graphene/types/typemap.py +++ b/graphene/types/typemap.py @@ -54,9 +54,10 @@ def resolve_type(resolve_type_func, map, type_name, root, context, info): class TypeMap(GraphQLTypeMap): - def __init__(self, types, auto_camelcase=True, schema=None): + def __init__(self, types, auto_camelcase=True, schema=None, resolvers=None): self.auto_camelcase = auto_camelcase self.schema = schema + self.resolvers = resolvers super(TypeMap, self).__init__(types) def reducer(self, map, type): @@ -245,10 +246,16 @@ class TypeMap(GraphQLTypeMap): fields[field_name] = _field return fields + def get_resolver_from_type(self, type, name): + if self.resolvers and type._meta.name in self.resolvers: + resolver_type = self.resolvers[type._meta.name] + return getattr(resolver_type, 'resolve_{}'.format(name), None) + return getattr(type, 'resolve_{}'.format(name), None) + def get_resolver_for_type(self, type, name, default_value): if not issubclass(type, ObjectType): return - resolver = getattr(type, 'resolve_{}'.format(name), None) + resolver = self.get_resolver_from_type(type, name) if not resolver: # If we don't find the resolver in the ObjectType class, then try to # find it in each of the interfaces @@ -256,7 +263,7 @@ class TypeMap(GraphQLTypeMap): for interface in type._meta.interfaces: if name not in interface._meta.fields: continue - interface_resolver = getattr(interface, 'resolve_{}'.format(name), None) + interface_resolver = self.get_resolver_from_type(interface, name) if interface_resolver: break resolver = interface_resolver