diff --git a/graphene/core/fields.py b/graphene/core/fields.py index bdff8cc1..5bbbb5a1 100644 --- a/graphene/core/fields.py +++ b/graphene/core/fields.py @@ -1,6 +1,6 @@ import inspect import six -from functools import total_ordering +from functools import total_ordering, wraps from graphql.core.type import ( GraphQLField, GraphQLList, @@ -49,12 +49,26 @@ class Field(object): cls._meta.add_field(self) def resolve(self, instance, args, info): - if self.resolve_fn: - resolve_fn = self.resolve_fn + resolve_fn = self.get_resolve_fn() + if resolve_fn: + return resolve_fn(instance, args, info) else: - resolve_fn = lambda root, args, info: root.resolve( - self.field_name, args, info) - return resolve_fn(instance, args, info) + return instance.get_field(self.field_name) + + @memoize + def get_resolve_fn(self): + if self.resolve_fn: + return self.resolve_fn + else: + custom_resolve_fn_name = 'resolve_%s' % self.field_name + if hasattr(self.object_type, custom_resolve_fn_name): + resolve_fn = getattr(self.object_type, custom_resolve_fn_name) + + @wraps(resolve_fn) + def custom_resolve_fn(instance, args, info): + custom_fn = getattr(instance, custom_resolve_fn_name) + return custom_fn(args, info) + return custom_resolve_fn def get_object_type(self, schema): field_type = self.field_type @@ -110,11 +124,18 @@ class Field(object): if not internal_type: raise Exception("Internal type for field %s is None" % self) + resolve_fn = self.get_resolve_fn() + if resolve_fn: + @wraps(resolve_fn) + def resolver(*args): + return self.resolve(*args) + else: + resolver = self.resolve return GraphQLField( internal_type, description=self.description, args=self.args, - resolver=self.resolve, + resolver=resolver, ) def __str__(self): diff --git a/graphene/core/types.py b/graphene/core/types.py index 46560b4b..c49de4f4 100644 --- a/graphene/core/types.py +++ b/graphene/core/types.py @@ -132,13 +132,6 @@ class BaseObjectType(object): def get_field(self, field): return getattr(self.instance, field, None) - def resolve(self, field_name, args, info): - custom_resolve_fn = 'resolve_%s' % field_name - if hasattr(self, custom_resolve_fn): - resolve_fn = getattr(self, custom_resolve_fn) - return resolve_fn(args, info) - return self.get_field(field_name) - @classmethod def resolve_objecttype(cls, schema, instance, *_): return instance diff --git a/tests/core/test_query.py b/tests/core/test_query.py index 66269e01..473bc3d4 100644 --- a/tests/core/test_query.py +++ b/tests/core/test_query.py @@ -46,6 +46,8 @@ schema = object() Human_type = Human.internal_type(schema) +def test_type(): + assert Human._meta.fields_map['name'].resolve(Human(object()), 1, 2) == 'Peter' def test_query(): schema = GraphQLSchema(query=Human_type) diff --git a/tests/core/test_types.py b/tests/core/test_types.py index 9ce272e4..b9414131 100644 --- a/tests/core/test_types.py +++ b/tests/core/test_types.py @@ -1,24 +1,24 @@ from py.test import raises -from collections import namedtuple from pytest import raises from graphene.core.fields import ( - Field, IntField, StringField, ) +from graphql.core.execution.middlewares.utils import ( + tag_resolver, + resolver_has_tag +) from graphql.core.type import ( GraphQLObjectType, GraphQLInterfaceType ) from graphene.core.types import ( - Interface, - ObjectType + Interface ) class Character(Interface): - '''Character description''' name = StringField() @@ -27,7 +27,6 @@ class Character(Interface): class Human(Character): - '''Human description''' friends = StringField() @@ -70,8 +69,22 @@ def test_field_clashes(): with raises(Exception) as excinfo: class Droid(Character): name = IntField() + assert 'clashes' in str(excinfo.value) def test_fields_inherited_should_be_different(): assert Character._meta.fields_map['name'] != Human._meta.fields_map['name'] + + +def test_field_mantain_resolver_tags(): + class Droid(Character): + name = StringField() + + def resolve_name(self, *args): + return 'My Droid' + + tag_resolver(resolve_name, 'test') + + field = Droid._meta.fields_map['name'].internal_field(schema) + assert resolver_has_tag(field.resolver, 'test')