From 912888935358c7823266936a0866654451098d7a Mon Sep 17 00:00:00 2001 From: Olivia Rodriguez Valdes Date: Thu, 11 Apr 2019 09:03:17 -0400 Subject: [PATCH] Add load-many support --- graphene_django/fields.py | 17 +++++++++----- graphene_django/tests/test_fields.py | 33 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 2de0659..c9b4942 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -171,17 +171,19 @@ class DjangoField(Field): class DataLoaderField(DjangoField): - """Class to manage access to dataloader when resolve the field""" + """Class to manage access to data-loader when resolve the field""" - def __init__(self, data_loader, source_loader, type, *args, **kwargs): + def __init__(self, type, data_loader, source_loader, load_many=False, *args, **kwargs): """ - Initialization of dataloader to resolve field - :param data_loader: dataloader to resolve field - :param source_loader: field to obtain the key for dataloading + Initialization of data-loader to resolve field + :param data_loader: data-loader to resolve field + :param source_loader: field to obtain the key for data-loading + :param load_many: Whether the resolver should try tu obtain one element or multiple elements :param kwargs: Extra arguments """ self.data_loader = data_loader self.source_loader = source_loader + self.load_many = load_many super(DataLoaderField, self).__init__(type, *args, **kwargs) @@ -191,7 +193,10 @@ class DataLoaderField(DjangoField): def resolver_data_loader(self, root, info, *args, **kwargs): """Resolve field through dataloader""" if root: - source_loader = getattr(root, self.source_loader) + source_loader = reduce(lambda x, y: getattr(x, y), self.source_loader.split('.'), root) else: source_loader = kwargs.get(self.source_loader) + + if self.load_many: + return self.data_loader.load_many(source_loader) return self.data_loader.load(source_loader) diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py index 8abd9b3..0978110 100644 --- a/graphene_django/tests/test_fields.py +++ b/graphene_django/tests/test_fields.py @@ -9,6 +9,11 @@ from promise import Promise class MyInstance(object): value = "value" key = 1 + keys = [1, 2, 3] + + class InnerClass(object): + key = 2 + keys = [4, 5, 6] def resolver(self): return "resolver method" @@ -62,6 +67,34 @@ class DataLoaderFieldTests(TestCase): self.assertEqual(resolver(instance, None).get(), instance.key) + def test_dataloaderfield_many(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='keys', type=MyType, load_many=True) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.keys) + + def test_dataloaderfield_inner_prop(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='InnerClass.key', type=MyType) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.InnerClass.key) + + def test_dataloaderfield_many_inner_prop(self): + MyType = object() + data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='InnerClass.keys', type=MyType, + load_many=True) + + resolver = data_loader_field.get_resolver(None) + instance = MyInstance() + + self.assertEqual(resolver(instance, None).get(), instance.InnerClass.keys) + def test_dataloaderfield_permissions(self): MyType = object() data_loader_field = DataLoaderField(data_loader=data_loader, source_loader='key', type=MyType,