diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 0c62f28..406d184 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -22,6 +22,9 @@ class Film(models.Model): reporters = models.ManyToManyField('Reporter', related_name='films') +class DoeReporterManager(models.Manager): + def get_queryset(self): + return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe") class Reporter(models.Model): first_name = models.CharField(max_length=30) @@ -29,6 +32,8 @@ class Reporter(models.Model): email = models.EmailField() pets = models.ManyToManyField('self') a_choice = models.CharField(max_length=30, choices=CHOICES) + objects = models.Manager() + doe_objects = DoeReporterManager() def __str__(self): # __unicode__ on Python 2 return "%s %s" % (self.first_name, self.last_name) diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index a785a49..c4c26f5 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -743,6 +743,61 @@ def test_should_query_connectionfields_with_last(): assert not result.errors assert result.data == expected +def test_should_query_connectionfields_with_manager(): + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + + r = Reporter.objects.create( + first_name='John', + last_name='NotDoe', + email='johndoe@example.com', + a_choice=1 + ) + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType, on='doe_objects') + + def resolve_all_reporters(self, info, **args): + return Reporter.objects.all() + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterLastQuery { + allReporters(first: 2) { + edges { + node { + id + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=' + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + def test_should_query_dataloader_fields(): from promise import Promise