diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 386103a..77839dd 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,3 +1,5 @@ +import functools +import inspect from collections import OrderedDict from functools import singledispatch, wraps @@ -25,6 +27,7 @@ from graphene import ( ) from graphene.types.json import JSONString from graphene.types.scalars import BigInt +from graphene.types.resolver import get_default_resolver from graphene.utils.str_converters import to_camel_case from graphql import GraphQLError @@ -324,13 +327,54 @@ def convert_field_to_djangomodel(field, registry=None): resolver = super().wrap_resolve(parent_resolver) def custom_resolver(root, info, **args): - fk_obj = resolver(root, info, **args) - if not isinstance(fk_obj, model): - # In case the resolver is a custom one that overwrites + # Note: this function is used to resolve FK or 1:1 fields + # it does not differentiate between custom-resolved fields + # and default resolved fields. + + # because this is a django foreign key or one-to-one field, the primary-key for + # this node can be accessed from the root node. + # ex: article.reporter_id + from graphene.utils.str_converters import to_snake_case + + # get the name of the id field from the root's model + field_name = to_snake_case(info.field_name) + db_field_key = root.__class__._meta.get_field(field_name).attname + if hasattr(root, db_field_key): + # get the object's primary-key from root + object_pk = getattr(root, db_field_key) + else: + return None + + is_resolver_awaitable = inspect.iscoroutinefunction(resolver) + + if is_resolver_awaitable: + fk_obj = resolver(root, info, **args) + # In case the resolver is a custom awaitable resolver overwrites # the default Django resolver - # This happens, for example, when using custom awaitable resolvers. return fk_obj - return _type.get_node(info, fk_obj.pk) + + instance_from_get_node = _type.get_node(info, object_pk) + + if instance_from_get_node is None: + # no instance to return + return + elif isinstance(resolver, functools.partial) and resolver.func is get_default_resolver(): + return instance_from_get_node + elif resolver is not get_default_resolver(): + # Default resolver is overriden + # For optimization, add the instance to the resolver + setattr(root, field_name, instance_from_get_node) + # Explanation: + # previously, _type.get_node` is called which results in at least one hit to the database. + # But, if we did not pass the instance to the root, calling the resolver will result in + # another call to get the instance which results in at least two database queries in total + # to resolve this node only. + # That's why the value of the object is set in the root so when the object is accessed + # in the resolver (root.field_name) it does not access the database unless queried explicitly. + fk_obj = resolver(root, info, **args) + return fk_obj + else: + return instance_from_get_node return custom_resolver