diff --git a/graphene/utils/annotate.py b/graphene/utils/annotate.py index a02677ca..c58c7544 100644 --- a/graphene/utils/annotate.py +++ b/graphene/utils/annotate.py @@ -31,5 +31,6 @@ def annotate(_func=None, _trigger_warning=True, **annotations): _func.__annotations__ = annotations else: _func.__annotations__.update(annotations) - + + _func._is_annotated = True return _func diff --git a/graphene/utils/auto_resolver.py b/graphene/utils/auto_resolver.py new file mode 100644 index 00000000..6308271d --- /dev/null +++ b/graphene/utils/auto_resolver.py @@ -0,0 +1,12 @@ +from .resolver_from_annotations import resolver_from_annotations + + +def auto_resolver(func=None): + annotations = getattr(func, '__annotations__', {}) + is_annotated = getattr(func, '_is_annotated', False) + + if annotations or is_annotated: + # Is a Graphene 2.0 resolver function + return resolver_from_annotations(func) + else: + return func diff --git a/graphene/utils/annotated_resolver.py b/graphene/utils/resolver_from_annotations.py similarity index 96% rename from graphene/utils/annotated_resolver.py rename to graphene/utils/resolver_from_annotations.py index 277ac94d..d1c6fb94 100644 --- a/graphene/utils/annotated_resolver.py +++ b/graphene/utils/resolver_from_annotations.py @@ -4,7 +4,7 @@ from functools import wraps from ..types import Context, ResolveInfo -def annotated_resolver(func): +def resolver_from_annotations(func): func_signature = signature(func) _context_var = None diff --git a/graphene/utils/tests/test_auto_resolver.py b/graphene/utils/tests/test_auto_resolver.py new file mode 100644 index 00000000..93b7a991 --- /dev/null +++ b/graphene/utils/tests/test_auto_resolver.py @@ -0,0 +1,36 @@ +import pytest +from ..annotate import annotate +from ..auto_resolver import auto_resolver + +from ...types import Context, ResolveInfo + + +def resolver(root, args, context, info): + return root, args, context, info + + +@annotate +def resolver_annotated(root, **args): + return root, args, None, None + + +@annotate(context=Context, info=ResolveInfo) +def resolver_with_context_and_info(root, context, info, **args): + return root, args, context, info + + +def test_auto_resolver_non_annotated(): + decorated_resolver = auto_resolver(resolver) + # We make sure the function is not wrapped + assert decorated_resolver == resolver + assert decorated_resolver(1, {}, 2, 3) == (1, {}, 2, 3) + + +def test_auto_resolver_annotated(): + decorated_resolver = auto_resolver(resolver_annotated) + assert decorated_resolver(1, {}, 2, 3) == (1, {}, None, None) + + +def test_auto_resolver_annotated_with_context_and_info(): + decorated_resolver = auto_resolver(resolver_with_context_and_info) + assert decorated_resolver(1, {}, 2, 3) == (1, {}, 2, 3) diff --git a/graphene/utils/tests/test_annotated_resolver.py b/graphene/utils/tests/test_resolver_from_annotations.py similarity index 83% rename from graphene/utils/tests/test_annotated_resolver.py rename to graphene/utils/tests/test_resolver_from_annotations.py index 7bb5f3ba..226b387c 100644 --- a/graphene/utils/tests/test_annotated_resolver.py +++ b/graphene/utils/tests/test_resolver_from_annotations.py @@ -1,21 +1,25 @@ import pytest from ..annotate import annotate -from ..annotated_resolver import annotated_resolver +from ..resolver_from_annotations import resolver_from_annotations from ...types import Context, ResolveInfo + @annotate def func(root, **args): return root, args, None, None + @annotate(context=Context) def func_with_context(root, context, **args): return root, args, context, None + @annotate(info=ResolveInfo) def func_with_info(root, info, **args): return root, args, None, info + @annotate(context=Context, info=ResolveInfo) def func_with_context_and_info(root, context, info, **args): return root, args, context, info @@ -27,13 +31,14 @@ args = { context = 2 info = 3 + @pytest.mark.parametrize("func,expected", [ (func, (1, {'arg': 0}, None, None)), (func_with_context, (1, {'arg': 0}, 2, None)), (func_with_info, (1, {'arg': 0}, None, 3)), (func_with_context_and_info, (1, {'arg': 0}, 2, 3)), ]) -def test_annotated_resolver(func, expected): - resolver_func = annotated_resolver(func) +def test_resolver_from_annotations(func, expected): + resolver_func = resolver_from_annotations(func) resolved = resolver_func(root, args, context, info) assert resolved == expected