diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py new file mode 100644 index 0000000..becd031 --- /dev/null +++ b/graphene_django/tests/test_utils.py @@ -0,0 +1,12 @@ +from ..utils import get_model_fields +from .models import Film, Reporter + + +def test_get_model_fields_no_duplication(): + reporter_fields = get_model_fields(Reporter) + reporter_name_set = set([field[0] for field in reporter_fields]) + assert len(reporter_fields) == len(reporter_name_set) + + film_fields = get_model_fields(Film) + film_name_set = set([field[0] for field in film_fields]) + assert len(film_fields) == len(film_name_set) diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 7b4cd21..3ea4d0d 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -21,8 +21,12 @@ except (ImportError, AttributeError): DJANGO_FILTER_INSTALLED = False -def get_reverse_fields(model): +def get_reverse_fields(model, local_field_names): for name, attr in model.__dict__.items(): + # Don't duplicate any local fields + if name in local_field_names: + continue + # Django =>1.9 uses 'rel', django <1.9 uses 'related' related = getattr(attr, 'rel', None) or \ getattr(attr, 'related', None) @@ -44,15 +48,18 @@ def maybe_queryset(value): def get_model_fields(model): - reverse_fields = get_reverse_fields(model) - all_fields = [ + local_fields = [ (field.name, field) for field in sorted(list(model._meta.fields) + list(model._meta.local_many_to_many)) ] - all_fields += list(reverse_fields) + # Make sure we don't duplicate local fields with "reverse" version + local_field_names = [field[0] for field in local_fields] + reverse_fields = get_reverse_fields(model, local_field_names) + + all_fields = local_fields + list(reverse_fields) return all_fields