diff --git a/graphene_django/fields.py b/graphene_django/fields.py index e3129c6..6383258 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,13 +1,11 @@ from functools import partial from django.db.models.query import QuerySet - -from promise import Promise - from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo from graphene.utils.get_unbound_function import get_unbound_function from graphql_relay.connection.arrayconnection import connection_from_list_slice +from promise import Promise from .settings import graphene_settings from .utils import maybe_queryset, auth_resolver @@ -170,3 +168,37 @@ class DjangoField(Field): return partial(get_unbound_function(self.permissions_resolver), parent_resolver, self.permissions, None, None, True) return parent_resolver + + +class DataLoaderField(DjangoField): + """Class to manage access to data-loader when resolve the field""" + + def __init__(self, type, data_loader, source_loader, load_many=False, *args, **kwargs): + """ + Initialization of data-loader to resolve field + :param data_loader: data-loader to resolve field + :param source_loader: field to obtain the key for data-loading + :param load_many: Whether the resolver should try tu obtain one element or multiple elements + :param kwargs: Extra arguments + """ + self.data_loader = data_loader + self.source_loader = source_loader + self.load_many = load_many + + super(DataLoaderField, self).__init__(type, *args, **kwargs) + + # If no resolver is explicitly provided, use dataloader + self.resolver = self.resolver or self.resolver_data_loader + + def resolver_data_loader(self, root, info, *args, **kwargs): + """Resolve field through dataloader""" + if root: + source_loader = reduce(lambda x, y: getattr(x, y), self.source_loader.split('.'), root) + else: + source_loader = kwargs.get(self.source_loader) + + if self.load_many: + return self.data_loader.load_many(source_loader) + if source_loader: + return self.data_loader.load(source_loader) + return None diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index e0478bd..0978110 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -1,15 +1,31 @@ +from mock import mock from unittest import TestCase from django.core.exceptions import PermissionDenied -from graphene_django.fields import DjangoField +from graphene_django.fields import DjangoField, DataLoaderField +from promise.dataloader import DataLoader +from promise import Promise class MyInstance(object): value = "value" + key = 1 + keys = [1, 2, 3] + + class InnerClass(object): + key = 2 + keys = [4, 5, 6] def resolver(self): return "resolver method" +def batch_load_fn(keys): + return Promise.all(keys) + + +data_loader = DataLoader(batch_load_fn=batch_load_fn) + + class PermissionFieldTests(TestCase): def test_permission_field(self): @@ -21,12 +37,9 @@ class PermissionFieldTests(TestCase): def has_perm(self, perm): return perm == 'perm2' - class Info(object): - class Context(object): - user = Viewer() - context = Context() + info = mock.Mock(context=mock.Mock(user=Viewer())) - self.assertEqual(resolver(MyInstance(), Info()), MyInstance().resolver()) + self.assertEqual(resolver(MyInstance(), info), MyInstance().resolver()) def test_permission_field_without_permission(self): MyType = object() @@ -37,10 +50,79 @@ class PermissionFieldTests(TestCase): def has_perm(self, perm): return False - class Info(object): - class Context(object): - user = Viewer() - context = Context() + info = mock.Mock(context=mock.Mock(user=Viewer())) with self.assertRaises(PermissionDenied): - resolver(MyInstance(), Info()) + resolver(MyInstance(), info) + + +class DataLoaderFieldTests(TestCase): + + def test_dataloaderfield(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='key', type=MyType) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.key) + + def test_dataloaderfield_many(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='keys', type=MyType, load_many=True) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.keys) + + def test_dataloaderfield_inner_prop(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='InnerClass.key', type=MyType) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.InnerClass.key) + + def test_dataloaderfield_many_inner_prop(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='InnerClass.keys', type=MyType, + load_many=True) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.InnerClass.keys) + + def test_dataloaderfield_permissions(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='key', type=MyType, + permissions=['perm1', 'perm2']) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + class Viewer(object): + def has_perm(self, perm): + return perm == 'perm2' + + info = mock.Mock(context=mock.Mock(user=Viewer())) + + self.assertEqual(resolver(instance, info).get(), instance.key) + + def test_dataloaderfield_without_permissions(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='key', type=MyType, + permissions=['perm1', 'perm2']) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + class Viewer(object): + def has_perm(self, perm): + return False + + info = mock.Mock(context=mock.Mock(user=Viewer())) + with self.assertRaises(PermissionDenied): + resolver(instance, info)