Merge remote-tracking branch 'up/v2' into fix-blank-field-enum

This commit is contained in:
Jason Kraus 2021-01-09 21:57:45 -08:00
commit bf29f3a74c
43 changed files with 1444 additions and 153 deletions

View File

@ -3,13 +3,13 @@
A [Django](https://www.djangoproject.com/) integration for [Graphene](http://graphene-python.org/). A [Django](https://www.djangoproject.com/) integration for [Graphene](http://graphene-python.org/).
[![travis][travis-image]][travis-url] [![build][build-image]][build-url]
[![pypi][pypi-image]][pypi-url] [![pypi][pypi-image]][pypi-url]
[![Anaconda-Server Badge][conda-image]][conda-url] [![Anaconda-Server Badge][conda-image]][conda-url]
[![coveralls][coveralls-image]][coveralls-url] [![coveralls][coveralls-image]][coveralls-url]
[travis-image]: https://travis-ci.org/graphql-python/graphene-django.svg?branch=master&style=flat [build-image]: https://github.com/graphql-python/graphene-django/workflows/Tests/badge.svg
[travis-url]: https://travis-ci.org/graphql-python/graphene-django [build-url]: https://github.com/graphql-python/graphene-django/actions
[pypi-image]: https://img.shields.io/pypi/v/graphene-django.svg?style=flat [pypi-image]: https://img.shields.io/pypi/v/graphene-django.svg?style=flat
[pypi-url]: https://pypi.org/project/graphene-django/ [pypi-url]: https://pypi.org/project/graphene-django/
[coveralls-image]: https://coveralls.io/repos/github/graphql-python/graphene-django/badge.svg?branch=master [coveralls-image]: https://coveralls.io/repos/github/graphql-python/graphene-django/badge.svg?branch=master
@ -110,6 +110,11 @@ To learn more check out the following [examples](examples/):
* **Relay Schema**: [Starwars Relay example](examples/starwars) * **Relay Schema**: [Starwars Relay example](examples/starwars)
## GraphQL testing clients
- [Firecamp](https://firecamp.io/graphql)
- [GraphiQL](https://github.com/graphql/graphiql)
## Contributing ## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) See [CONTRIBUTING.md](CONTRIBUTING.md)

View File

@ -114,8 +114,8 @@ Contributing
See `CONTRIBUTING.md <CONTRIBUTING.md>`__. See `CONTRIBUTING.md <CONTRIBUTING.md>`__.
.. |Graphene Logo| image:: http://graphene-python.org/favicon.png .. |Graphene Logo| image:: http://graphene-python.org/favicon.png
.. |Build Status| image:: https://travis-ci.org/graphql-python/graphene-django.svg?branch=master .. |Build Status| image:: https://github.com/graphql-python/graphene-django/workflows/Tests/badge.svg
:target: https://travis-ci.org/graphql-python/graphene-django :target: https://github.com/graphql-python/graphene-django/actions
.. |PyPI version| image:: https://badge.fury.io/py/graphene-django.svg .. |PyPI version| image:: https://badge.fury.io/py/graphene-django.svg
:target: https://badge.fury.io/py/graphene-django :target: https://badge.fury.io/py/graphene-django
.. |Coverage Status| image:: https://coveralls.io/repos/graphql-python/graphene-django/badge.svg?branch=master&service=github .. |Coverage Status| image:: https://coveralls.io/repos/graphql-python/graphene-django/badge.svg?branch=master&service=github

View File

@ -3,7 +3,7 @@ Django Debug Middleware
You can debug your GraphQL queries in a similar way to You can debug your GraphQL queries in a similar way to
`django-debug-toolbar <https://django-debug-toolbar.readthedocs.org/>`__, `django-debug-toolbar <https://django-debug-toolbar.readthedocs.org/>`__,
but outputing in the results in GraphQL response as fields, instead of but outputting in the results in GraphQL response as fields, instead of
the graphical HTML interface. the graphical HTML interface.
For that, you will need to add the plugin in your graphene schema. For that, you will need to add the plugin in your graphene schema.
@ -43,7 +43,7 @@ And in your ``settings.py``:
Querying Querying
-------- --------
You can query it for outputing all the sql transactions that happened in You can query it for outputting all the sql transactions that happened in
the GraphQL request, like: the GraphQL request, like:
.. code:: .. code::

View File

@ -36,7 +36,8 @@ Simple example
# The class attributes define the response of the mutation # The class attributes define the response of the mutation
question = graphene.Field(QuestionType) question = graphene.Field(QuestionType)
def mutate(self, info, text, id): @classmethod
def mutate(cls, root, info, text, id):
question = Question.objects.get(pk=id) question = Question.objects.get(pk=id)
question.text = text question.text = text
question.save() question.save()
@ -229,3 +230,121 @@ This argument is also sent back to the client with the mutation result
(you do not have to do anything). For services that manage (you do not have to do anything). For services that manage
a pool of many GraphQL requests in bulk, the ``clientIDMutation`` a pool of many GraphQL requests in bulk, the ``clientIDMutation``
allows you to match up a specific mutation with the response. allows you to match up a specific mutation with the response.
Django Database Transactions
----------------------------
Django gives you a few ways to control how database transactions are managed.
Tying transactions to HTTP requests
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A common way to handle transactions in Django is to wrap each request in a transaction.
Set ``ATOMIC_REQUESTS`` settings to ``True`` in the configuration of each database for
which you want to enable this behavior.
It works like this. Before calling ``GraphQLView`` Django starts a transaction. If the
response is produced without problems, Django commits the transaction. If the view, a
``DjangoFormMutation`` or a ``DjangoModelFormMutation`` produces an exception, Django
rolls back the transaction.
.. warning::
While the simplicity of this transaction model is appealing, it also makes it
inefficient when traffic increases. Opening a transaction for every request has some
overhead. The impact on performance depends on the query patterns of your application
and on how well your database handles locking.
Check the next section for a better solution.
Tying transactions to mutations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A mutation can contain multiple fields, just like a query. There's one important
distinction between queries and mutations, other than the name:
..
`While query fields are executed in parallel, mutation fields run in series, one
after the other.`
This means that if we send two ``incrementCredits`` mutations in one request, the first
is guaranteed to finish before the second begins, ensuring that we don't end up with a
race condition with ourselves.
On the other hand, if the first ``incrementCredits`` runs successfully but the second
one does not, the operation cannot be retried as it is. That's why is a good idea to
run all mutation fields in a transaction, to guarantee all occur or nothing occurs.
To enable this behavior for all databases set the graphene ``ATOMIC_MUTATIONS`` settings
to ``True`` in your settings file:
.. code:: python
GRAPHENE = {
# ...
"ATOMIC_MUTATIONS": True,
}
On the contrary, if you want to enable this behavior for a specific database, set
``ATOMIC_MUTATIONS`` to ``True`` in your database settings:
.. code:: python
DATABASES = {
"default": {
# ...
"ATOMIC_MUTATIONS": True,
},
# ...
}
Now, given the following example mutation:
.. code::
mutation IncreaseCreditsTwice {
increaseCredits1: increaseCredits(input: { amount: 10 }) {
balance
errors {
field
messages
}
}
increaseCredits2: increaseCredits(input: { amount: -1 }) {
balance
errors {
field
messages
}
}
}
The server is going to return something like:
.. code:: json
{
"data": {
"increaseCredits1": {
"balance": 10.0,
"errors": []
},
"increaseCredits2": {
"balance": null,
"errors": [
{
"field": "amount",
"message": "Amount should be a positive number"
}
]
},
}
}
But the balance will remain the same.

View File

@ -287,7 +287,7 @@ Where "foo" is the name of the field declared in the ``Query`` object.
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
foo = graphene.List(QuestionType) foo = graphene.List(QuestionType)
def resolve_foo(root, info): def resolve_foo(root, info, **kwargs):
id = kwargs.get("id") id = kwargs.get("id")
return Question.objects.get(id) return Question.objects.get(id)

View File

@ -1,7 +1,7 @@
from .fields import DjangoConnectionField, DjangoListField from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType from .types import DjangoObjectType
__version__ = "2.13.0" __version__ = "2.15.0"
__all__ = [ __all__ = [
"__version__", "__version__",

View File

@ -0,0 +1 @@
MUTATION_ERRORS_FLAG = "graphene_mutation_has_errors"

View File

@ -21,6 +21,7 @@ from graphene import (
NonNull, NonNull,
String, String,
Time, Time,
Decimal,
) )
from graphene.types.resolver import get_default_resolver from graphene.types.resolver import get_default_resolver
from graphene.types.json import JSONString from graphene.types.json import JSONString
@ -185,6 +186,10 @@ def convert_field_to_boolean(field, registry=None):
@convert_django_field.register(models.DecimalField) @convert_django_field.register(models.DecimalField)
def convert_field_to_decimal(field, registry=None):
return Decimal(description=field.help_text, required=not field.null)
@convert_django_field.register(models.FloatField) @convert_django_field.register(models.FloatField)
@convert_django_field.register(models.DurationField) @convert_django_field.register(models.DurationField)
def convert_field_to_float(field, registry=None): def convert_field_to_float(field, registry=None):

View File

@ -43,16 +43,16 @@ class DjangoListField(Field):
def model(self): def model(self):
return self._underlying_type._meta.model return self._underlying_type._meta.model
def get_default_queryset(self): def get_manager(self):
return self.model._default_manager.get_queryset() return self.model._default_manager
@staticmethod @staticmethod
def list_resolver( def list_resolver(
django_object_type, resolver, default_queryset, root, info, **args django_object_type, resolver, default_manager, root, info, **args
): ):
queryset = maybe_queryset(resolver(root, info, **args)) queryset = maybe_queryset(resolver(root, info, **args))
if queryset is None: if queryset is None:
queryset = default_queryset queryset = maybe_queryset(default_manager)
if isinstance(queryset, QuerySet): if isinstance(queryset, QuerySet):
# Pass queryset to the DjangoObjectType get_queryset method # Pass queryset to the DjangoObjectType get_queryset method
@ -66,10 +66,7 @@ class DjangoListField(Field):
_type = _type.of_type _type = _type.of_type
django_object_type = _type.of_type.of_type django_object_type = _type.of_type.of_type
return partial( return partial(
self.list_resolver, self.list_resolver, django_object_type, parent_resolver, self.get_manager(),
django_object_type,
parent_resolver,
self.get_default_queryset(),
) )
@ -147,9 +144,6 @@ class DjangoConnectionField(ConnectionField):
if isinstance(iterable, QuerySet): if isinstance(iterable, QuerySet):
list_length = iterable.count() list_length = iterable.count()
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
else: else:
list_length = len(iterable) list_length = len(iterable)
list_slice_length = ( list_slice_length = (
@ -162,6 +156,10 @@ class DjangoConnectionField(ConnectionField):
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length) after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
if max_limit is not None and "first" not in args: if max_limit is not None and "first" not in args:
if "last" in args:
args["first"] = list_length
list_slice_length = list_length
else:
args["first"] = max_limit args["first"] = max_limit
connection = connection_from_list_slice( connection = connection_from_list_slice(

View File

@ -9,7 +9,7 @@ if not DJANGO_FILTER_INSTALLED:
) )
else: else:
from .fields import DjangoFilterConnectionField from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = [ __all__ = [
"DjangoFilterConnectionField", "DjangoFilterConnectionField",

View File

@ -3,6 +3,7 @@ from functools import partial
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
@ -21,6 +22,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self._fields = fields self._fields = fields
self._provided_filterset_class = filterset_class self._provided_filterset_class = filterset_class
self._filterset_class = None self._filterset_class = None
self._filtering_args = None
self._extra_filter_meta = extra_filter_meta self._extra_filter_meta = extra_filter_meta
self._base_args = None self._base_args = None
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
@ -50,18 +52,31 @@ class DjangoFilterConnectionField(DjangoConnectionField):
@property @property
def filtering_args(self): def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type) if not self._filtering_args:
self._filtering_args = get_filtering_args_from_filterset(
self.filterset_class, self.node_type
)
return self._filtering_args
@classmethod @classmethod
def resolve_queryset( def resolve_queryset(
cls, connection, iterable, info, args, filtering_args, filterset_class cls, connection, iterable, info, args, filtering_args, filterset_class
): ):
def filter_kwargs():
kwargs = {}
for k, v in args.items():
if k in filtering_args:
if k == "order_by":
v = to_snake_case(v)
kwargs[k] = v
return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset( qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
connection, iterable, info, args connection, iterable, info, args
) )
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
filterset = filterset_class( filterset = filterset_class(
data=filter_kwargs, queryset=qs, request=info.context data=filter_kwargs(), queryset=qs, request=info.context
) )
if filterset.form.is_valid(): if filterset.form.is_valid():
return filterset.qs return filterset.qs

View File

@ -0,0 +1,75 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
class InFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super(InFilter, self).filter(qs, value)
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(Filter):
field_class = RangeField

View File

@ -1,32 +1,11 @@
import itertools import itertools
from django.db import models from django.db import models
from django_filters import Filter, MultipleChoiceFilter, VERSION from django_filters import VERSION
from django_filters.filterset import BaseFilterSet, FilterSet from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
from graphql_relay.node.node import from_global_id from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
GRAPHENE_FILTER_SET_OVERRIDES = { GRAPHENE_FILTER_SET_OVERRIDES = {

View File

@ -21,7 +21,7 @@ class ReporterFilter(django_filters.FilterSet):
model = Reporter model = Reporter
fields = ["first_name", "last_name", "email", "pets"] fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=("pub_date",)) order_by = OrderingFilter(fields=("first_name",))
class PetFilter(django_filters.FilterSet): class PetFilter(django_filters.FilterSet):

View File

@ -713,6 +713,73 @@ def test_should_query_filter_node_limit():
assert result.data == expected assert result.data == expected
def test_order_by():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
ReporterType, filterset_class=ReporterFilter
)
Reporter.objects.create(first_name="b")
Reporter.objects.create(first_name="a")
schema = Schema(query=Query)
query = """
query NodeFilteringQuery {
allReporters(orderBy: "-firstName") {
edges {
node {
firstName
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [{"node": {"firstName": "b"}}, {"node": {"firstName": "a"}}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
query = """
query NodeFilteringQuery {
allReporters(orderBy: "-first_name") {
edges {
node {
firstName
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data == expected
query = """
query NodeFilteringQuery {
allReporters(orderBy: "-firtsnaMe") {
edges {
node {
firstName
}
}
}
}
"""
result = schema.execute(query)
assert result.errors
def test_order_by_is_perserved(): def test_order_by_is_perserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:

View File

@ -0,0 +1,183 @@
import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet, Person
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
def test_string_in_filter():
"""
Test in filter on a string field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query)
query = """
query {
pets (name_In: ["Brutus", "Jojo, the rabbit"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Brutus"}},
{"node": {"name": "Jojo, the rabbit"}},
]
def test_string_in_filter_with_filterset_class():
"""Test in filter on a string field with a custom filterset class."""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=Query)
query = """
query {
people (names: ["John", "Michael"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_int_in_filter():
"""
Test in filter on an integer field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query)
query = """
query {
pets (age_In: [3]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Jojo, the rabbit"}},
]
query = """
query {
pets (age_In: [3, 12]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Brutus"}},
{"node": {"name": "Mimi"}},
{"node": {"name": "Jojo, the rabbit"}},
]
def test_in_filter_with_empty_list():
"""
Check that using a in filter with an empty list provided as input returns no objects.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query {
pets (name_In: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert len(result.data["pets"]["edges"]) == 0

View File

@ -0,0 +1,115 @@
import ast
import json
import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
def test_int_range_filter():
"""
Test range filter on an integer field.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query {
pets (age_Range: [4, 9]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["pets"]["edges"] == [
{"node": {"name": "Mimi"}},
{"node": {"name": "Picotin"}},
]
def test_range_filter_with_invalid_input():
"""
Test range filter used with invalid inputs raise an error.
"""
Pet.objects.create(name="Brutus", age=12)
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Jojo, the rabbit", age=3)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
query = """
query ($rangeValue: [Int]) {
pets (age_Range: $rangeValue) {
edges {
node {
name
}
}
}
}
"""
expected_error = json.dumps(
{
"age__range": [
{
"message": "Invalid range specified: it needs to contain 2 values.",
"code": "invalid",
}
]
}
)
# Empty list
result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
# Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
# More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1
assert ast.literal_eval(result.errors[0].message)[0] == expected_error

View File

@ -1,7 +1,12 @@
import six import six
from graphene import List
from django_filters.utils import get_model_field from django_filters.utils import get_model_field
from django_filters.filters import Filter, BaseCSVFilter
from .filterset import custom_filterset_factory, setup_filterset from .filterset import custom_filterset_factory, setup_filterset
from .filters import InFilter, RangeFilter
def get_filtering_args_from_filterset(filterset_class, type): def get_filtering_args_from_filterset(filterset_class, type):
@ -15,12 +20,15 @@ def get_filtering_args_from_filterset(filterset_class, type):
model = filterset_class._meta.model model = filterset_class._meta.model
for name, filter_field in six.iteritems(filterset_class.base_filters): for name, filter_field in six.iteritems(filterset_class.base_filters):
form_field = None form_field = None
filter_type = filter_field.lookup_expr
if name in filterset_class.declared_filters: if name in filterset_class.declared_filters:
# Get the filter field from the explicitly declared filter
form_field = filter_field.field form_field = filter_field.field
field = convert_form_field(form_field)
else: else:
# Get the filter field with no explicit type declaration
model_field = get_model_field(model, filter_field.field_name) model_field = get_model_field(model, filter_field.field_name)
filter_type = filter_field.lookup_expr
if filter_type != "isnull" and hasattr(model_field, "formfield"): if filter_type != "isnull" and hasattr(model_field, "formfield"):
form_field = model_field.formfield( form_field = model_field.formfield(
required=filter_field.extra.get("required", False) required=filter_field.extra.get("required", False)
@ -31,7 +39,15 @@ def get_filtering_args_from_filterset(filterset_class, type):
if not form_field: if not form_field:
form_field = filter_field.field form_field = filter_field.field
field_type = convert_form_field(form_field).Argument() field = convert_form_field(form_field)
if filter_type in ["in", "range"]:
# Replace CSV filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in
# `replace_csv_filters` method for more details.
field = List(field.get_type())
field_type = field.Argument()
field_type.description = filter_field.label field_type.description = filter_field.label
args[name] = field_type args[name] = field_type
@ -39,9 +55,50 @@ def get_filtering_args_from_filterset(filterset_class, type):
def get_filterset_class(filterset_class, **meta): def get_filterset_class(filterset_class, **meta):
"""Get the class to be used as the FilterSet""" """
Get the class to be used as the FilterSet.
"""
if filterset_class: if filterset_class:
# If were given a FilterSet class, then set it up and # If were given a FilterSet class, then set it up.
# return it graphene_filterset_class = setup_filterset(filterset_class)
return setup_filterset(filterset_class) else:
return custom_filterset_factory(**meta) # Otherwise create one.
graphene_filterset_class = custom_filterset_factory(**meta)
replace_csv_filters(graphene_filterset_class)
return graphene_filterset_class
def replace_csv_filters(filterset_class):
"""
Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but regular Filter objects that simply use the input value as filter argument on the queryset.
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
can actually have a list as input and have a proper type verification of each value in the list.
See issue https://github.com/graphql-python/graphene-django/issues/1068.
"""
for name, filter_field in six.iteritems(filterset_class.base_filters):
filter_type = filter_field.lookup_expr
if filter_type == "in":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = InFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
method=filter_field.method,
exclude=filter_field.exclude,
**filter_field.extra
)
if filter_type == "range":
assert isinstance(filter_field, BaseCSVFilter)
filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
method=filter_field.method,
exclude=filter_field.exclude,
**filter_field.extra
)

View File

@ -63,6 +63,11 @@ def convert_form_field_to_list(field):
return List(ID, required=field.required) return List(ID, required=field.required)
@convert_form_field.register(forms.MultipleChoiceField)
def convert_form_field_to_string_list(field):
return List(String, required=field.required)
@convert_form_field.register(forms.DateField) @convert_form_field.register(forms.DateField)
def convert_form_field_to_date(field): def convert_form_field_to_date(field):
return Date(description=field.help_text, required=field.required) return Date(description=field.help_text, required=field.required)

View File

@ -11,8 +11,13 @@ from graphene.types.mutation import MutationOptions
# InputObjectType, # InputObjectType,
# ) # )
from graphene.types.utils import yank_fields_from_attrs from graphene.types.utils import yank_fields_from_attrs
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.registry import get_global_registry from graphene_django.registry import get_global_registry
from django.core.exceptions import ValidationError
from django.db import connection
from ..types import ErrorType from ..types import ErrorType
from .converter import convert_form_field from .converter import convert_form_field
@ -46,6 +51,7 @@ class BaseDjangoFormMutation(ClientIDMutation):
return cls.perform_mutate(form, info) return cls.perform_mutate(form, info)
else: else:
errors = ErrorType.from_errors(form.errors) errors = ErrorType.from_errors(form.errors)
_set_errors_flag_to_context(info)
return cls(errors=errors, **form.data) return cls(errors=errors, **form.data)
@ -170,6 +176,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
return cls.perform_mutate(form, info) return cls.perform_mutate(form, info)
else: else:
errors = ErrorType.from_errors(form.errors) errors = ErrorType.from_errors(form.errors)
_set_errors_flag_to_context(info)
return cls(errors=errors) return cls(errors=errors)
@ -178,3 +185,9 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
obj = form.save() obj = form.save()
kwargs = {cls._meta.return_field_name: obj} kwargs = {cls._meta.return_field_name: obj}
return cls(errors=[], **kwargs) return cls(errors=[], **kwargs)
def _set_errors_flag_to_context(info):
# This is not ideal but necessary to keep the response errors empty
if info and info.context:
setattr(info.context, MUTATION_ERRORS_FLAG, True)

View File

@ -101,7 +101,14 @@ def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, Float) assert_conversion(forms.DecimalField, Float)
def test_should_multiple_choice_convert_connectionorlist(): def test_should_multiple_choice_convert_list():
field = forms.MultipleChoiceField()
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == String
def test_should_model_multiple_choice_convert_connectionorlist():
field = forms.ModelMultipleChoiceField(queryset=None) field = forms.ModelMultipleChoiceField(queryset=None)
graphene_type = convert_form_field(field) graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List) assert isinstance(graphene_type, List)

View File

@ -5,21 +5,13 @@ from py.test import raises
from graphene import Field, ObjectType, Schema, String from graphene import Field, ObjectType, Schema, String
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.tests.forms import PetForm
from graphene_django.tests.models import Pet from graphene_django.tests.models import Pet
from graphene_django.tests.mutations import PetMutation
from ..mutation import DjangoFormMutation, DjangoModelFormMutation from ..mutation import DjangoFormMutation, DjangoModelFormMutation
@pytest.fixture()
def pet_type():
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
return PetType
class MyForm(forms.Form): class MyForm(forms.Form):
text = forms.CharField() text = forms.CharField()
@ -33,18 +25,6 @@ class MyForm(forms.Form):
pass pass
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = "__all__"
def clean_age(self):
age = self.cleaned_data["age"]
if age >= 99:
raise ValidationError("Too old")
return age
def test_needs_form_class(): def test_needs_form_class():
with raises(Exception) as exc: with raises(Exception) as exc:
@ -70,11 +50,18 @@ def test_has_input_fields():
assert "text" in MyMutation.Input._meta.fields assert "text" in MyMutation.Input._meta.fields
def test_mutation_error_camelcased(pet_type, graphene_settings): def test_mutation_error_camelcased(graphene_settings):
class ExtraPetForm(PetForm): class ExtraPetForm(PetForm):
test_field = forms.CharField(required=True) test_field = forms.CharField(required=True)
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta: class Meta:
form_class = ExtraPetForm form_class = ExtraPetForm
@ -146,21 +133,13 @@ def test_form_valid_input():
assert result.data["myMutation"]["text"] == "VALID_INPUT" assert result.data["myMutation"]["text"] == "VALID_INPUT"
def test_default_meta_fields(pet_type): def test_default_meta_fields():
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
assert PetMutation._meta.model is Pet assert PetMutation._meta.model is Pet
assert PetMutation._meta.return_field_name == "pet" assert PetMutation._meta.return_field_name == "pet"
assert "pet" in PetMutation._meta.fields assert "pet" in PetMutation._meta.fields
def test_default_input_meta_fields(pet_type): def test_default_input_meta_fields():
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
assert PetMutation._meta.model is Pet assert PetMutation._meta.model is Pet
assert PetMutation._meta.return_field_name == "pet" assert PetMutation._meta.return_field_name == "pet"
assert "name" in PetMutation.Input._meta.fields assert "name" in PetMutation.Input._meta.fields
@ -168,8 +147,15 @@ def test_default_input_meta_fields(pet_type):
assert "id" in PetMutation.Input._meta.fields assert "id" in PetMutation.Input._meta.fields
def test_exclude_fields_input_meta_fields(pet_type): def test_exclude_fields_input_meta_fields():
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta: class Meta:
form_class = PetForm form_class = PetForm
exclude_fields = ["id"] exclude_fields = ["id"]
@ -182,8 +168,15 @@ def test_exclude_fields_input_meta_fields(pet_type):
assert "id" not in PetMutation.Input._meta.fields assert "id" not in PetMutation.Input._meta.fields
def test_custom_return_field_name(pet_type): def test_custom_return_field_name():
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"
class PetMutation(DjangoModelFormMutation): class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta: class Meta:
form_class = PetForm form_class = PetForm
model = Pet model = Pet
@ -194,13 +187,7 @@ def test_custom_return_field_name(pet_type):
assert "animal" in PetMutation._meta.fields assert "animal" in PetMutation._meta.fields
def test_model_form_mutation_mutate_existing(pet_type): def test_model_form_mutation_mutate_existing():
class PetMutation(DjangoModelFormMutation):
pet = Field(pet_type)
class Meta:
form_class = PetForm
class Mutation(ObjectType): class Mutation(ObjectType):
pet_mutation = PetMutation.Field() pet_mutation = PetMutation.Field()
@ -229,13 +216,7 @@ def test_model_form_mutation_mutate_existing(pet_type):
assert pet.name == "Mia" assert pet.name == "Mia"
def test_model_form_mutation_creates_new(pet_type): def test_model_form_mutation_creates_new():
class PetMutation(DjangoModelFormMutation):
pet = Field(pet_type)
class Meta:
form_class = PetForm
class Mutation(ObjectType): class Mutation(ObjectType):
pet_mutation = PetMutation.Field() pet_mutation = PetMutation.Field()
@ -265,13 +246,7 @@ def test_model_form_mutation_creates_new(pet_type):
assert pet.age == 10 assert pet.age == 10
def test_model_form_mutation_invalid_input(pet_type): def test_model_form_mutation_invalid_input():
class PetMutation(DjangoModelFormMutation):
pet = Field(pet_type)
class Meta:
form_class = PetForm
class Mutation(ObjectType): class Mutation(ObjectType):
pet_mutation = PetMutation.Field() pet_mutation = PetMutation.Field()
@ -301,11 +276,7 @@ def test_model_form_mutation_invalid_input(pet_type):
assert Pet.objects.count() == 0 assert Pet.objects.count() == 0
def test_model_form_mutation_mutate_invalid_form(pet_type): def test_model_form_mutation_mutate_invalid_form():
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None) result = PetMutation.mutate_and_get_payload(None, None)
# A pet was not created # A pet was not created
@ -317,3 +288,98 @@ def test_model_form_mutation_mutate_invalid_form(pet_type):
assert result.errors[1].messages == ["This field is required."] assert result.errors[1].messages == ["This field is required."]
assert "age" in fields_w_error assert "age" in fields_w_error
assert "name" in fields_w_error assert "name" in fields_w_error
def test_model_form_mutation_multiple_creation_valid():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
"""
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 10 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
)
assert result.errors is None
assert result.data["petMutation1"]["pet"] == {"name": "Mia", "age": 10}
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 2
pet1 = Pet.objects.first()
assert pet1.name == "Mia"
assert pet1.age == 10
pet2 = Pet.objects.last()
assert pet2.name == "Enzo"
assert pet2.age == 0
def test_model_form_mutation_multiple_creation_invalid():
class Mutation(ObjectType):
pet_mutation = PetMutation.Field()
schema = Schema(query=MockQuery, mutation=Mutation)
result = schema.execute(
"""
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
)
assert result.errors is None
assert result.data["petMutation1"]["pet"] is None
assert result.data["petMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Enzo"
assert pet.age == 0

View File

@ -60,8 +60,10 @@ class Command(CommandArguments):
def get_schema(self, schema, out, indent): def get_schema(self, schema, out, indent):
schema_dict = {"data": schema.introspect()} schema_dict = {"data": schema.introspect()}
if out == "-": if out == "-" or out == "-.json":
self.stdout.write(json.dumps(schema_dict, indent=indent, sort_keys=True)) self.stdout.write(json.dumps(schema_dict, indent=indent, sort_keys=True))
elif out == "-.graphql":
self.stdout.write(print_schema(schema))
else: else:
# Determine format # Determine format
_, file_extension = os.path.splitext(out) _, file_extension = os.path.splitext(out)

View File

@ -18,6 +18,7 @@ class SerializerMutationOptions(MutationOptions):
model_class = None model_class = None
model_operations = ["create", "update"] model_operations = ["create", "update"]
serializer_class = None serializer_class = None
optional_fields = ()
def fields_for_serializer( def fields_for_serializer(
@ -27,6 +28,7 @@ def fields_for_serializer(
is_input=False, is_input=False,
convert_choices_to_enum=True, convert_choices_to_enum=True,
lookup_field=None, lookup_field=None,
optional_fields=(),
): ):
fields = OrderedDict() fields = OrderedDict()
for name, field in serializer.fields.items(): for name, field in serializer.fields.items():
@ -44,9 +46,13 @@ def fields_for_serializer(
if is_not_in_only or is_excluded: if is_not_in_only or is_excluded:
continue continue
is_optional = name in optional_fields
fields[name] = convert_serializer_field( fields[name] = convert_serializer_field(
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum field,
is_input=is_input,
convert_choices_to_enum=convert_choices_to_enum,
force_optional=is_optional,
) )
return fields return fields
@ -70,6 +76,7 @@ class SerializerMutation(ClientIDMutation):
exclude_fields=(), exclude_fields=(),
convert_choices_to_enum=True, convert_choices_to_enum=True,
_meta=None, _meta=None,
optional_fields=(),
**options **options
): ):
@ -95,6 +102,7 @@ class SerializerMutation(ClientIDMutation):
is_input=True, is_input=True,
convert_choices_to_enum=convert_choices_to_enum, convert_choices_to_enum=convert_choices_to_enum,
lookup_field=lookup_field, lookup_field=lookup_field,
optional_fields=optional_fields,
) )
output_fields = fields_for_serializer( output_fields = fields_for_serializer(
serializer, serializer,

View File

@ -19,7 +19,9 @@ def get_graphene_type_from_serializer_field(field):
) )
def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True): def convert_serializer_field(
field, is_input=True, convert_choices_to_enum=True, force_optional=False
):
""" """
Converts a django rest frameworks field to a graphql field Converts a django rest frameworks field to a graphql field
and marks the field as required if we are creating an input type and marks the field as required if we are creating an input type
@ -32,7 +34,10 @@ def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True)
graphql_type = get_graphene_type_from_serializer_field(field) graphql_type = get_graphene_type_from_serializer_field(field)
args = [] args = []
kwargs = {"description": field.help_text, "required": is_input and field.required} kwargs = {
"description": field.help_text,
"required": is_input and field.required and not force_optional,
}
# if it is a tuple or a list it means that we are returning # if it is a tuple or a list it means that we are returning
# the graphql type and the child type # the graphql type and the child type

View File

@ -3,7 +3,7 @@ import datetime
from py.test import raises from py.test import raises
from rest_framework import serializers from rest_framework import serializers
from graphene import Field, ResolveInfo from graphene import Field, ResolveInfo, NonNull, String
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from ...types import DjangoObjectType from ...types import DjangoObjectType
@ -98,6 +98,25 @@ def test_exclude_fields():
assert "created" not in MyMutation.Input._meta.fields assert "created" not in MyMutation.Input._meta.fields
def test_model_serializer_required_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
assert "cool_name" in MyMutation.Input._meta.fields
assert MyMutation.Input._meta.fields["cool_name"].type == NonNull(String)
def test_model_serializer_optional_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
optional_fields = ("cool_name",)
assert "cool_name" in MyMutation.Input._meta.fields
assert MyMutation.Input._meta.fields["cool_name"].type == String
def test_write_only_field(): def test_write_only_field():
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer): class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
password = serializers.CharField(write_only=True) password = serializers.CharField(write_only=True)

View File

@ -45,6 +45,7 @@ DEFAULTS = {
# This sets headerEditorEnabled GraphiQL option, for details go to # This sets headerEditorEnabled GraphiQL option, for details go to
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options # https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
"GRAPHIQL_HEADER_EDITOR_ENABLED": True, "GRAPHIQL_HEADER_EDITOR_ENABLED": True,
"ATOMIC_MUTATIONS": False,
} }
if settings.DEBUG: if settings.DEBUG:

View File

@ -0,0 +1,16 @@
from django import forms
from django.core.exceptions import ValidationError
from .models import Pet
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = "__all__"
def clean_age(self):
age = self.cleaned_data["age"]
if age >= 99:
raise ValidationError("Too old")
return age

View File

@ -6,6 +6,10 @@ from django.utils.translation import ugettext_lazy as _
CHOICES = ((1, "this"), (2, _("that"))) CHOICES = ((1, "this"), (2, _("that")))
class Person(models.Model):
name = models.CharField(max_length=30)
class Pet(models.Model): class Pet(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
age = models.PositiveIntegerField() age = models.PositiveIntegerField()

View File

@ -0,0 +1,18 @@
from graphene import Field
from graphene_django.forms.mutation import DjangoFormMutation, DjangoModelFormMutation
from .forms import PetForm
from .types import PetType
class PetFormMutation(DjangoFormMutation):
class Meta:
form_class = PetForm
class PetMutation(DjangoModelFormMutation):
pet = Field(PetType)
class Meta:
form_class = PetForm

View File

@ -1,6 +1,8 @@
import graphene import graphene
from graphene import ObjectType, Schema from graphene import ObjectType, Schema
from .mutations import PetFormMutation, PetMutation
class QueryRoot(ObjectType): class QueryRoot(ObjectType):
@ -19,6 +21,8 @@ class QueryRoot(ObjectType):
class MutationRoot(ObjectType): class MutationRoot(ObjectType):
pet_form_mutation = PetFormMutation.Field()
pet_mutation = PetMutation.Field()
write_test = graphene.Field(QueryRoot) write_test = graphene.Field(QueryRoot)
def resolve_write_test(self, info): def resolve_write_test(self, info):

View File

@ -241,6 +241,10 @@ def test_should_float_convert_float():
assert_conversion(models.FloatField, graphene.Float) assert_conversion(models.FloatField, graphene.Float)
def test_should_float_convert_decimal():
assert_conversion(models.DecimalField, graphene.Decimal)
def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist():
registry = Registry() registry = Registry()
dynamic_field = convert_django_field(Reporter._meta.local_many_to_many[0], registry) dynamic_field = convert_django_field(Reporter._meta.local_many_to_many[0], registry)

View File

@ -75,6 +75,39 @@ class TestDjangoListField:
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}] "reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
} }
def test_list_field_queryset_is_not_cached(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name",)
class Query(ObjectType):
reporters = DjangoListField(Reporter)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": []}
ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
result = schema.execute(query)
assert not result.errors
assert result.data == {
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
}
def test_override_resolver(self): def test_override_resolver(self):
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:

View File

@ -1213,6 +1213,103 @@ def test_should_have_next_page(graphene_settings):
} }
class TestBackwardPagination:
def setup_schema(self, graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
reporters = [Reporter(**kwargs) for kwargs in REPORTERS]
Reporter.objects.bulk_create(reporters)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
schema = graphene.Schema(query=Query)
return schema
def do_queries(self, schema):
# Simply last 3
query_last = """
query {
allReporters(last: 3) {
edges {
node {
firstName
}
}
}
}
"""
result = schema.execute(query_last)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3
assert [
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 3", "First 4", "First 5"]
# Use a combination of first and last
query_first_and_last = """
query {
allReporters(first: 4, last: 3) {
edges {
node {
firstName
}
}
}
}
"""
result = schema.execute(query_first_and_last)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3
assert [
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 1", "First 2", "First 3"]
# Use a combination of first and last and after
query_first_last_and_after = """
query queryAfter($after: String) {
allReporters(first: 4, last: 3, after: $after) {
edges {
node {
firstName
}
}
}
}
"""
after = base64.b64encode(b"arrayconnection:0").decode()
result = schema.execute(
query_first_last_and_after, variable_values=dict(after=after)
)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3
assert [
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 2", "First 3", "First 4"]
def test_should_query(self, graphene_settings):
"""
Backward pagination should work as expected
"""
schema = self.setup_schema(graphene_settings, max_limit=100)
self.do_queries(schema)
def test_should_query_with_low_max_limit(self, graphene_settings):
"""
When doing backward pagination (using last) in combination with a max limit higher than the number of objects
we should really retrieve the last ones.
"""
schema = self.setup_schema(graphene_settings, max_limit=4)
self.do_queries(schema)
def test_should_preserve_prefetch_related(django_assert_num_queries): def test_should_preserve_prefetch_related(django_assert_num_queries):
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:

View File

@ -51,7 +51,9 @@ def test_graphql_test_case_op_name(post_mock):
pass pass
tc = TestClass() tc = TestClass()
tc._pre_setup()
tc.setUpClass() tc.setUpClass()
tc.query("query { }", op_name="QueryName") tc.query("query { }", op_name="QueryName")
body = json.loads(post_mock.call_args.args[1]) body = json.loads(post_mock.call_args.args[1])
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request # `operationName` field from https://graphql.org/learn/serving-over-http/#post-request

View File

@ -2,6 +2,14 @@ import json
import pytest import pytest
from mock import patch
from django.db import connection
from graphene_django.settings import graphene_settings
from .models import Pet
try: try:
from urllib import urlencode from urllib import urlencode
except ImportError: except ImportError:
@ -558,3 +566,265 @@ def test_passes_request_into_context_request(client):
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == {"data": {"request": "testing"}} assert response_json(response) == {"data": {"request": "testing"}}
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
)
def test_form_mutation_multiple_creation_invalid_atomic_request(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": True, "ATOMIC_REQUESTS": False}
)
def test_form_mutation_multiple_creation_invalid_atomic_mutation_1(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", True)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_form_mutation_multiple_creation_invalid_atomic_mutation_2(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_form_mutation_multiple_creation_invalid_non_atomic(client):
query = """
mutation PetMutations {
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
errors {
field
messages
}
}
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petFormMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petFormMutation2"]["errors"] == []
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Enzo"
assert pet.age == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
)
def test_model_form_mutation_multiple_creation_invalid_atomic_request(client):
query = """
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petMutation1"]["pet"] is None
assert content["data"]["petMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 0
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_model_form_mutation_multiple_creation_invalid_non_atomic(client):
query = """
mutation PetMutations {
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
pet {
name
age
}
errors {
field
messages
}
}
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
pet {
name
age
}
errors {
field
messages
}
}
}
"""
response = client.post(url_string(query=query))
content = response_json(response)
assert "errors" not in content
assert content["data"]["petMutation1"]["pet"] is None
assert content["data"]["petMutation1"]["errors"] == [
{"field": "age", "messages": ["Too old"]}
]
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
assert Pet.objects.count() == 1
pet = Pet.objects.get()
assert pet.name == "Enzo"
assert pet.age == 0
@patch("graphene_django.utils.utils.transaction.set_rollback")
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
)
def test_query_errors_atomic_request(set_rollback_mock, client):
client.get(url_string(query="force error"))
set_rollback_mock.assert_called_once_with(True)
@patch("graphene_django.utils.utils.transaction.set_rollback")
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
@patch.dict(
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
)
def test_query_errors_non_atomic(set_rollback_mock, client):
client.get(url_string(query="force error"))
set_rollback_mock.assert_not_called()

View File

@ -0,0 +1,9 @@
from graphene_django.types import DjangoObjectType
from .models import Pet
class PetType(DjangoObjectType):
class Meta:
model = Pet
fields = "__all__"

View File

@ -1,5 +1,5 @@
import re import re
from unidecode import unidecode from text_unidecode import unidecode
def to_const(string): def to_const(string):

View File

@ -1,6 +1,7 @@
import json import json
import warnings
from django.test import TestCase, Client from django.test import Client, TestCase
DEFAULT_GRAPHQL_URL = "/graphql/" DEFAULT_GRAPHQL_URL = "/graphql/"
@ -68,12 +69,6 @@ class GraphQLTestCase(TestCase):
# URL to graphql endpoint # URL to graphql endpoint
GRAPHQL_URL = DEFAULT_GRAPHQL_URL GRAPHQL_URL = DEFAULT_GRAPHQL_URL
@classmethod
def setUpClass(cls):
super(GraphQLTestCase, cls).setUpClass()
cls._client = Client()
def query(self, query, op_name=None, input_data=None, variables=None, headers=None): def query(self, query, op_name=None, input_data=None, variables=None, headers=None):
""" """
Args: Args:
@ -99,19 +94,41 @@ class GraphQLTestCase(TestCase):
input_data=input_data, input_data=input_data,
variables=variables, variables=variables,
headers=headers, headers=headers,
client=self._client, client=self.client,
graphql_url=self.GRAPHQL_URL, graphql_url=self.GRAPHQL_URL,
) )
@property
def _client(self):
pass
@_client.getter
def _client(self):
warnings.warn(
"Using `_client` is deprecated in favour of `client`.",
PendingDeprecationWarning,
stacklevel=2,
)
return self.client
@_client.setter
def _client(self, client):
warnings.warn(
"Using `_client` is deprecated in favour of `client`.",
PendingDeprecationWarning,
stacklevel=2,
)
self.client = client
def assertResponseNoErrors(self, resp, msg=None): def assertResponseNoErrors(self, resp, msg=None):
""" """
Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`, Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`,
the call was fine. the call was fine.
:resp HttpResponse: Response :resp HttpResponse: Response
""" """
self.assertEqual(resp.status_code, 200)
content = json.loads(resp.content) content = json.loads(resp.content)
self.assertNotIn("errors", list(content.keys()), msg) self.assertEqual(resp.status_code, 200, msg or content)
self.assertNotIn("errors", list(content.keys()), msg or content)
def assertResponseHasErrors(self, resp, msg=None): def assertResponseHasErrors(self, resp, msg=None):
""" """
@ -119,4 +136,4 @@ class GraphQLTestCase(TestCase):
:resp HttpResponse: Response :resp HttpResponse: Response
""" """
content = json.loads(resp.content) content = json.loads(resp.content)
self.assertIn("errors", list(content.keys()), msg) self.assertIn("errors", list(content.keys()), msg or content)

View File

@ -0,0 +1,45 @@
import pytest
from .. import GraphQLTestCase
from ...tests.test_types import with_local_registry
from django.test import Client
@with_local_registry
def test_graphql_test_case_deprecated_client_getter():
"""
`GraphQLTestCase._client`' getter should raise pending deprecation warning.
"""
class TestClass(GraphQLTestCase):
GRAPHQL_SCHEMA = True
def runTest(self):
pass
tc = TestClass()
tc._pre_setup()
tc.setUpClass()
with pytest.warns(PendingDeprecationWarning):
tc._client
@with_local_registry
def test_graphql_test_case_deprecated_client_setter():
"""
`GraphQLTestCase._client`' setter should raise pending deprecation warning.
"""
class TestClass(GraphQLTestCase):
GRAPHQL_SCHEMA = True
def runTest(self):
pass
tc = TestClass()
tc._pre_setup()
tc.setUpClass()
with pytest.warns(PendingDeprecationWarning):
tc._client = Client()

View File

@ -1,7 +1,7 @@
import inspect import inspect
import six import six
from django.db import models from django.db import connection, models, transaction
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import Promise from django.utils.functional import Promise
@ -100,3 +100,9 @@ def import_single_dispatch():
) )
return singledispatch return singledispatch
def set_rollback():
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)

