Merge branch 'v2' into main

This commit is contained in:
Jason Kraus 2020-12-30 22:31:41 -08:00
commit 8324d47999
6 changed files with 97 additions and 19 deletions

View File

@ -1,9 +1,11 @@
import pytest import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema from graphene import ObjectType, Schema
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet from graphene_django.tests.models import Pet, Person
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = [] pytestmark = []
@ -28,8 +30,27 @@ class PetNode(DjangoObjectType):
} }
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
class Query(ObjectType): class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode) pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
def test_string_in_filter(): def test_string_in_filter():
@ -61,6 +82,33 @@ def test_string_in_filter():
] ]
def test_string_in_filter_with_filterset_class():
"""Test in filter on a string field with a custom filterset class."""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=Query)
query = """
query {
people (names: ["John", "Michael"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_int_in_filter(): def test_int_in_filter():
""" """
Test in filter on an integer field. Test in filter on an integer field.

View File

@ -17,6 +17,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
model = filterset_class._meta.model model = filterset_class._meta.model
for name, filter_field in filterset_class.base_filters.items(): for name, filter_field in filterset_class.base_filters.items():
form_field = None form_field = None
filter_type = filter_field.lookup_expr
if name in filterset_class.declared_filters: if name in filterset_class.declared_filters:
# Get the filter field from the explicitly declared filter # Get the filter field from the explicitly declared filter
@ -25,7 +26,6 @@ def get_filtering_args_from_filterset(filterset_class, type):
else: else:
# Get the filter field with no explicit type declaration # Get the filter field with no explicit type declaration
model_field = get_model_field(model, filter_field.field_name) model_field = get_model_field(model, filter_field.field_name)
filter_type = filter_field.lookup_expr
if filter_type != "isnull" and hasattr(model_field, "formfield"): if filter_type != "isnull" and hasattr(model_field, "formfield"):
form_field = model_field.formfield( form_field = model_field.formfield(
required=filter_field.extra.get("required", False) required=filter_field.extra.get("required", False)
@ -38,14 +38,14 @@ def get_filtering_args_from_filterset(filterset_class, type):
field = convert_form_field(form_field) field = convert_form_field(form_field)
if filter_type in ["in", "range"]: if filter_type in ["in", "range"]:
# Replace CSV filters (`in`, `range`) argument type to be a list of the same type as the field. # Replace CSV filters (`in`, `range`) argument type to be a list of
# See comments in `replace_csv_filters` method for more details. # the same type as the field. See comments in
field = List(field.get_type()) # `replace_csv_filters` method for more details.
field = List(field.get_type())
field_type = field.Argument() field_type = field.Argument()
field_type.description = str(filter_field.label) if filter_field.label else None field_type.description = str(filter_field.label) if filter_field.label else None
args[name] = field_type args[name] = field_type
return args return args
@ -78,10 +78,7 @@ def replace_csv_filters(filterset_class):
""" """
for name, filter_field in list(filterset_class.base_filters.items()): for name, filter_field in list(filterset_class.base_filters.items()):
filter_type = filter_field.lookup_expr filter_type = filter_field.lookup_expr
if ( if filter_type in ["in", "range"]:
filter_type in ["in", "range"]
and name not in filterset_class.declared_filters
):
assert isinstance(filter_field, BaseCSVFilter) assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = Filter( filterset_class.base_filters[name] = Filter(
field_name=filter_field.field_name, field_name=filter_field.field_name,

View File

@ -6,6 +6,10 @@ from django.utils.translation import gettext_lazy as _
CHOICES = ((1, "this"), (2, _("that"))) CHOICES = ((1, "this"), (2, _("that")))
class Person(models.Model):
name = models.CharField(max_length=30)
class Pet(models.Model): class Pet(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
age = models.PositiveIntegerField() age = models.PositiveIntegerField()

View File

@ -51,6 +51,7 @@ def test_graphql_test_case_operation_name(post_mock):
pass pass
tc = TestClass() tc = TestClass()
tc._pre_setup()
tc.setUpClass() tc.setUpClass()
tc.query("query { }", operation_name="QueryName") tc.query("query { }", operation_name="QueryName")
body = json.loads(post_mock.call_args.args[1]) body = json.loads(post_mock.call_args.args[1])

View File

@ -1,6 +1,7 @@
import json import json
import warnings
from django.test import TestCase, Client from django.test import Client, TestCase
DEFAULT_GRAPHQL_URL = "/graphql/" DEFAULT_GRAPHQL_URL = "/graphql/"
@ -68,12 +69,6 @@ class GraphQLTestCase(TestCase):
# URL to graphql endpoint # URL to graphql endpoint
GRAPHQL_URL = DEFAULT_GRAPHQL_URL GRAPHQL_URL = DEFAULT_GRAPHQL_URL
@classmethod
def setUpClass(cls):
super(GraphQLTestCase, cls).setUpClass()
cls._client = Client()
def query( def query(
self, query, operation_name=None, input_data=None, variables=None, headers=None self, query, operation_name=None, input_data=None, variables=None, headers=None
): ):
@ -101,10 +96,19 @@ class GraphQLTestCase(TestCase):
input_data=input_data, input_data=input_data,
variables=variables, variables=variables,
headers=headers, headers=headers,
client=self._client, client=self.client,
graphql_url=self.GRAPHQL_URL, graphql_url=self.GRAPHQL_URL,
) )
@property
def _client(self):
warnings.warn(
"Using `_client` is deprecated in favour of `client`.",
PendingDeprecationWarning,
stacklevel=2,
)
return self.client
def assertResponseNoErrors(self, resp, msg=None): def assertResponseNoErrors(self, resp, msg=None):
""" """
Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`, Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`,

View File

@ -0,0 +1,24 @@
import pytest
from .. import GraphQLTestCase
from ...tests.test_types import with_local_registry
@with_local_registry
def test_graphql_test_case_deprecated_client():
"""
Test that `GraphQLTestCase._client`'s should raise pending deprecation warning.
"""
class TestClass(GraphQLTestCase):
GRAPHQL_SCHEMA = True
def runTest(self):
pass
tc = TestClass()
tc._pre_setup()
tc.setUpClass()
with pytest.warns(PendingDeprecationWarning):
tc._client