diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 1da68d5..a5db45f 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -30,7 +30,7 @@ from graphql import GraphQLError, assert_valid_name from graphql.pyutils import register_description from .compat import ArrayField, HStoreField, JSONField, PGJSONField, RangeField -from .fields import DjangoListField, DjangoConnectionField +from .fields import DjangoListField, DjangoConnectionField, DjangoInstanceField from .settings import graphene_settings from .utils.str_converters import to_const @@ -297,10 +297,11 @@ def convert_field_to_djangomodel(field, registry=None): if not _type: return - return Field( + return DjangoInstanceField( _type, description=get_django_field_description(field), required=not field.null, + is_foreign_key=True, ) return Dynamic(dynamic_type) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 8d6e995..f51db60 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -246,3 +246,97 @@ class DjangoConnectionField(ConnectionField): def get_queryset_resolver(self): return self.resolve_queryset + + +class DjangoInstanceField(Field): + def __init__(self, _type, *args, **kwargs): + from .types import DjangoObjectType + + self.unique_fields = kwargs.pop("unique_fields", ("id",)) + self.is_foreign_key = kwargs.pop("is_foreign_key", False) + + assert not isinstance( + self.unique_fields, list + ), "unique_fields argument needs to be a list" + + if isinstance(_type, NonNull): + _type = _type.of_type + + super(DjangoInstanceField, self).__init__(_type, *args, **kwargs) + + assert issubclass( + self._underlying_type, DjangoObjectType + ), "DjangoInstanceField only accepts DjangoObjectType types" + + @property + def _underlying_type(self): + _type = self._type + while hasattr(_type, "of_type"): + _type = _type.of_type + return _type + + @property + def model(self): + return self._underlying_type._meta.model + + def get_manager(self): + return self.model._default_manager + + @staticmethod + def instance_resolver( + django_object_type, + unique_fields, + resolver, + default_manager, + is_foreign_key, + root, + info, + **args + ): + + queryset = None + unique_filter = {} + if is_foreign_key: + pk = getattr(root, "{}_id".format(info.field_name)) + if pk is not None: + unique_filter["pk"] = pk + unique_fields = () + else: + return None + else: + queryset = maybe_queryset(resolver(root, info, **args)) + + if queryset is None: + queryset = maybe_queryset(default_manager) + + if isinstance(queryset, QuerySet): + # Pass queryset to the DjangoObjectType get_queryset method + queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) + for field in unique_fields: + key = field if field != "id" else "pk" + value = args.get(field) + + if value is not None: + unique_filter[key] = value + + assert len(unique_filter.keys()) > 0, ( + "You need to model unique arguments. The declared unique fields are: {}." + ).format(", ".join(unique_fields)) + + try: + return queryset.get(**unique_filter) + except django_object_type._meta.model.DoesNotExist: + return None + + return queryset + + def wrap_resolve(self, parent_resolver): + resolver = super(DjangoInstanceField, self).wrap_resolve(parent_resolver) + return partial( + self.instance_resolver, + self._underlying_type, + self.unique_fields, + resolver, + self.get_manager(), + self.is_foreign_key, + ) diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 180acc5..deda018 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -13,6 +13,9 @@ class Person(models.Model): class Pet(models.Model): name = models.CharField(max_length=30) age = models.PositiveIntegerField() + owner = models.ForeignKey( + "Person", on_delete=models.CASCADE, null=True, blank=True, related_name="pets" + ) class FilmDetails(models.Model): @@ -91,8 +94,8 @@ class CNNReporter(Reporter): class Article(models.Model): headline = models.CharField(max_length=100) - pub_date = models.DateField() - pub_date_time = models.DateTimeField() + pub_date = models.DateField(auto_now_add=True) + pub_date_time = models.DateTimeField(auto_now_add=True) reporter = models.ForeignKey( Reporter, on_delete=models.CASCADE, related_name="articles" ) diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index f68470e..eca9d99 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -5,7 +5,7 @@ import pytest from graphene import List, NonNull, ObjectType, Schema, String -from ..fields import DjangoListField +from ..fields import DjangoListField, DjangoInstanceField from ..types import DjangoObjectType from .models import Article as ArticleModel from .models import Reporter as ReporterModel @@ -302,6 +302,149 @@ class TestDjangoListField: assert not result.errors assert result.data == {"reporters": [{"firstName": "Tara"}]} + def test_get_queryset_filter_instance(self): + """Resolving prefilter list to get instance""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporter = DjangoInstanceField( + Reporter, + unique_fields=("first_name",), + first_name=String(required=True), + ) + + schema = Schema(query=Query) + + query = """ + query { + reporter(firstName: "Tara") { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == {"reporter": {"firstName": "Tara"}} + + def test_get_queryset_filter_instance_null(self): + """Resolving prefilter list with no results""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporter = DjangoInstanceField( + Reporter, + unique_fields=("first_name",), + first_name=String(required=True), + ) + + schema = Schema(query=Query) + + query = """ + query { + reporter(firstName: "Debra") { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == {"reporter": None} + + def test_get_queryset_filter_instance_plain(self): + """Resolving a plain object should work (and not call get_queryset)""" + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + @classmethod + def get_queryset(cls, queryset, info): + # Only get reporters with at least 1 article + return queryset.annotate(article_count=Count("articles")).filter( + article_count__gt=0 + ) + + class Query(ObjectType): + reporter = DjangoInstanceField(Reporter, first_name=String(required=True)) + + def resolve_reporter(_, info, first_name): + return ReporterModel.objects.get(first_name=first_name) + + schema = Schema(query=Query) + + query = """ + query { + reporter(firstName: "Debra") { + firstName + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == {"reporter": {"firstName": "Debra"}} + def test_resolve_list(self): """Resolving a plain list should work (and not call get_queryset)""" diff --git a/graphene_django/tests/test_get_queryset.py b/graphene_django/tests/test_get_queryset.py new file mode 100644 index 0000000..03457e3 --- /dev/null +++ b/graphene_django/tests/test_get_queryset.py @@ -0,0 +1,345 @@ +import pytest + +import graphene +from graphene.relay import Node + +from graphql_relay import to_global_id + +from ..fields import DjangoConnectionField, DjangoInstanceField +from ..types import DjangoObjectType + +from .models import Article, Reporter + + +class TestShouldCallGetQuerySetOnForeignKey: + """ + Check that the get_queryset method is called in both forward and reversed direction + of a foreignkey on types. + (see issue #1111) + """ + + @pytest.fixture(autouse=True) + def setup_schema(self): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + + @classmethod + def get_queryset(cls, queryset, info): + if info.context and info.context.get("admin"): + return queryset + raise Exception("Not authorized to access reporters.") + + class ArticleType(DjangoObjectType): + class Meta: + model = Article + + @classmethod + def get_queryset(cls, queryset, info): + return queryset.exclude(headline__startswith="Draft") + + class Query(graphene.ObjectType): + reporter = DjangoInstanceField(ReporterType, id=graphene.ID(required=True)) + article = DjangoInstanceField(ArticleType, id=graphene.ID(required=True)) + + self.schema = graphene.Schema(query=Query) + + self.reporter = Reporter.objects.create( + first_name="Jane", last_name="Doe", pk=2 + ) + + self.articles = [ + Article.objects.create( + headline="A fantastic article", + reporter=self.reporter, + editor=self.reporter, + ), + Article.objects.create( + headline="Draft: My next best seller", + reporter=self.reporter, + editor=self.reporter, + ), + ] + + def test_get_queryset_called_on_field(self): + # If a user tries to access an article it is fine as long as it's not a draft one + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + } + } + """ + # Non-draft + result = self.schema.execute(query, variables={"id": self.articles[0].id}) + assert not result.errors + assert result.data["article"] == { + "headline": "A fantastic article", + } + # Draft + result = self.schema.execute(query, variables={"id": self.articles[1].id}) + assert not result.errors + assert result.data["article"] is None + + # If a non admin user tries to access a reporter they should get our authorization error + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + } + } + """ + + result = self.schema.execute(query, variables={"id": self.reporter.id}) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access reporters." + + # An admin user should be able to get reporters + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + } + } + """ + + result = self.schema.execute( + query, variables={"id": self.reporter.id}, context_value={"admin": True}, + ) + assert not result.errors + assert result.data == {"reporter": {"firstName": "Jane"}} + + def test_get_queryset_called_on_foreignkey(self): + # If a user tries to access a reporter through an article they should get our authorization error + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute(query, variables={"id": self.articles[0].id}) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access reporters." + + # An admin user should be able to get reporters through an article + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute( + query, variables={"id": self.articles[0].id}, context_value={"admin": True}, + ) + assert not result.errors + assert result.data["article"] == { + "headline": "A fantastic article", + "reporter": {"firstName": "Jane"}, + } + + # An admin user should not be able to access draft article through a reporter + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + articles { + headline + } + } + } + """ + + result = self.schema.execute( + query, variables={"id": self.reporter.id}, context_value={"admin": True}, + ) + assert not result.errors + assert result.data["reporter"] == { + "firstName": "Jane", + "articles": [{"headline": "A fantastic article"}], + } + + +class TestShouldCallGetQuerySetOnForeignKeyNode: + """ + Check that the get_queryset method is called in both forward and reversed direction + of a foreignkey on types using a node interface. + (see issue #1111) + """ + + @pytest.fixture(autouse=True) + def setup_schema(self): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_queryset(cls, queryset, info): + if info.context and info.context.get("admin"): + return queryset + raise Exception("Not authorized to access reporters.") + + class ArticleType(DjangoObjectType): + class Meta: + model = Article + interfaces = (Node,) + + @classmethod + def get_queryset(cls, queryset, info): + return queryset.exclude(headline__startswith="Draft") + + class Query(graphene.ObjectType): + reporter = Node.Field(ReporterType) + article = Node.Field(ArticleType) + + self.schema = graphene.Schema(query=Query) + + self.reporter = Reporter.objects.create(first_name="Jane", last_name="Doe") + + self.articles = [ + Article.objects.create( + headline="A fantastic article", + reporter=self.reporter, + editor=self.reporter, + ), + Article.objects.create( + headline="Draft: My next best seller", + reporter=self.reporter, + editor=self.reporter, + ), + ] + + def test_get_queryset_called_on_node(self): + # If a user tries to access an article it is fine as long as it's not a draft one + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + } + } + """ + # Non-draft + result = self.schema.execute( + query, variables={"id": to_global_id("ArticleType", self.articles[0].id)} + ) + assert not result.errors + assert result.data["article"] == { + "headline": "A fantastic article", + } + # Draft + result = self.schema.execute( + query, variables={"id": to_global_id("ArticleType", self.articles[1].id)} + ) + assert not result.errors + assert result.data["article"] is None + + # If a non admin user tries to access a reporter they should get our authorization error + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + } + } + """ + + result = self.schema.execute( + query, variables={"id": to_global_id("ReporterType", self.reporter.id)} + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access reporters." + + # An admin user should be able to get reporters + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + } + } + """ + + result = self.schema.execute( + query, + variables={"id": to_global_id("ReporterType", self.reporter.id)}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data == {"reporter": {"firstName": "Jane"}} + + def test_get_queryset_called_on_foreignkey(self): + # If a user tries to access a reporter through an article they should get our authorization error + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute( + query, variables={"id": to_global_id("ArticleType", self.articles[0].id)} + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "Not authorized to access reporters." + + # An admin user should be able to get reporters through an article + query = """ + query getArticle($id: ID!) { + article(id: $id) { + headline + reporter { + firstName + } + } + } + """ + + result = self.schema.execute( + query, + variables={"id": to_global_id("ArticleType", self.articles[0].id)}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data["article"] == { + "headline": "A fantastic article", + "reporter": {"firstName": "Jane"}, + } + + # An admin user should not be able to access draft article through a reporter + query = """ + query getReporter($id: ID!) { + reporter(id: $id) { + firstName + articles { + edges { + node { + headline + } + } + } + } + } + """ + + result = self.schema.execute( + query, + variables={"id": to_global_id("ReporterType", self.reporter.id)}, + context_value={"admin": True}, + ) + assert not result.errors + assert result.data["reporter"] == { + "firstName": "Jane", + "articles": {"edges": [{"node": {"headline": "A fantastic article"}}]}, + } diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 699814d..accb0a5 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -15,7 +15,7 @@ from ..compat import IntegerRangeField, MissingType from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..utils import DJANGO_FILTER_INSTALLED -from .models import Article, CNNReporter, Film, FilmDetails, Reporter +from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter def test_should_query_only_fields(): @@ -251,8 +251,8 @@ def test_should_node(): def test_should_query_onetoone_fields(): - film = Film(id=1) - film_details = FilmDetails(id=1, film=film) + film = Film.objects.create(id=1) + film_details = FilmDetails.objects.create(id=1, film=film) class FilmNode(DjangoObjectType): class Meta: @@ -1251,6 +1251,7 @@ class TestBackwardPagination: class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1455,6 +1456,7 @@ def test_connection_should_enable_offset_filtering(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1494,6 +1496,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit( class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1527,6 +1530,7 @@ def test_connection_should_forbid_offset_filtering_with_before(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1561,6 +1565,7 @@ def test_connection_should_allow_offset_filtering_with_after(): class Meta: model = Reporter interfaces = (Node,) + fields = "__all__" class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -1586,3 +1591,69 @@ def test_connection_should_allow_offset_filtering_with_after(): "allReporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe"}},]} } assert result.data == expected + + +def test_should_query_nullable_foreign_key(): + class PetType(DjangoObjectType): + class Meta: + model = Pet + fields = "__all__" + + class PersonType(DjangoObjectType): + class Meta: + model = Person + fields = "__all__" + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType, name=graphene.String(required=True)) + person = graphene.Field(PersonType, name=graphene.String(required=True)) + + def resolve_pet(self, info, name): + return Pet.objects.filter(name=name).first() + + def resolve_person(self, info, name): + return Person.objects.filter(name=name).first() + + schema = graphene.Schema(query=Query) + + person = Person.objects.create(name="Jane") + pets = [ + Pet.objects.create(name="Stray dog", age=1), + Pet.objects.create(name="Jane's dog", owner=person, age=1), + ] + + query_pet = """ + query getPet($name: String!) { + pet(name: $name) { + owner { + name + } + } + } + """ + result = schema.execute(query_pet, variables={"name": "Stray dog"}) + assert not result.errors + assert result.data["pet"] == { + "owner": None, + } + + result = schema.execute(query_pet, variables={"name": "Jane's dog"}) + assert not result.errors + assert result.data["pet"] == { + "owner": {"name": "Jane"}, + } + + query_owner = """ + query getOwner($name: String!) { + person(name: $name) { + pets { + name + } + } + } + """ + result = schema.execute(query_owner, variables={"name": "Jane"}) + assert not result.errors + assert result.data["person"] == { + "pets": [{"name": "Jane's dog"}], + }