From 74e4e1aa77365a39595b61a7cdf2327cecc8dd05 Mon Sep 17 00:00:00 2001 From: Jacob Foster Date: Thu, 25 May 2017 11:01:29 -0500 Subject: [PATCH] Drop reverse flag, remove duplicates in get_model_fields --- graphene_django/tests/test_utils.py | 16 +--------------- graphene_django/utils.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index 944324b..becd031 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -1,21 +1,7 @@ - - -from ..utils import get_model_fields, get_reverse_fields +from ..utils import get_model_fields from .models import Film, Reporter -def test_get_reverse_fields_correct(): - reporter_reverse_fields = get_reverse_fields(Reporter) - reporter_field_names = [field[0] for field in reporter_reverse_fields] - assert sorted(reporter_field_names) == [ - 'articles', 'films' - ] - - film_reverse_fields = get_reverse_fields(Film) - film_field_names = [field[0] for field in film_reverse_fields] - assert film_field_names == ['details'] - - def test_get_model_fields_no_duplication(): reporter_fields = get_model_fields(Reporter) reporter_name_set = set([field[0] for field in reporter_fields]) diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 97f6fcb..677a206 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -33,7 +33,7 @@ def get_reverse_fields(model): yield (name, new_related) elif isinstance(related, models.ManyToOneRel): yield (name, related) - elif isinstance(related, models.ManyToManyRel) and attr.reverse and not related.symmetrical: + elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: yield (name, related) @@ -52,7 +52,14 @@ def get_model_fields(model): list(model._meta.local_many_to_many)) ] - all_fields += list(reverse_fields) + # Make sure we don't duplicate local fields with "reverse" version + all_field_names = [field[0] for field in all_fields] + actual_reverse_fields = [ + reverse_field for reverse_field in reverse_fields + if reverse_field[0] not in all_field_names + ] + + all_fields += list(actual_reverse_fields) return all_fields