diff --git a/rest_framework/compat.py b/rest_framework/compat.py index e435618a2..759cd1fbc 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -6,6 +6,8 @@ versions of Django/Python, and compatibility wrappers around optional packages. # flake8: noqa from __future__ import unicode_literals +import inspect + import django from django.conf import settings from django.db import connection, transaction @@ -136,11 +138,43 @@ if six.PY3: SHORT_SEPARATORS = (',', ':') LONG_SEPARATORS = (', ', ': ') INDENT_SEPARATORS = (',', ': ') + + def is_simple_callable(obj): + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + # when we drop support of python3.2, we should replace getfullargspec with singnature + # signature = inspect.signature(obj) + # defaults = [p for p in signature.parameters.values() if p.default is not inspect.Parameter.empty] + # return len(signature.parameters) <= len(defaults) + function = inspect.isfunction(obj) + args, _, _, defaults, _, kwonly, kwdefaults = inspect.getfullargspec(obj) + len_args = (len(args) if function else len(args) - 1) + len(kwonly or ()) + len(kwdefaults or ()) + len_defaults = (len(defaults) if defaults else 0) + len(kwdefaults or ()) + return len_args <= len_defaults + else: SHORT_SEPARATORS = (b',', b':') LONG_SEPARATORS = (b', ', b': ') INDENT_SEPARATORS = (b',', b': ') + def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + args, _, _, defaults = inspect.getargspec(obj) + len_args = len(args) if function else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults + try: # DecimalValidator is unavailable in Django < 1.9 from django.core.validators import DecimalValidator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 643aa762f..0e79cf918 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -30,7 +30,7 @@ from django.utils.ipv6 import clean_ipv6_address from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 -from rest_framework.compat import unicode_repr, unicode_to_repr +from rest_framework.compat import unicode_repr, unicode_to_repr, is_simple_callable from rest_framework.exceptions import ValidationError from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -46,24 +46,6 @@ class empty: pass -def is_simple_callable(obj): - """ - True if the object is a callable that takes no arguments. - """ - function = inspect.isfunction(obj) - method = inspect.ismethod(obj) - - if not (function or method): - return False - if six.PY2: - args, _, _, defaults = inspect.getargspec(obj) - else: - args, _, _, defaults, _, _, _ = inspect.getfullargspec(obj) - len_args = len(args) if function else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults - - def get_attribute(instance, attrs): """ Similar to Python's built in `getattr(instance, attr)`, diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 572b69170..554a8c943 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -14,8 +14,9 @@ from django.utils.encoding import smart_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ +from rest_framework.compat import is_simple_callable from rest_framework.fields import ( - Field, empty, get_attribute, is_simple_callable, iter_options + Field, empty, get_attribute, iter_options ) from rest_framework.reverse import reverse from rest_framework.utils import html diff --git a/tests/compat/__init__.py b/tests/compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compat/test_compat_py35.py b/tests/compat/test_compat_py35.py new file mode 100644 index 000000000..19b849b0c --- /dev/null +++ b/tests/compat/test_compat_py35.py @@ -0,0 +1,17 @@ +class FunctionSimplicityCheckPy35Mixin: + def get_good_cases(self): + def annotated_simple() -> int: + return 0 + + def annotated_defaults(x: int = 0) -> int: + return 0 + + def kwonly_defaults(*, x=0): + pass + return super().get_good_cases() + (annotated_simple, annotated_defaults, kwonly_defaults) + + def get_bad_cases(self): + def kwonly(*, x): + pass + + return super().get_bad_cases() + (kwonly,) diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 000000000..b0c40ce34 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,53 @@ +from __future__ import unicode_literals + +import pytest +import sys + +from rest_framework.compat import is_simple_callable + + +class TestFunctionSimplicityCheck: + def get_good_cases(self): + def simple(): + pass + + def simple_with_default(x=0): + pass + + class SimpleMethods(object): + def simple(self): + pass + + def simple_with_default(self, x=0): + pass + + return simple, simple_with_default, SimpleMethods().simple, SimpleMethods().simple_with_default + + def get_bad_cases(self): + def positional(x): + pass + + def many_positional_and_defaults(x, y, z=0): + pass + + nofunc = 0 + + class Callable: + pass + + return positional, many_positional_and_defaults, nofunc, Callable + + def test_good_cases(self): + for case in self.get_good_cases(): + assert is_simple_callable(case) + + def test_bad_cases(self): + for case in self.get_bad_cases(): + assert not is_simple_callable(case) + + +if sys.version_info >= (3, 5): + from tests.compat.test_compat_py35 import FunctionSimplicityCheckPy35Mixin + + class TestFunctionSimplicityCheckPy35(FunctionSimplicityCheckPy35Mixin, TestFunctionSimplicityCheck): + pass