diff --git a/.gitignore b/.gitignore index 150025a..2d9d43c 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ __pycache__/ # Distribution / packaging .Python env/ +.env/ build/ develop-eggs/ dist/ diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 375d683..1454d53 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,5 +1,6 @@ +import inspect from collections import OrderedDict -from functools import singledispatch, wraps +from functools import partial, singledispatch, wraps from django.db import models from django.utils.encoding import force_str @@ -25,6 +26,7 @@ from graphene import ( ) from graphene.types.json import JSONString from graphene.types.scalars import BigInt +from graphene.types.resolver import get_default_resolver from graphene.utils.str_converters import to_camel_case from graphql import GraphQLError @@ -258,6 +260,9 @@ def convert_time_to_string(field, registry=None): @convert_django_field.register(models.OneToOneRel) def convert_onetoone_field_to_djangomodel(field, registry=None): + from graphene.utils.str_converters import to_snake_case + from .types import DjangoObjectType + model = field.related_model def dynamic_type(): @@ -265,7 +270,69 @@ def convert_onetoone_field_to_djangomodel(field, registry=None): if not _type: return - return Field(_type, required=not field.null) + class CustomField(Field): + def wrap_resolve(self, parent_resolver): + """ + Implements a custom resolver which go through the `get_node` method to insure that + it goes through the `get_queryset` method of the DjangoObjectType. + """ + resolver = super().wrap_resolve(parent_resolver) + + # If `get_queryset` was not overridden in the DjangoObjectType, + # we can just return the default resolver. + if ( + _type.get_queryset.__func__ + is DjangoObjectType.get_queryset.__func__ + ): + return resolver + + def custom_resolver(root, info, **args): + # Note: this function is used to resolve 1:1 relation fields + # it does not differentiate between custom-resolved fields + # and default resolved fields. + + is_resolver_awaitable = inspect.iscoroutinefunction(resolver) + + if is_resolver_awaitable: + fk_obj = resolver(root, info, **args) + # In case the resolver is a custom awaitable resolver that overwrites + # the default Django resolver + return fk_obj + + object_pk = resolver(root, info, **args).pk + instance_from_get_node = _type.get_node(info, object_pk) + + if instance_from_get_node is None: + # no instance to return + return + elif ( + isinstance(resolver, partial) + and resolver.func is get_default_resolver() + ): + return instance_from_get_node + elif resolver is not get_default_resolver(): + # Default resolver is overridden + # For optimization, add the instance to the resolver + field_name = to_snake_case(info.field_name) + setattr(root, field_name, instance_from_get_node) + # Explanation: + # previously, _type.get_node` is called which results in at least one hit to the database. + # But, if we did not pass the instance to the root, calling the resolver will result in + # another call to get the instance which results in at least two database queries in total + # to resolve this node only. + # That's why the value of the object is set in the root so when the object is accessed + # in the resolver (root.field_name) it does not access the database unless queried explicitly. + fk_obj = resolver(root, info, **args) + return fk_obj + else: + return instance_from_get_node + + return custom_resolver + + return CustomField( + _type, + required=not field.null, + ) return Dynamic(dynamic_type) @@ -313,6 +380,9 @@ def convert_field_to_list_or_connection(field, registry=None): @convert_django_field.register(models.OneToOneField) @convert_django_field.register(models.ForeignKey) def convert_field_to_djangomodel(field, registry=None): + from graphene.utils.str_converters import to_snake_case + from .types import DjangoObjectType + model = field.related_model def dynamic_type(): @@ -320,7 +390,77 @@ def convert_field_to_djangomodel(field, registry=None): if not _type: return - return Field( + class CustomField(Field): + def wrap_resolve(self, parent_resolver): + """ + Implements a custom resolver which go through the `get_node` method to ensure that + it goes through the `get_queryset` method of the DjangoObjectType. + """ + resolver = super().wrap_resolve(parent_resolver) + + # If `get_queryset` was not overridden in the DjangoObjectType, + # we can just return the default resolver. + if ( + _type.get_queryset.__func__ + is DjangoObjectType.get_queryset.__func__ + ): + return resolver + + def custom_resolver(root, info, **args): + # Note: this function is used to resolve FK or 1:1 fields + # it does not differentiate between custom-resolved fields + # and default resolved fields. + + # because this is a django foreign key or one-to-one field, the primary-key for + # this node can be accessed from the root node. + # ex: article.reporter_id + + # get the name of the id field from the root's model + field_name = to_snake_case(info.field_name) + db_field_key = root.__class__._meta.get_field(field_name).attname + if hasattr(root, db_field_key): + # get the object's primary-key from root + object_pk = getattr(root, db_field_key) + else: + return None + + is_resolver_awaitable = inspect.iscoroutinefunction(resolver) + + if is_resolver_awaitable: + fk_obj = resolver(root, info, **args) + # In case the resolver is a custom awaitable resolver that overwrites + # the default Django resolver + return fk_obj + + instance_from_get_node = _type.get_node(info, object_pk) + + if instance_from_get_node is None: + # no instance to return + return + elif ( + isinstance(resolver, partial) + and resolver.func is get_default_resolver() + ): + return instance_from_get_node + elif resolver is not get_default_resolver(): + # Default resolver is overridden + # For optimization, add the instance to the resolver + setattr(root, field_name, instance_from_get_node) + # Explanation: + # previously, _type.get_node` is called which results in at least one hit to the database. + # But, if we did not pass the instance to the root, calling the resolver will result in + # another call to get the instance which results in at least two database queries in total + # to resolve this node only. + # That's why the value of the object is set in the root so when the object is accessed + # in the resolver (root.field_name) it does not access the database unless queried explicitly. + fk_obj = resolver(root, info, **args) + return fk_obj + else: + return instance_from_get_node + + return custom_resolver + + return CustomField( _type, description=get_django_field_description(field), required=not field.null, diff --git a/graphene_django/tests/test_get_queryset.py b/graphene_django/tests/test_get_queryset.py index 7cbaa54..44d8142 100644 --- a/graphene_django/tests/test_get_queryset.py +++ b/graphene_django/tests/test_get_queryset.py @@ -8,7 +8,7 @@ from graphql_relay import to_global_id from ..fields import DjangoConnectionField from ..types import DjangoObjectType -from .models import Article, Reporter +from .models import Article, Reporter, FilmDetails, Film class TestShouldCallGetQuerySetOnForeignKey: @@ -127,6 +127,69 @@ class TestShouldCallGetQuerySetOnForeignKey: assert not result.errors assert result.data == {"reporter": {"firstName": "Jane"}} + def test_get_queryset_called_on_foreignkey(self): + # If a user tries to access a reporter through an article they should get our authorization error + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute(query, variables={"id": self.articles[0].id}) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access reporters." + + # An admin user should be able to get reporters through an article + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute( + query, + variables={"id": self.articles[0].id}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data["article"] == { + "headline": "A fantastic article", + "reporter": {"firstName": "Jane"}, + } + + # An admin user should not be able to access draft article through a reporter + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + articles { + headline + } + } + } + """ + + result = self.schema.execute( + query, + variables={"id": self.reporter.id}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data["reporter"] == { + "firstName": "Jane", + "articles": [{"headline": "A fantastic article"}], + } + class TestShouldCallGetQuerySetOnForeignKeyNode: """ @@ -233,3 +296,268 @@ class TestShouldCallGetQuerySetOnForeignKeyNode: ) assert not result.errors assert result.data == {"reporter": {"firstName": "Jane"}} + + def test_get_queryset_called_on_foreignkey(self): + # If a user tries to access a reporter through an article they should get our authorization error + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute( + query, variables={"id": to_global_id("ArticleType", self.articles[0].id)} + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access reporters." + + # An admin user should be able to get reporters through an article + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute( + query, + variables={"id": to_global_id("ArticleType", self.articles[0].id)}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data["article"] == { + "headline": "A fantastic article", + "reporter": {"firstName": "Jane"}, + } + + # An admin user should not be able to access draft article through a reporter + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + articles { + edges { + node { + headline + } + } + } + } + } + """ + + result = self.schema.execute( + query, + variables={"id": to_global_id("ReporterType", self.reporter.id)}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data["reporter"] == { + "firstName": "Jane", + "articles": {"edges": [{"node": {"headline": "A fantastic article"}}]}, + } + + +class TestShouldCallGetQuerySetOnOneToOne: + @pytest.fixture(autouse=True) + def setup_schema(self): + class FilmDetailsType(DjangoObjectType): + class Meta: + model = FilmDetails + + @classmethod + def get_queryset(cls, queryset, info): + if info.context and info.context.get("permission_get_film_details"): + return queryset + raise Exception("Not authorized to access film details.") + + class FilmType(DjangoObjectType): + class Meta: + model = Film + + @classmethod + def get_queryset(cls, queryset, info): + if info.context and info.context.get("permission_get_film"): + return queryset + raise Exception("Not authorized to access film.") + + class Query(graphene.ObjectType): + film_details = graphene.Field( + FilmDetailsType, id=graphene.ID(required=True) + ) + film = graphene.Field(FilmType, id=graphene.ID(required=True)) + + def resolve_film_details(self, info, id): + return ( + FilmDetailsType.get_queryset(FilmDetails.objects, info) + .filter(id=id) + .last() + ) + + def resolve_film(self, info, id): + return FilmType.get_queryset(Film.objects, info).filter(id=id).last() + + self.schema = graphene.Schema(query=Query) + + self.films = [ + Film.objects.create( + genre="do", + ), + Film.objects.create( + genre="ac", + ), + ] + + self.film_details = [ + FilmDetails.objects.create( + film=self.films[0], + ), + FilmDetails.objects.create( + film=self.films[1], + ), + ] + + def test_get_queryset_called_on_field(self): + # A user tries to access a film + query = """ + query getFilm($id: ID!) { + film(id: $id) { + genre + } + } + """ + + # With `permission_get_film` + result = self.schema.execute( + query, + variables={"id": self.films[0].id}, + context_value={"permission_get_film": True}, + ) + assert not result.errors + assert result.data["film"] == { + "genre": "DO", + } + + # Without `permission_get_film` + result = self.schema.execute( + query, + variables={"id": self.films[1].id}, + context_value={"permission_get_film": False}, + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access film." + + # A user tries to access a film details + query = """ + query getFilmDetails($id: ID!) { + filmDetails(id: $id) { + location + } + } + """ + + # With `permission_get_film` + result = self.schema.execute( + query, + variables={"id": self.film_details[0].id}, + context_value={"permission_get_film_details": True}, + ) + assert not result.errors + assert result.data == {"filmDetails": {"location": ""}} + + # Without `permission_get_film` + result = self.schema.execute( + query, + variables={"id": self.film_details[0].id}, + context_value={"permission_get_film_details": False}, + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access film details." + + def test_get_queryset_called_on_foreignkey(self): + # A user tries to access a film details through a film + query = """ + query getFilm($id: ID!) { + film(id: $id) { + genre + details { + location + } + } + } + """ + + # With `permission_get_film_details` + result = self.schema.execute( + query, + variables={"id": self.films[0].id}, + context_value={ + "permission_get_film": True, + "permission_get_film_details": True, + }, + ) + assert not result.errors + assert result.data["film"] == { + "genre": "DO", + "details": {"location": ""}, + } + + # Without `permission_get_film_details` + result = self.schema.execute( + query, + variables={"id": self.films[0].id}, + context_value={ + "permission_get_film": True, + "permission_get_film_details": False, + }, + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access film details." + + # A user tries to access a film through a film details + query = """ + query getFilmDetails($id: ID!) { + filmDetails(id: $id) { + location + film { + genre + } + } + } + """ + + # With `permission_get_film` + result = self.schema.execute( + query, + variables={"id": self.film_details[0].id}, + context_value={ + "permission_get_film": True, + "permission_get_film_details": True, + }, + ) + assert not result.errors + assert result.data["filmDetails"] == { + "location": "", + "film": {"genre": "DO"}, + } + + # Without `permission_get_film` + result = self.schema.execute( + query, + variables={"id": self.film_details[1].id}, + context_value={ + "permission_get_film": False, + "permission_get_film_details": True, + }, + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access film."