diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 94f0fe16..5228fb44 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -311,20 +311,17 @@ class TypeMap(dict): if isinstance(arg.type, NonNull) else arg.default_value, ) - resolve = field.get_resolver( - self.get_resolver(graphene_type, name, field.default_value) - ) _field = GraphQLField( field_type, args=args, resolve=field.get_resolver( self.get_resolver_for_type( - type_, "resolve_{}", name, field.default_value + graphene_type, "resolve_{}", name, field.default_value ) ), subscribe=field.get_resolver( self.get_resolver_for_type( - type_, "subscribe_{}", name, field.default_value + graphene_type, "subscribe_{}", name, field.default_value ) ), deprecation_reason=field.deprecation_reason, @@ -334,11 +331,11 @@ class TypeMap(dict): fields[field_name] = _field return fields - def get_resolver_for_type(self, type_, pattern, name, default_value): - if not issubclass(type_, ObjectType): + def get_resolver_for_type(self, graphene_type, pattern, name, default_value): + if not issubclass(graphene_type, ObjectType): return func_name = pattern.format(name) - resolver = getattr(type_, func_name, None) + resolver = getattr(graphene_type, func_name, None) if not resolver: # If we don't find the resolver in the ObjectType class, then try to # find it in each of the interfaces @@ -360,6 +357,48 @@ class TypeMap(dict): ) return partial(default_resolver, name, default_value) + def resolve_type(self, resolve_type_func, type_name, root, info, _type): + type_ = resolve_type_func(root, info) + + if not type_: + return_type = self[type_name] + return default_type_resolver(root, info, return_type) + + if inspect.isclass(type_) and issubclass(type_, ObjectType): + graphql_type = self.get(type_._meta.name) + assert graphql_type, "Can't find type {} in schema".format(type_._meta.name) + assert graphql_type.graphene_type == type_, ( + "The type {} does not match with the associated graphene type {}." + ).format(type_, graphql_type.graphene_type) + return graphql_type + + return type_ + + def get_resolver(self, graphene_type, name, default_value): + if not issubclass(graphene_type, ObjectType): + return + resolver = getattr(graphene_type, "resolve_{}".format(name), None) + if not resolver: + # If we don't find the resolver in the ObjectType class, then try to + # find it in each of the interfaces + interface_resolver = None + for interface in graphene_type._meta.interfaces: + if name not in interface._meta.fields: + continue + interface_resolver = getattr(interface, "resolve_{}".format(name), None) + if interface_resolver: + break + resolver = interface_resolver + + # Only if is not decorated with classmethod + if resolver: + return get_unbound_function(resolver) + + default_resolver = ( + graphene_type._meta.default_resolver or get_default_resolver() + ) + return partial(default_resolver, name, default_value) + class Schema: """Schema Definition.