From 40b88bc87bc9fd276ba74613c49706cef500d9dd Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Tue, 13 Oct 2015 19:16:42 -0700 Subject: [PATCH] Added support for resolver tagging. Fixed #6 --- graphene/core/fields.py | 37 +++++++++++++++++++++++++++++-------- graphene/core/types.py | 7 ------- tests/core/test_types.py | 12 ++++++++++++ 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/graphene/core/fields.py b/graphene/core/fields.py index 9c883567..2ce6cb3b 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -1,6 +1,6 @@ import inspect import six -from functools import total_ordering +from functools import total_ordering, wraps from graphql.core.type import ( GraphQLField, GraphQLList, @@ -49,12 +49,26 @@ class Field(object): cls._meta.add_field(self) def resolve(self, instance, args, info): - if self.resolve_fn: - resolve_fn = self.resolve_fn + resolve_fn = self.get_resolve_fn() + if resolve_fn: + return resolve_fn(instance, args, info) else: - resolve_fn = lambda root, args, info: root.resolve( - self.field_name, args, info) - return resolve_fn(instance, args, info) + return instance.get_field(self.field_name) + + @memoize + def get_resolve_fn(self): + if self.resolve_fn: + return self.resolve_fn + else: + custom_resolve_fn_name = 'resolve_%s' % self.field_name + if hasattr(self.object_type, custom_resolve_fn_name): + resolve_fn = getattr(self.object_type, custom_resolve_fn_name) + + @wraps(resolve_fn) + def custom_resolve_fn(instance, args, info): + custom_fn = getattr(instance, custom_resolve_fn_name) + return custom_fn(args, info) + return custom_resolve_fn def get_object_type(self, schema): field_type = self.field_type @@ -110,11 +124,18 @@ class Field(object): if not internal_type: raise Exception("Internal type for field %s is None" % self) + resolve_fn = self.get_resolve_fn() + if resolve_fn: + @wraps(resolve_fn) + def resolver(*args): + return self.resolve(*args) + else: + resolver = self.resolve return GraphQLField( internal_type, description=self.description, args=self.args, - resolver=self.resolve, + resolver=resolver, ) def __str__(self): @@ -144,7 +165,7 @@ class Field(object): return NotImplemented def __hash__(self): - return hash(self.creation_counter) + return hash((self.creation_counter, self.object_type)) def __copy__(self): # We need to avoid hitting __reduce__, so define this diff --git a/graphene/core/types.py b/graphene/core/types.py index 636d9ede..ff238705 100644 --- a/graphene/core/types.py +++ b/graphene/core/types.py @@ -132,13 +132,6 @@ class BaseObjectType(object): def get_field(self, field): return getattr(self.instance, field, None) - def resolve(self, field_name, args, info): - custom_resolve_fn = 'resolve_%s' % field_name - if hasattr(self, custom_resolve_fn): - resolve_fn = getattr(self, custom_resolve_fn) - return resolve_fn(args, info) - return self.get_field(field_name) - @classmethod def resolve_type(cls, schema, instance, *_): return instance.internal_type(schema) diff --git a/tests/core/test_types.py b/tests/core/test_types.py index 7ca75d32..06628847 100644 --- a/tests/core/test_types.py +++ b/tests/core/test_types.py @@ -69,3 +69,15 @@ def test_field_clashes(): class Droid(Character): name = IntField() assert 'clashes' in str(excinfo.value) + + +def test_field_mantain_resolver_tags(): + class Droid(Character): + name = StringField() + + def resolve_name(self, *args): + return 'My Droid' + resolve_name.custom_tag = True + + field = Droid._meta.fields_map['name'].internal_field(schema) + assert field.resolver.custom_tag