View File

@ -3,6 +3,7 @@ import json
import re import re
import six import six
from django.db import connection, transaction
from django.http import HttpResponse, HttpResponseNotAllowed from django.http import HttpResponse, HttpResponseNotAllowed
from django.http.response import HttpResponseBadRequest from django.http.response import HttpResponseBadRequest
from django.shortcuts import render from django.shortcuts import render
@ -17,6 +18,9 @@ from graphql.execution import ExecutionResult
from graphql.type.schema import GraphQLSchema from graphql.type.schema import GraphQLSchema
from graphql.execution.middleware import MiddlewareManager from graphql.execution.middleware import MiddlewareManager
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.utils.utils import set_rollback
from .settings import graphene_settings from .settings import graphene_settings
@ -203,11 +207,15 @@ class GraphQLView(View):
request, data, query, variables, operation_name, show_graphiql request, data, query, variables, operation_name, show_graphiql
) )
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
set_rollback()
status_code = 200 status_code = 200
if execution_result: if execution_result:
response = {} response = {}
if execution_result.errors: if execution_result.errors:
set_rollback()
response["errors"] = [ response["errors"] = [
self.format_error(e) for e in execution_result.errors self.format_error(e) for e in execution_result.errors
] ]
@ -312,14 +320,27 @@ class GraphQLView(View):
# executor is not a valid argument in all backends # executor is not a valid argument in all backends
extra_options["executor"] = self.executor extra_options["executor"] = self.executor
return document.execute( options = {
root_value=self.get_root_value(request), "root_value": self.get_root_value(request),
variable_values=variables, "variable_values": variables,
operation_name=operation_name, "operation_name": operation_name,
context_value=self.get_context(request), "context_value": self.get_context(request),
middleware=self.get_middleware(request), "middleware": self.get_middleware(request),
**extra_options }
) options.update(extra_options)
operation_type = document.get_operation_type(operation_name)
if operation_type == "mutation" and (
graphene_settings.ATOMIC_MUTATIONS is True
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
):
with transaction.atomic():
result = document.execute(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True)
return result
return document.execute(**options)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e], invalid=True) return ExecutionResult(errors=[e], invalid=True)

View File

@ -66,7 +66,7 @@ setup(
"Django>=1.11", "Django>=1.11",
"singledispatch>=3.4.0.3", "singledispatch>=3.4.0.3",
"promise>=2.1", "promise>=2.1",
"unidecode>=1.1.1,<2", "text-unidecode",
], ],
setup_requires=["pytest-runner"], setup_requires=["pytest-runner"],
tests_require=tests_require, tests_require=tests_require,