Merge pull request #181 from spockNinja/fix_many_to_many

Fix ManyToMany schema mapping
This commit is contained in:
Syrus Akbary 2017-06-24 15:08:49 -07:00 committed by GitHub
commit 06f8323fcf
2 changed files with 23 additions and 4 deletions

View File

@ -0,0 +1,12 @@
from ..utils import get_model_fields
from .models import Film, Reporter
def test_get_model_fields_no_duplication():
reporter_fields = get_model_fields(Reporter)
reporter_name_set = set([field[0] for field in reporter_fields])
assert len(reporter_fields) == len(reporter_name_set)
film_fields = get_model_fields(Film)
film_name_set = set([field[0] for field in film_fields])
assert len(film_fields) == len(film_name_set)

View File

@ -21,8 +21,12 @@ except (ImportError, AttributeError):
DJANGO_FILTER_INSTALLED = False DJANGO_FILTER_INSTALLED = False
def get_reverse_fields(model): def get_reverse_fields(model, local_field_names):
for name, attr in model.__dict__.items(): for name, attr in model.__dict__.items():
# Don't duplicate any local fields
if name in local_field_names:
continue
# Django =>1.9 uses 'rel', django <1.9 uses 'related' # Django =>1.9 uses 'rel', django <1.9 uses 'related'
related = getattr(attr, 'rel', None) or \ related = getattr(attr, 'rel', None) or \
getattr(attr, 'related', None) getattr(attr, 'related', None)
@ -44,15 +48,18 @@ def maybe_queryset(value):
def get_model_fields(model): def get_model_fields(model):
reverse_fields = get_reverse_fields(model) local_fields = [
all_fields = [
(field.name, field) (field.name, field)
for field for field
in sorted(list(model._meta.fields) + in sorted(list(model._meta.fields) +
list(model._meta.local_many_to_many)) list(model._meta.local_many_to_many))
] ]
all_fields += list(reverse_fields) # Make sure we don't duplicate local fields with "reverse" version
local_field_names = [field[0] for field in local_fields]
reverse_fields = get_reverse_fields(model, local_field_names)
all_fields = local_fields + list(reverse_fields)
return all_fields return all_fields