From 7a765537e9d0a996f81faf218663aa4b8edb8018 Mon Sep 17 00:00:00 2001 From: Jacob Foster Date: Mon, 22 May 2017 16:43:20 -0500 Subject: [PATCH 1/5] Only evaluate reverse M2Ms in get_reverse_fields --- graphene_django/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 7b4cd21..97f6fcb 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -33,7 +33,7 @@ def get_reverse_fields(model): yield (name, new_related) elif isinstance(related, models.ManyToOneRel): yield (name, related) - elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: + elif isinstance(related, models.ManyToManyRel) and attr.reverse and not related.symmetrical: yield (name, related) From cfe38ae20899816455c131c0669ce850a2e11ea5 Mon Sep 17 00:00:00 2001 From: Jacob Foster Date: Mon, 22 May 2017 17:20:56 -0500 Subject: [PATCH 2/5] Add tests --- graphene_django/tests/test_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 graphene_django/tests/test_utils.py diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py new file mode 100644 index 0000000..b83a7d8 --- /dev/null +++ b/graphene_django/tests/test_utils.py @@ -0,0 +1,26 @@ + + +from ..utils import get_model_fields, get_reverse_fields +from .models import Film, Reporter + + +def test_get_reverse_fields_correct(): + reporter_reverse_fields = get_reverse_fields(Reporter) + reporter_field_names = [field[0] for field in reporter_reverse_fields] + assert reporter_field_names == [ + 'articles', 'films' + ] + + film_reverse_fields = get_reverse_fields(Film) + film_field_names = [field[0] for field in film_reverse_fields] + assert film_field_names == ['details'] + + +def test_get_model_fields_no_duplication(): + reporter_fields = get_model_fields(Reporter) + reporter_name_set = set([field[0] for field in reporter_fields]) + assert len(reporter_fields) == len(reporter_name_set) + + film_fields = get_model_fields(Film) + film_name_set = set([field[0] for field in film_fields]) + assert len(film_fields) == len(film_name_set) From ca06d741956b5e2b8ae088ee3e93092216bdd283 Mon Sep 17 00:00:00 2001 From: Jacob Foster Date: Mon, 22 May 2017 17:36:00 -0500 Subject: [PATCH 3/5] Make test sort for stable comparison --- graphene_django/tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index b83a7d8..944324b 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -7,7 +7,7 @@ from .models import Film, Reporter def test_get_reverse_fields_correct(): reporter_reverse_fields = get_reverse_fields(Reporter) reporter_field_names = [field[0] for field in reporter_reverse_fields] - assert reporter_field_names == [ + assert sorted(reporter_field_names) == [ 'articles', 'films' ] From 74e4e1aa77365a39595b61a7cdf2327cecc8dd05 Mon Sep 17 00:00:00 2001 From: Jacob Foster Date: Thu, 25 May 2017 11:01:29 -0500 Subject: [PATCH 4/5] Drop reverse flag, remove duplicates in get_model_fields --- graphene_django/tests/test_utils.py | 16 +--------------- graphene_django/utils.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index 944324b..becd031 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -1,21 +1,7 @@ - - -from ..utils import get_model_fields, get_reverse_fields +from ..utils import get_model_fields from .models import Film, Reporter -def test_get_reverse_fields_correct(): - reporter_reverse_fields = get_reverse_fields(Reporter) - reporter_field_names = [field[0] for field in reporter_reverse_fields] - assert sorted(reporter_field_names) == [ - 'articles', 'films' - ] - - film_reverse_fields = get_reverse_fields(Film) - film_field_names = [field[0] for field in film_reverse_fields] - assert film_field_names == ['details'] - - def test_get_model_fields_no_duplication(): reporter_fields = get_model_fields(Reporter) reporter_name_set = set([field[0] for field in reporter_fields]) diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 97f6fcb..677a206 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -33,7 +33,7 @@ def get_reverse_fields(model): yield (name, new_related) elif isinstance(related, models.ManyToOneRel): yield (name, related) - elif isinstance(related, models.ManyToManyRel) and attr.reverse and not related.symmetrical: + elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: yield (name, related) @@ -52,7 +52,14 @@ def get_model_fields(model): list(model._meta.local_many_to_many)) ] - all_fields += list(reverse_fields) + # Make sure we don't duplicate local fields with "reverse" version + all_field_names = [field[0] for field in all_fields] + actual_reverse_fields = [ + reverse_field for reverse_field in reverse_fields + if reverse_field[0] not in all_field_names + ] + + all_fields += list(actual_reverse_fields) return all_fields From 95510987f177d24ae6f79956ef1a6fa64319a783 Mon Sep 17 00:00:00 2001 From: Jacob Foster Date: Thu, 25 May 2017 11:15:13 -0500 Subject: [PATCH 5/5] Reorganize for clarity --- graphene_django/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 677a206..3ea4d0d 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -21,8 +21,12 @@ except (ImportError, AttributeError): DJANGO_FILTER_INSTALLED = False -def get_reverse_fields(model): +def get_reverse_fields(model, local_field_names): for name, attr in model.__dict__.items(): + # Don't duplicate any local fields + if name in local_field_names: + continue + # Django =>1.9 uses 'rel', django <1.9 uses 'related' related = getattr(attr, 'rel', None) or \ getattr(attr, 'related', None) @@ -44,8 +48,7 @@ def maybe_queryset(value): def get_model_fields(model): - reverse_fields = get_reverse_fields(model) - all_fields = [ + local_fields = [ (field.name, field) for field in sorted(list(model._meta.fields) + @@ -53,13 +56,10 @@ def get_model_fields(model): ] # Make sure we don't duplicate local fields with "reverse" version - all_field_names = [field[0] for field in all_fields] - actual_reverse_fields = [ - reverse_field for reverse_field in reverse_fields - if reverse_field[0] not in all_field_names - ] + local_field_names = [field[0] for field in local_fields] + reverse_fields = get_reverse_fields(model, local_field_names) - all_fields += list(actual_reverse_fields) + all_fields = local_fields + list(reverse_fields) return all_fields