Check for filters defined on base filterset classes

This commit is contained in:
Kike Isidoro 2019-08-02 16:31:28 +02:00
parent 59f4f134b5
commit 0425985dab
2 changed files with 96 additions and 9 deletions

View File

@ -818,3 +818,86 @@ def test_integer_field_filter_type():
}
"""
)
def test_filter_filterset_based_on_mixin():
class ArticleFilterMixin:
@classmethod
def get_filters(cls):
filters = super().get_filters()
filters.update({
'viewer__email__in': django_filters.CharFilter(
method='filter_email_in',
field_name='reporter__email__in',
),
})
return filters
class NewArticleFilter(ArticleFilterMixin, ArticleFilter):
pass
class NewReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
class NewArticleFilterNode(DjangoObjectType):
viewer = Field(NewReporterNode)
class Meta:
model = Article
interfaces = (Node,)
filterset_class = NewArticleFilter
def resolve_viewer(self, info):
return self.reporter
class Query(ObjectType):
all_articles = DjangoFilterConnectionField(NewArticleFilterNode)
reporter = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com")
article = Article.objects.create(
headline="Hello",
reporter=reporter,
editor=reporter,
pub_date=datetime.now(),
pub_date_time=datetime.now())
schema = Schema(query=Query)
query = """
query NodeFilteringQuery {
allArticles {
edges {
node {
viewer {
email
}
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{
"node": {
"viewer": {
"email": reporter.email,
}
}
}
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -13,21 +13,25 @@ def get_filtering_args_from_filterset(filterset_class, type):
args = {}
model = filterset_class._meta.model
for name, filter_field in six.iteritems(filterset_class.base_filters):
form_field = None
if name in filterset_class.declared_filters:
form_field = filter_field.field
else:
field_name = name.split("__", 1)[0]
model_field = model._meta.get_field(field_name)
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(
required=filter_field.extra.get("required", False)
)
if hasattr(model, field_name):
model_field = model._meta.get_field(field_name)
# Fallback to field defined on filter if we can't get it from the
# model field
if not form_field:
form_field = filter_field.field
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(
required=filter_field.extra.get("required", False)
)
# Fallback to field defined on filter if we can't get it from the
# model field
if not form_field:
form_field = filter_field.field
field_type = convert_form_field(form_field).Argument()
field_type.description = filter_field.label