mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-07-11 08:42:32 +03:00
Merge remote-tracking branch 'up/v2' into fix-blank-field-enum
This commit is contained in:
commit
bf29f3a74c
11
README.md
11
README.md
|
@ -3,13 +3,13 @@
|
|||
|
||||
A [Django](https://www.djangoproject.com/) integration for [Graphene](http://graphene-python.org/).
|
||||
|
||||
[![travis][travis-image]][travis-url]
|
||||
[![build][build-image]][build-url]
|
||||
[![pypi][pypi-image]][pypi-url]
|
||||
[![Anaconda-Server Badge][conda-image]][conda-url]
|
||||
[![coveralls][coveralls-image]][coveralls-url]
|
||||
|
||||
[travis-image]: https://travis-ci.org/graphql-python/graphene-django.svg?branch=master&style=flat
|
||||
[travis-url]: https://travis-ci.org/graphql-python/graphene-django
|
||||
[build-image]: https://github.com/graphql-python/graphene-django/workflows/Tests/badge.svg
|
||||
[build-url]: https://github.com/graphql-python/graphene-django/actions
|
||||
[pypi-image]: https://img.shields.io/pypi/v/graphene-django.svg?style=flat
|
||||
[pypi-url]: https://pypi.org/project/graphene-django/
|
||||
[coveralls-image]: https://coveralls.io/repos/github/graphql-python/graphene-django/badge.svg?branch=master
|
||||
|
@ -110,6 +110,11 @@ To learn more check out the following [examples](examples/):
|
|||
* **Relay Schema**: [Starwars Relay example](examples/starwars)
|
||||
|
||||
|
||||
## GraphQL testing clients
|
||||
- [Firecamp](https://firecamp.io/graphql)
|
||||
- [GraphiQL](https://github.com/graphql/graphiql)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
|
|
|
@ -114,8 +114,8 @@ Contributing
|
|||
See `CONTRIBUTING.md <CONTRIBUTING.md>`__.
|
||||
|
||||
.. |Graphene Logo| image:: http://graphene-python.org/favicon.png
|
||||
.. |Build Status| image:: https://travis-ci.org/graphql-python/graphene-django.svg?branch=master
|
||||
:target: https://travis-ci.org/graphql-python/graphene-django
|
||||
.. |Build Status| image:: https://github.com/graphql-python/graphene-django/workflows/Tests/badge.svg
|
||||
:target: https://github.com/graphql-python/graphene-django/actions
|
||||
.. |PyPI version| image:: https://badge.fury.io/py/graphene-django.svg
|
||||
:target: https://badge.fury.io/py/graphene-django
|
||||
.. |Coverage Status| image:: https://coveralls.io/repos/graphql-python/graphene-django/badge.svg?branch=master&service=github
|
||||
|
|
|
@ -3,7 +3,7 @@ Django Debug Middleware
|
|||
|
||||
You can debug your GraphQL queries in a similar way to
|
||||
`django-debug-toolbar <https://django-debug-toolbar.readthedocs.org/>`__,
|
||||
but outputing in the results in GraphQL response as fields, instead of
|
||||
but outputting in the results in GraphQL response as fields, instead of
|
||||
the graphical HTML interface.
|
||||
|
||||
For that, you will need to add the plugin in your graphene schema.
|
||||
|
@ -43,7 +43,7 @@ And in your ``settings.py``:
|
|||
Querying
|
||||
--------
|
||||
|
||||
You can query it for outputing all the sql transactions that happened in
|
||||
You can query it for outputting all the sql transactions that happened in
|
||||
the GraphQL request, like:
|
||||
|
||||
.. code::
|
||||
|
|
|
@ -36,7 +36,8 @@ Simple example
|
|||
# The class attributes define the response of the mutation
|
||||
question = graphene.Field(QuestionType)
|
||||
|
||||
def mutate(self, info, text, id):
|
||||
@classmethod
|
||||
def mutate(cls, root, info, text, id):
|
||||
question = Question.objects.get(pk=id)
|
||||
question.text = text
|
||||
question.save()
|
||||
|
@ -229,3 +230,121 @@ This argument is also sent back to the client with the mutation result
|
|||
(you do not have to do anything). For services that manage
|
||||
a pool of many GraphQL requests in bulk, the ``clientIDMutation``
|
||||
allows you to match up a specific mutation with the response.
|
||||
|
||||
|
||||
|
||||
Django Database Transactions
|
||||
----------------------------
|
||||
|
||||
Django gives you a few ways to control how database transactions are managed.
|
||||
|
||||
Tying transactions to HTTP requests
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A common way to handle transactions in Django is to wrap each request in a transaction.
|
||||
Set ``ATOMIC_REQUESTS`` settings to ``True`` in the configuration of each database for
|
||||
which you want to enable this behavior.
|
||||
|
||||
It works like this. Before calling ``GraphQLView`` Django starts a transaction. If the
|
||||
response is produced without problems, Django commits the transaction. If the view, a
|
||||
``DjangoFormMutation`` or a ``DjangoModelFormMutation`` produces an exception, Django
|
||||
rolls back the transaction.
|
||||
|
||||
.. warning::
|
||||
|
||||
While the simplicity of this transaction model is appealing, it also makes it
|
||||
inefficient when traffic increases. Opening a transaction for every request has some
|
||||
overhead. The impact on performance depends on the query patterns of your application
|
||||
and on how well your database handles locking.
|
||||
|
||||
Check the next section for a better solution.
|
||||
|
||||
Tying transactions to mutations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A mutation can contain multiple fields, just like a query. There's one important
|
||||
distinction between queries and mutations, other than the name:
|
||||
|
||||
..
|
||||
|
||||
`While query fields are executed in parallel, mutation fields run in series, one
|
||||
after the other.`
|
||||
|
||||
This means that if we send two ``incrementCredits`` mutations in one request, the first
|
||||
is guaranteed to finish before the second begins, ensuring that we don't end up with a
|
||||
race condition with ourselves.
|
||||
|
||||
On the other hand, if the first ``incrementCredits`` runs successfully but the second
|
||||
one does not, the operation cannot be retried as it is. That's why is a good idea to
|
||||
run all mutation fields in a transaction, to guarantee all occur or nothing occurs.
|
||||
|
||||
To enable this behavior for all databases set the graphene ``ATOMIC_MUTATIONS`` settings
|
||||
to ``True`` in your settings file:
|
||||
|
||||
.. code:: python
|
||||
|
||||
GRAPHENE = {
|
||||
# ...
|
||||
"ATOMIC_MUTATIONS": True,
|
||||
}
|
||||
|
||||
On the contrary, if you want to enable this behavior for a specific database, set
|
||||
``ATOMIC_MUTATIONS`` to ``True`` in your database settings:
|
||||
|
||||
.. code:: python
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
# ...
|
||||
"ATOMIC_MUTATIONS": True,
|
||||
},
|
||||
# ...
|
||||
}
|
||||
|
||||
Now, given the following example mutation:
|
||||
|
||||
.. code::
|
||||
|
||||
mutation IncreaseCreditsTwice {
|
||||
|
||||
increaseCredits1: increaseCredits(input: { amount: 10 }) {
|
||||
balance
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
increaseCredits2: increaseCredits(input: { amount: -1 }) {
|
||||
balance
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
The server is going to return something like:
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"data": {
|
||||
"increaseCredits1": {
|
||||
"balance": 10.0,
|
||||
"errors": []
|
||||
},
|
||||
"increaseCredits2": {
|
||||
"balance": null,
|
||||
"errors": [
|
||||
{
|
||||
"field": "amount",
|
||||
"message": "Amount should be a positive number"
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
But the balance will remain the same.
|
||||
|
|
|
@ -287,7 +287,7 @@ Where "foo" is the name of the field declared in the ``Query`` object.
|
|||
class Query(graphene.ObjectType):
|
||||
foo = graphene.List(QuestionType)
|
||||
|
||||
def resolve_foo(root, info):
|
||||
def resolve_foo(root, info, **kwargs):
|
||||
id = kwargs.get("id")
|
||||
return Question.objects.get(id)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from .fields import DjangoConnectionField, DjangoListField
|
||||
from .types import DjangoObjectType
|
||||
|
||||
__version__ = "2.13.0"
|
||||
__version__ = "2.15.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
|
|
1
graphene_django/constants.py
Normal file
1
graphene_django/constants.py
Normal file
|
@ -0,0 +1 @@
|
|||
MUTATION_ERRORS_FLAG = "graphene_mutation_has_errors"
|
|
@ -21,6 +21,7 @@ from graphene import (
|
|||
NonNull,
|
||||
String,
|
||||
Time,
|
||||
Decimal,
|
||||
)
|
||||
from graphene.types.resolver import get_default_resolver
|
||||
from graphene.types.json import JSONString
|
||||
|
@ -185,6 +186,10 @@ def convert_field_to_boolean(field, registry=None):
|
|||
|
||||
|
||||
@convert_django_field.register(models.DecimalField)
|
||||
def convert_field_to_decimal(field, registry=None):
|
||||
return Decimal(description=field.help_text, required=not field.null)
|
||||
|
||||
|
||||
@convert_django_field.register(models.FloatField)
|
||||
@convert_django_field.register(models.DurationField)
|
||||
def convert_field_to_float(field, registry=None):
|
||||
|
|
|
@ -43,16 +43,16 @@ class DjangoListField(Field):
|
|||
def model(self):
|
||||
return self._underlying_type._meta.model
|
||||
|
||||
def get_default_queryset(self):
|
||||
return self.model._default_manager.get_queryset()
|
||||
def get_manager(self):
|
||||
return self.model._default_manager
|
||||
|
||||
@staticmethod
|
||||
def list_resolver(
|
||||
django_object_type, resolver, default_queryset, root, info, **args
|
||||
django_object_type, resolver, default_manager, root, info, **args
|
||||
):
|
||||
queryset = maybe_queryset(resolver(root, info, **args))
|
||||
if queryset is None:
|
||||
queryset = default_queryset
|
||||
queryset = maybe_queryset(default_manager)
|
||||
|
||||
if isinstance(queryset, QuerySet):
|
||||
# Pass queryset to the DjangoObjectType get_queryset method
|
||||
|
@ -66,10 +66,7 @@ class DjangoListField(Field):
|
|||
_type = _type.of_type
|
||||
django_object_type = _type.of_type.of_type
|
||||
return partial(
|
||||
self.list_resolver,
|
||||
django_object_type,
|
||||
parent_resolver,
|
||||
self.get_default_queryset(),
|
||||
self.list_resolver, django_object_type, parent_resolver, self.get_manager(),
|
||||
)
|
||||
|
||||
|
||||
|
@ -147,14 +144,11 @@ class DjangoConnectionField(ConnectionField):
|
|||
|
||||
if isinstance(iterable, QuerySet):
|
||||
list_length = iterable.count()
|
||||
list_slice_length = (
|
||||
min(max_limit, list_length) if max_limit is not None else list_length
|
||||
)
|
||||
else:
|
||||
list_length = len(iterable)
|
||||
list_slice_length = (
|
||||
min(max_limit, list_length) if max_limit is not None else list_length
|
||||
)
|
||||
list_slice_length = (
|
||||
min(max_limit, list_length) if max_limit is not None else list_length
|
||||
)
|
||||
|
||||
# If after is higher than list_length, connection_from_list_slice
|
||||
# would try to do a negative slicing which makes django throw an
|
||||
|
@ -162,7 +156,11 @@ class DjangoConnectionField(ConnectionField):
|
|||
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
|
||||
|
||||
if max_limit is not None and "first" not in args:
|
||||
args["first"] = max_limit
|
||||
if "last" in args:
|
||||
args["first"] = list_length
|
||||
list_slice_length = list_length
|
||||
else:
|
||||
args["first"] = max_limit
|
||||
|
||||
connection = connection_from_list_slice(
|
||||
iterable[after:],
|
||||
|
|
|
@ -9,7 +9,7 @@ if not DJANGO_FILTER_INSTALLED:
|
|||
)
|
||||
else:
|
||||
from .fields import DjangoFilterConnectionField
|
||||
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
|
||||
__all__ = [
|
||||
"DjangoFilterConnectionField",
|
||||
|
|
|
@ -3,6 +3,7 @@ from functools import partial
|
|||
|
||||
from django.core.exceptions import ValidationError
|
||||
from graphene.types.argument import to_arguments
|
||||
from graphene.utils.str_converters import to_snake_case
|
||||
from ..fields import DjangoConnectionField
|
||||
from .utils import get_filtering_args_from_filterset, get_filterset_class
|
||||
|
||||
|
@ -21,6 +22,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
|||
self._fields = fields
|
||||
self._provided_filterset_class = filterset_class
|
||||
self._filterset_class = None
|
||||
self._filtering_args = None
|
||||
self._extra_filter_meta = extra_filter_meta
|
||||
self._base_args = None
|
||||
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
|
||||
|
@ -50,18 +52,31 @@ class DjangoFilterConnectionField(DjangoConnectionField):
|
|||
|
||||
@property
|
||||
def filtering_args(self):
|
||||
return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
|
||||
if not self._filtering_args:
|
||||
self._filtering_args = get_filtering_args_from_filterset(
|
||||
self.filterset_class, self.node_type
|
||||
)
|
||||
return self._filtering_args
|
||||
|
||||
@classmethod
|
||||
def resolve_queryset(
|
||||
cls, connection, iterable, info, args, filtering_args, filterset_class
|
||||
):
|
||||
def filter_kwargs():
|
||||
kwargs = {}
|
||||
for k, v in args.items():
|
||||
if k in filtering_args:
|
||||
if k == "order_by":
|
||||
v = to_snake_case(v)
|
||||
kwargs[k] = v
|
||||
return kwargs
|
||||
|
||||
qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
|
||||
connection, iterable, info, args
|
||||
)
|
||||
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
|
||||
|
||||
filterset = filterset_class(
|
||||
data=filter_kwargs, queryset=qs, request=info.context
|
||||
data=filter_kwargs(), queryset=qs, request=info.context
|
||||
)
|
||||
if filterset.form.is_valid():
|
||||
return filterset.qs
|
||||
|
|
75
graphene_django/filter/filters.py
Normal file
75
graphene_django/filter/filters.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
from django.core.exceptions import ValidationError
|
||||
from django.forms import Field
|
||||
|
||||
from django_filters import Filter, MultipleChoiceFilter
|
||||
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
class GlobalIDFilter(Filter):
|
||||
"""
|
||||
Filter for Relay global ID.
|
||||
"""
|
||||
|
||||
field_class = GlobalIDFormField
|
||||
|
||||
def filter(self, qs, value):
|
||||
""" Convert the filter value to a primary key before filtering """
|
||||
_id = None
|
||||
if value is not None:
|
||||
_, _id = from_global_id(value)
|
||||
return super(GlobalIDFilter, self).filter(qs, _id)
|
||||
|
||||
|
||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
||||
field_class = GlobalIDMultipleChoiceField
|
||||
|
||||
def filter(self, qs, value):
|
||||
gids = [from_global_id(v)[1] for v in value]
|
||||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
||||
|
||||
|
||||
class InFilter(Filter):
|
||||
"""
|
||||
Filter for a list of value using the `__in` Django filter.
|
||||
"""
|
||||
|
||||
def filter(self, qs, value):
|
||||
"""
|
||||
Override the default filter class to check first weather the list is
|
||||
empty or not.
|
||||
This needs to be done as in this case we expect to get an empty output
|
||||
(if not an exclude filter) but django_filter consider an empty list
|
||||
to be an empty input value (see `EMPTY_VALUES`) meaning that
|
||||
the filter does not need to be applied (hence returning the original
|
||||
queryset).
|
||||
"""
|
||||
if value is not None and len(value) == 0:
|
||||
if self.exclude:
|
||||
return qs
|
||||
else:
|
||||
return qs.none()
|
||||
else:
|
||||
return super(InFilter, self).filter(qs, value)
|
||||
|
||||
|
||||
def validate_range(value):
|
||||
"""
|
||||
Validator for range filter input: the list of value must be of length 2.
|
||||
Note that validators are only run if the value is not empty.
|
||||
"""
|
||||
if len(value) != 2:
|
||||
raise ValidationError(
|
||||
"Invalid range specified: it needs to contain 2 values.", code="invalid"
|
||||
)
|
||||
|
||||
|
||||
class RangeField(Field):
|
||||
default_validators = [validate_range]
|
||||
empty_values = [None]
|
||||
|
||||
|
||||
class RangeFilter(Filter):
|
||||
field_class = RangeField
|
|
@ -1,32 +1,11 @@
|
|||
import itertools
|
||||
|
||||
from django.db import models
|
||||
from django_filters import Filter, MultipleChoiceFilter, VERSION
|
||||
from django_filters import VERSION
|
||||
from django_filters.filterset import BaseFilterSet, FilterSet
|
||||
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
|
||||
|
||||
from graphql_relay.node.node import from_global_id
|
||||
|
||||
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
|
||||
|
||||
|
||||
class GlobalIDFilter(Filter):
|
||||
field_class = GlobalIDFormField
|
||||
|
||||
def filter(self, qs, value):
|
||||
""" Convert the filter value to a primary key before filtering """
|
||||
_id = None
|
||||
if value is not None:
|
||||
_, _id = from_global_id(value)
|
||||
return super(GlobalIDFilter, self).filter(qs, _id)
|
||||
|
||||
|
||||
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
|
||||
field_class = GlobalIDMultipleChoiceField
|
||||
|
||||
def filter(self, qs, value):
|
||||
gids = [from_global_id(v)[1] for v in value]
|
||||
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
|
||||
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
|
||||
|
||||
|
||||
GRAPHENE_FILTER_SET_OVERRIDES = {
|
||||
|
|
|
@ -21,7 +21,7 @@ class ReporterFilter(django_filters.FilterSet):
|
|||
model = Reporter
|
||||
fields = ["first_name", "last_name", "email", "pets"]
|
||||
|
||||
order_by = OrderingFilter(fields=("pub_date",))
|
||||
order_by = OrderingFilter(fields=("first_name",))
|
||||
|
||||
|
||||
class PetFilter(django_filters.FilterSet):
|
||||
|
|
|
@ -713,6 +713,73 @@ def test_should_query_filter_node_limit():
|
|||
assert result.data == expected
|
||||
|
||||
|
||||
def test_order_by():
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(ObjectType):
|
||||
all_reporters = DjangoFilterConnectionField(
|
||||
ReporterType, filterset_class=ReporterFilter
|
||||
)
|
||||
|
||||
Reporter.objects.create(first_name="b")
|
||||
Reporter.objects.create(first_name="a")
|
||||
|
||||
schema = Schema(query=Query)
|
||||
query = """
|
||||
query NodeFilteringQuery {
|
||||
allReporters(orderBy: "-firstName") {
|
||||
edges {
|
||||
node {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected = {
|
||||
"allReporters": {
|
||||
"edges": [{"node": {"firstName": "b"}}, {"node": {"firstName": "a"}}]
|
||||
}
|
||||
}
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
query = """
|
||||
query NodeFilteringQuery {
|
||||
allReporters(orderBy: "-first_name") {
|
||||
edges {
|
||||
node {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
query = """
|
||||
query NodeFilteringQuery {
|
||||
allReporters(orderBy: "-firtsnaMe") {
|
||||
edges {
|
||||
node {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert result.errors
|
||||
|
||||
|
||||
def test_order_by_is_perserved():
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
|
|
183
graphene_django/filter/tests/test_in_filter.py
Normal file
183
graphene_django/filter/tests/test_in_filter.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
import pytest
|
||||
|
||||
from django_filters import FilterSet
|
||||
from django_filters import rest_framework as filters
|
||||
from graphene import ObjectType, Schema
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.models import Pet, Person
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PetNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
interfaces = (Node,)
|
||||
filter_fields = {
|
||||
"name": ["exact", "in"],
|
||||
"age": ["exact", "in", "range"],
|
||||
}
|
||||
|
||||
|
||||
class PersonFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = Person
|
||||
fields = {}
|
||||
|
||||
names = filters.BaseInFilter(method="filter_names")
|
||||
|
||||
def filter_names(self, qs, name, value):
|
||||
return qs.filter(name__in=value)
|
||||
|
||||
|
||||
class PersonNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Person
|
||||
interfaces = (Node,)
|
||||
filterset_class = PersonFilterSet
|
||||
|
||||
|
||||
class Query(ObjectType):
|
||||
pets = DjangoFilterConnectionField(PetNode)
|
||||
people = DjangoFilterConnectionField(PersonNode)
|
||||
|
||||
|
||||
def test_string_in_filter():
|
||||
"""
|
||||
Test in filter on a string field.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=3)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (name_In: ["Brutus", "Jojo, the rabbit"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Brutus"}},
|
||||
{"node": {"name": "Jojo, the rabbit"}},
|
||||
]
|
||||
|
||||
|
||||
def test_string_in_filter_with_filterset_class():
|
||||
"""Test in filter on a string field with a custom filterset class."""
|
||||
Person.objects.create(name="John")
|
||||
Person.objects.create(name="Michael")
|
||||
Person.objects.create(name="Angela")
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
people (names: ["John", "Michael"]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["people"]["edges"] == [
|
||||
{"node": {"name": "John"}},
|
||||
{"node": {"name": "Michael"}},
|
||||
]
|
||||
|
||||
|
||||
def test_int_in_filter():
|
||||
"""
|
||||
Test in filter on an integer field.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=3)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (age_In: [3]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Mimi"}},
|
||||
{"node": {"name": "Jojo, the rabbit"}},
|
||||
]
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (age_In: [3, 12]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Brutus"}},
|
||||
{"node": {"name": "Mimi"}},
|
||||
{"node": {"name": "Jojo, the rabbit"}},
|
||||
]
|
||||
|
||||
|
||||
def test_in_filter_with_empty_list():
|
||||
"""
|
||||
Check that using a in filter with an empty list provided as input returns no objects.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=8)
|
||||
Pet.objects.create(name="Picotin", age=5)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (name_In: []) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert len(result.data["pets"]["edges"]) == 0
|
115
graphene_django/filter/tests/test_range_filter.py
Normal file
115
graphene_django/filter/tests/test_range_filter.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import ast
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from django_filters import FilterSet
|
||||
from django_filters import rest_framework as filters
|
||||
from graphene import ObjectType, Schema
|
||||
from graphene.relay import Node
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.models import Pet
|
||||
from graphene_django.utils import DJANGO_FILTER_INSTALLED
|
||||
|
||||
pytestmark = []
|
||||
|
||||
if DJANGO_FILTER_INSTALLED:
|
||||
from graphene_django.filter import DjangoFilterConnectionField
|
||||
else:
|
||||
pytestmark.append(
|
||||
pytest.mark.skipif(
|
||||
True, reason="django_filters not installed or not compatible"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PetNode(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
interfaces = (Node,)
|
||||
filter_fields = {
|
||||
"name": ["exact", "in"],
|
||||
"age": ["exact", "in", "range"],
|
||||
}
|
||||
|
||||
|
||||
class Query(ObjectType):
|
||||
pets = DjangoFilterConnectionField(PetNode)
|
||||
|
||||
|
||||
def test_int_range_filter():
|
||||
"""
|
||||
Test range filter on an integer field.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=8)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
Pet.objects.create(name="Picotin", age=5)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
pets (age_Range: [4, 9]) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data["pets"]["edges"] == [
|
||||
{"node": {"name": "Mimi"}},
|
||||
{"node": {"name": "Picotin"}},
|
||||
]
|
||||
|
||||
|
||||
def test_range_filter_with_invalid_input():
|
||||
"""
|
||||
Test range filter used with invalid inputs raise an error.
|
||||
"""
|
||||
Pet.objects.create(name="Brutus", age=12)
|
||||
Pet.objects.create(name="Mimi", age=8)
|
||||
Pet.objects.create(name="Jojo, the rabbit", age=3)
|
||||
Pet.objects.create(name="Picotin", age=5)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query ($rangeValue: [Int]) {
|
||||
pets (age_Range: $rangeValue) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
expected_error = json.dumps(
|
||||
{
|
||||
"age__range": [
|
||||
{
|
||||
"message": "Invalid range specified: it needs to contain 2 values.",
|
||||
"code": "invalid",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Empty list
|
||||
result = schema.execute(query, variables={"rangeValue": []})
|
||||
assert len(result.errors) == 1
|
||||
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
|
||||
|
||||
# Only one item in the list
|
||||
result = schema.execute(query, variables={"rangeValue": [1]})
|
||||
assert len(result.errors) == 1
|
||||
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
|
||||
|
||||
# More than 2 items in the list
|
||||
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
|
||||
assert len(result.errors) == 1
|
||||
assert ast.literal_eval(result.errors[0].message)[0] == expected_error
|
|
@ -1,7 +1,12 @@
|
|||
import six
|
||||
|
||||
from graphene import List
|
||||
|
||||
from django_filters.utils import get_model_field
|
||||
from django_filters.filters import Filter, BaseCSVFilter
|
||||
|
||||
from .filterset import custom_filterset_factory, setup_filterset
|
||||
from .filters import InFilter, RangeFilter
|
||||
|
||||
|
||||
def get_filtering_args_from_filterset(filterset_class, type):
|
||||
|
@ -15,23 +20,34 @@ def get_filtering_args_from_filterset(filterset_class, type):
|
|||
model = filterset_class._meta.model
|
||||
for name, filter_field in six.iteritems(filterset_class.base_filters):
|
||||
form_field = None
|
||||
filter_type = filter_field.lookup_expr
|
||||
|
||||
if name in filterset_class.declared_filters:
|
||||
# Get the filter field from the explicitly declared filter
|
||||
form_field = filter_field.field
|
||||
field = convert_form_field(form_field)
|
||||
else:
|
||||
# Get the filter field with no explicit type declaration
|
||||
model_field = get_model_field(model, filter_field.field_name)
|
||||
filter_type = filter_field.lookup_expr
|
||||
if filter_type != "isnull" and hasattr(model_field, "formfield"):
|
||||
form_field = model_field.formfield(
|
||||
required=filter_field.extra.get("required", False)
|
||||
)
|
||||
|
||||
# Fallback to field defined on filter if we can't get it from the
|
||||
# model field
|
||||
if not form_field:
|
||||
form_field = filter_field.field
|
||||
# Fallback to field defined on filter if we can't get it from the
|
||||
# model field
|
||||
if not form_field:
|
||||
form_field = filter_field.field
|
||||
|
||||
field_type = convert_form_field(form_field).Argument()
|
||||
field = convert_form_field(form_field)
|
||||
|
||||
if filter_type in ["in", "range"]:
|
||||
# Replace CSV filters (`in`, `range`) argument type to be a list of
|
||||
# the same type as the field. See comments in
|
||||
# `replace_csv_filters` method for more details.
|
||||
field = List(field.get_type())
|
||||
|
||||
field_type = field.Argument()
|
||||
field_type.description = filter_field.label
|
||||
args[name] = field_type
|
||||
|
||||
|
@ -39,9 +55,50 @@ def get_filtering_args_from_filterset(filterset_class, type):
|
|||
|
||||
|
||||
def get_filterset_class(filterset_class, **meta):
|
||||
"""Get the class to be used as the FilterSet"""
|
||||
"""
|
||||
Get the class to be used as the FilterSet.
|
||||
"""
|
||||
if filterset_class:
|
||||
# If were given a FilterSet class, then set it up and
|
||||
# return it
|
||||
return setup_filterset(filterset_class)
|
||||
return custom_filterset_factory(**meta)
|
||||
# If were given a FilterSet class, then set it up.
|
||||
graphene_filterset_class = setup_filterset(filterset_class)
|
||||
else:
|
||||
# Otherwise create one.
|
||||
graphene_filterset_class = custom_filterset_factory(**meta)
|
||||
|
||||
replace_csv_filters(graphene_filterset_class)
|
||||
return graphene_filterset_class
|
||||
|
||||
|
||||
def replace_csv_filters(filterset_class):
|
||||
"""
|
||||
Replace the "in" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
|
||||
but regular Filter objects that simply use the input value as filter argument on the queryset.
|
||||
|
||||
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
|
||||
can actually have a list as input and have a proper type verification of each value in the list.
|
||||
|
||||
See issue https://github.com/graphql-python/graphene-django/issues/1068.
|
||||
"""
|
||||
for name, filter_field in six.iteritems(filterset_class.base_filters):
|
||||
filter_type = filter_field.lookup_expr
|
||||
if filter_type == "in":
|
||||
assert isinstance(filter_field, BaseCSVFilter)
|
||||
filterset_class.base_filters[name] = InFilter(
|
||||
field_name=filter_field.field_name,
|
||||
lookup_expr=filter_field.lookup_expr,
|
||||
label=filter_field.label,
|
||||
method=filter_field.method,
|
||||
exclude=filter_field.exclude,
|
||||
**filter_field.extra
|
||||
)
|
||||
|
||||
if filter_type == "range":
|
||||
assert isinstance(filter_field, BaseCSVFilter)
|
||||
filterset_class.base_filters[name] = RangeFilter(
|
||||
field_name=filter_field.field_name,
|
||||
lookup_expr=filter_field.lookup_expr,
|
||||
label=filter_field.label,
|
||||
method=filter_field.method,
|
||||
exclude=filter_field.exclude,
|
||||
**filter_field.extra
|
||||
)
|
||||
|
|
|
@ -63,6 +63,11 @@ def convert_form_field_to_list(field):
|
|||
return List(ID, required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.MultipleChoiceField)
|
||||
def convert_form_field_to_string_list(field):
|
||||
return List(String, required=field.required)
|
||||
|
||||
|
||||
@convert_form_field.register(forms.DateField)
|
||||
def convert_form_field_to_date(field):
|
||||
return Date(description=field.help_text, required=field.required)
|
||||
|
|
|
@ -11,8 +11,13 @@ from graphene.types.mutation import MutationOptions
|
|||
# InputObjectType,
|
||||
# )
|
||||
from graphene.types.utils import yank_fields_from_attrs
|
||||
from graphene_django.constants import MUTATION_ERRORS_FLAG
|
||||
from graphene_django.registry import get_global_registry
|
||||
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import connection
|
||||
|
||||
from ..types import ErrorType
|
||||
from .converter import convert_form_field
|
||||
|
||||
|
@ -46,6 +51,7 @@ class BaseDjangoFormMutation(ClientIDMutation):
|
|||
return cls.perform_mutate(form, info)
|
||||
else:
|
||||
errors = ErrorType.from_errors(form.errors)
|
||||
_set_errors_flag_to_context(info)
|
||||
|
||||
return cls(errors=errors, **form.data)
|
||||
|
||||
|
@ -170,6 +176,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
|
|||
return cls.perform_mutate(form, info)
|
||||
else:
|
||||
errors = ErrorType.from_errors(form.errors)
|
||||
_set_errors_flag_to_context(info)
|
||||
|
||||
return cls(errors=errors)
|
||||
|
||||
|
@ -178,3 +185,9 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
|
|||
obj = form.save()
|
||||
kwargs = {cls._meta.return_field_name: obj}
|
||||
return cls(errors=[], **kwargs)
|
||||
|
||||
|
||||
def _set_errors_flag_to_context(info):
|
||||
# This is not ideal but necessary to keep the response errors empty
|
||||
if info and info.context:
|
||||
setattr(info.context, MUTATION_ERRORS_FLAG, True)
|
||||
|
|
|
@ -101,7 +101,14 @@ def test_should_decimal_convert_float():
|
|||
assert_conversion(forms.DecimalField, Float)
|
||||
|
||||
|
||||
def test_should_multiple_choice_convert_connectionorlist():
|
||||
def test_should_multiple_choice_convert_list():
|
||||
field = forms.MultipleChoiceField()
|
||||
graphene_type = convert_form_field(field)
|
||||
assert isinstance(graphene_type, List)
|
||||
assert graphene_type.of_type == String
|
||||
|
||||
|
||||
def test_should_model_multiple_choice_convert_connectionorlist():
|
||||
field = forms.ModelMultipleChoiceField(queryset=None)
|
||||
graphene_type = convert_form_field(field)
|
||||
assert isinstance(graphene_type, List)
|
||||
|
|
|
@ -5,21 +5,13 @@ from py.test import raises
|
|||
|
||||
from graphene import Field, ObjectType, Schema, String
|
||||
from graphene_django import DjangoObjectType
|
||||
from graphene_django.tests.forms import PetForm
|
||||
from graphene_django.tests.models import Pet
|
||||
from graphene_django.tests.mutations import PetMutation
|
||||
|
||||
from ..mutation import DjangoFormMutation, DjangoModelFormMutation
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pet_type():
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
return PetType
|
||||
|
||||
|
||||
class MyForm(forms.Form):
|
||||
text = forms.CharField()
|
||||
|
||||
|
@ -33,18 +25,6 @@ class MyForm(forms.Form):
|
|||
pass
|
||||
|
||||
|
||||
class PetForm(forms.ModelForm):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
def clean_age(self):
|
||||
age = self.cleaned_data["age"]
|
||||
if age >= 99:
|
||||
raise ValidationError("Too old")
|
||||
return age
|
||||
|
||||
|
||||
def test_needs_form_class():
|
||||
with raises(Exception) as exc:
|
||||
|
||||
|
@ -70,11 +50,18 @@ def test_has_input_fields():
|
|||
assert "text" in MyMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_mutation_error_camelcased(pet_type, graphene_settings):
|
||||
def test_mutation_error_camelcased(graphene_settings):
|
||||
class ExtraPetForm(PetForm):
|
||||
test_field = forms.CharField(required=True)
|
||||
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = ExtraPetForm
|
||||
|
||||
|
@ -146,21 +133,13 @@ def test_form_valid_input():
|
|||
assert result.data["myMutation"]["text"] == "VALID_INPUT"
|
||||
|
||||
|
||||
def test_default_meta_fields(pet_type):
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
def test_default_meta_fields():
|
||||
assert PetMutation._meta.model is Pet
|
||||
assert PetMutation._meta.return_field_name == "pet"
|
||||
assert "pet" in PetMutation._meta.fields
|
||||
|
||||
|
||||
def test_default_input_meta_fields(pet_type):
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
def test_default_input_meta_fields():
|
||||
assert PetMutation._meta.model is Pet
|
||||
assert PetMutation._meta.return_field_name == "pet"
|
||||
assert "name" in PetMutation.Input._meta.fields
|
||||
|
@ -168,8 +147,15 @@ def test_default_input_meta_fields(pet_type):
|
|||
assert "id" in PetMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_exclude_fields_input_meta_fields(pet_type):
|
||||
def test_exclude_fields_input_meta_fields():
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
exclude_fields = ["id"]
|
||||
|
@ -182,8 +168,15 @@ def test_exclude_fields_input_meta_fields(pet_type):
|
|||
assert "id" not in PetMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_custom_return_field_name(pet_type):
|
||||
def test_custom_return_field_name():
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
model = Pet
|
||||
|
@ -194,13 +187,7 @@ def test_custom_return_field_name(pet_type):
|
|||
assert "animal" in PetMutation._meta.fields
|
||||
|
||||
|
||||
def test_model_form_mutation_mutate_existing(pet_type):
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(pet_type)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
def test_model_form_mutation_mutate_existing():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
|
@ -229,13 +216,7 @@ def test_model_form_mutation_mutate_existing(pet_type):
|
|||
assert pet.name == "Mia"
|
||||
|
||||
|
||||
def test_model_form_mutation_creates_new(pet_type):
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(pet_type)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
def test_model_form_mutation_creates_new():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
|
@ -265,13 +246,7 @@ def test_model_form_mutation_creates_new(pet_type):
|
|||
assert pet.age == 10
|
||||
|
||||
|
||||
def test_model_form_mutation_invalid_input(pet_type):
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(pet_type)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
def test_model_form_mutation_invalid_input():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
|
@ -301,11 +276,7 @@ def test_model_form_mutation_invalid_input(pet_type):
|
|||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
def test_model_form_mutation_mutate_invalid_form(pet_type):
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
def test_model_form_mutation_mutate_invalid_form():
|
||||
result = PetMutation.mutate_and_get_payload(None, None)
|
||||
|
||||
# A pet was not created
|
||||
|
@ -317,3 +288,98 @@ def test_model_form_mutation_mutate_invalid_form(pet_type):
|
|||
assert result.errors[1].messages == ["This field is required."]
|
||||
assert "age" in fields_w_error
|
||||
assert "name" in fields_w_error
|
||||
|
||||
|
||||
def test_model_form_mutation_multiple_creation_valid():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
"""
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 10 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert result.errors is None
|
||||
assert result.data["petMutation1"]["pet"] == {"name": "Mia", "age": 10}
|
||||
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 2
|
||||
|
||||
pet1 = Pet.objects.first()
|
||||
assert pet1.name == "Mia"
|
||||
assert pet1.age == 10
|
||||
|
||||
pet2 = Pet.objects.last()
|
||||
assert pet2.name == "Enzo"
|
||||
assert pet2.age == 0
|
||||
|
||||
|
||||
def test_model_form_mutation_multiple_creation_invalid():
|
||||
class Mutation(ObjectType):
|
||||
pet_mutation = PetMutation.Field()
|
||||
|
||||
schema = Schema(query=MockQuery, mutation=Mutation)
|
||||
|
||||
result = schema.execute(
|
||||
"""
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
assert result.errors is None
|
||||
|
||||
assert result.data["petMutation1"]["pet"] is None
|
||||
assert result.data["petMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert result.data["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Enzo"
|
||||
assert pet.age == 0
|
||||
|
|
|
@ -60,8 +60,10 @@ class Command(CommandArguments):
|
|||
|
||||
def get_schema(self, schema, out, indent):
|
||||
schema_dict = {"data": schema.introspect()}
|
||||
if out == "-":
|
||||
if out == "-" or out == "-.json":
|
||||
self.stdout.write(json.dumps(schema_dict, indent=indent, sort_keys=True))
|
||||
elif out == "-.graphql":
|
||||
self.stdout.write(print_schema(schema))
|
||||
else:
|
||||
# Determine format
|
||||
_, file_extension = os.path.splitext(out)
|
||||
|
|
|
@ -18,6 +18,7 @@ class SerializerMutationOptions(MutationOptions):
|
|||
model_class = None
|
||||
model_operations = ["create", "update"]
|
||||
serializer_class = None
|
||||
optional_fields = ()
|
||||
|
||||
|
||||
def fields_for_serializer(
|
||||
|
@ -27,6 +28,7 @@ def fields_for_serializer(
|
|||
is_input=False,
|
||||
convert_choices_to_enum=True,
|
||||
lookup_field=None,
|
||||
optional_fields=(),
|
||||
):
|
||||
fields = OrderedDict()
|
||||
for name, field in serializer.fields.items():
|
||||
|
@ -44,9 +46,13 @@ def fields_for_serializer(
|
|||
|
||||
if is_not_in_only or is_excluded:
|
||||
continue
|
||||
is_optional = name in optional_fields
|
||||
|
||||
fields[name] = convert_serializer_field(
|
||||
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum
|
||||
field,
|
||||
is_input=is_input,
|
||||
convert_choices_to_enum=convert_choices_to_enum,
|
||||
force_optional=is_optional,
|
||||
)
|
||||
return fields
|
||||
|
||||
|
@ -70,6 +76,7 @@ class SerializerMutation(ClientIDMutation):
|
|||
exclude_fields=(),
|
||||
convert_choices_to_enum=True,
|
||||
_meta=None,
|
||||
optional_fields=(),
|
||||
**options
|
||||
):
|
||||
|
||||
|
@ -95,6 +102,7 @@ class SerializerMutation(ClientIDMutation):
|
|||
is_input=True,
|
||||
convert_choices_to_enum=convert_choices_to_enum,
|
||||
lookup_field=lookup_field,
|
||||
optional_fields=optional_fields,
|
||||
)
|
||||
output_fields = fields_for_serializer(
|
||||
serializer,
|
||||
|
|
|
@ -19,7 +19,9 @@ def get_graphene_type_from_serializer_field(field):
|
|||
)
|
||||
|
||||
|
||||
def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True):
|
||||
def convert_serializer_field(
|
||||
field, is_input=True, convert_choices_to_enum=True, force_optional=False
|
||||
):
|
||||
"""
|
||||
Converts a django rest frameworks field to a graphql field
|
||||
and marks the field as required if we are creating an input type
|
||||
|
@ -32,7 +34,10 @@ def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True)
|
|||
graphql_type = get_graphene_type_from_serializer_field(field)
|
||||
|
||||
args = []
|
||||
kwargs = {"description": field.help_text, "required": is_input and field.required}
|
||||
kwargs = {
|
||||
"description": field.help_text,
|
||||
"required": is_input and field.required and not force_optional,
|
||||
}
|
||||
|
||||
# if it is a tuple or a list it means that we are returning
|
||||
# the graphql type and the child type
|
||||
|
|
|
@ -3,7 +3,7 @@ import datetime
|
|||
from py.test import raises
|
||||
from rest_framework import serializers
|
||||
|
||||
from graphene import Field, ResolveInfo
|
||||
from graphene import Field, ResolveInfo, NonNull, String
|
||||
from graphene.types.inputobjecttype import InputObjectType
|
||||
|
||||
from ...types import DjangoObjectType
|
||||
|
@ -98,6 +98,25 @@ def test_exclude_fields():
|
|||
assert "created" not in MyMutation.Input._meta.fields
|
||||
|
||||
|
||||
def test_model_serializer_required_fields():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
|
||||
assert "cool_name" in MyMutation.Input._meta.fields
|
||||
assert MyMutation.Input._meta.fields["cool_name"].type == NonNull(String)
|
||||
|
||||
|
||||
def test_model_serializer_optional_fields():
|
||||
class MyMutation(SerializerMutation):
|
||||
class Meta:
|
||||
serializer_class = MyModelSerializer
|
||||
optional_fields = ("cool_name",)
|
||||
|
||||
assert "cool_name" in MyMutation.Input._meta.fields
|
||||
assert MyMutation.Input._meta.fields["cool_name"].type == String
|
||||
|
||||
|
||||
def test_write_only_field():
|
||||
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
|
||||
password = serializers.CharField(write_only=True)
|
||||
|
|
|
@ -45,6 +45,7 @@ DEFAULTS = {
|
|||
# This sets headerEditorEnabled GraphiQL option, for details go to
|
||||
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
|
||||
"GRAPHIQL_HEADER_EDITOR_ENABLED": True,
|
||||
"ATOMIC_MUTATIONS": False,
|
||||
}
|
||||
|
||||
if settings.DEBUG:
|
||||
|
|
16
graphene_django/tests/forms.py
Normal file
16
graphene_django/tests/forms.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from django import forms
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from .models import Pet
|
||||
|
||||
|
||||
class PetForm(forms.ModelForm):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
||||
|
||||
def clean_age(self):
|
||||
age = self.cleaned_data["age"]
|
||||
if age >= 99:
|
||||
raise ValidationError("Too old")
|
||||
return age
|
|
@ -6,6 +6,10 @@ from django.utils.translation import ugettext_lazy as _
|
|||
CHOICES = ((1, "this"), (2, _("that")))
|
||||
|
||||
|
||||
class Person(models.Model):
|
||||
name = models.CharField(max_length=30)
|
||||
|
||||
|
||||
class Pet(models.Model):
|
||||
name = models.CharField(max_length=30)
|
||||
age = models.PositiveIntegerField()
|
||||
|
|
18
graphene_django/tests/mutations.py
Normal file
18
graphene_django/tests/mutations.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from graphene import Field
|
||||
|
||||
from graphene_django.forms.mutation import DjangoFormMutation, DjangoModelFormMutation
|
||||
|
||||
from .forms import PetForm
|
||||
from .types import PetType
|
||||
|
||||
|
||||
class PetFormMutation(DjangoFormMutation):
|
||||
class Meta:
|
||||
form_class = PetForm
|
||||
|
||||
|
||||
class PetMutation(DjangoModelFormMutation):
|
||||
pet = Field(PetType)
|
||||
|
||||
class Meta:
|
||||
form_class = PetForm
|
|
@ -1,6 +1,8 @@
|
|||
import graphene
|
||||
from graphene import ObjectType, Schema
|
||||
|
||||
from .mutations import PetFormMutation, PetMutation
|
||||
|
||||
|
||||
class QueryRoot(ObjectType):
|
||||
|
||||
|
@ -19,6 +21,8 @@ class QueryRoot(ObjectType):
|
|||
|
||||
|
||||
class MutationRoot(ObjectType):
|
||||
pet_form_mutation = PetFormMutation.Field()
|
||||
pet_mutation = PetMutation.Field()
|
||||
write_test = graphene.Field(QueryRoot)
|
||||
|
||||
def resolve_write_test(self, info):
|
||||
|
|
|
@ -241,6 +241,10 @@ def test_should_float_convert_float():
|
|||
assert_conversion(models.FloatField, graphene.Float)
|
||||
|
||||
|
||||
def test_should_float_convert_decimal():
|
||||
assert_conversion(models.DecimalField, graphene.Decimal)
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist():
|
||||
registry = Registry()
|
||||
dynamic_field = convert_django_field(Reporter._meta.local_many_to_many[0], registry)
|
||||
|
|
|
@ -75,6 +75,39 @@ class TestDjangoListField:
|
|||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
}
|
||||
|
||||
def test_list_field_queryset_is_not_cached(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
model = ReporterModel
|
||||
fields = ("first_name",)
|
||||
|
||||
class Query(ObjectType):
|
||||
reporters = DjangoListField(Reporter)
|
||||
|
||||
schema = Schema(query=Query)
|
||||
|
||||
query = """
|
||||
query {
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == {"reporters": []}
|
||||
|
||||
ReporterModel.objects.create(first_name="Tara", last_name="West")
|
||||
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert not result.errors
|
||||
assert result.data == {
|
||||
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
|
||||
}
|
||||
|
||||
def test_override_resolver(self):
|
||||
class Reporter(DjangoObjectType):
|
||||
class Meta:
|
||||
|
|
|
@ -1213,6 +1213,103 @@ def test_should_have_next_page(graphene_settings):
|
|||
}
|
||||
|
||||
|
||||
class TestBackwardPagination:
|
||||
def setup_schema(self, graphene_settings, max_limit):
|
||||
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
|
||||
reporters = [Reporter(**kwargs) for kwargs in REPORTERS]
|
||||
Reporter.objects.bulk_create(reporters)
|
||||
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
interfaces = (Node,)
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
all_reporters = DjangoConnectionField(ReporterType)
|
||||
|
||||
schema = graphene.Schema(query=Query)
|
||||
return schema
|
||||
|
||||
def do_queries(self, schema):
|
||||
# Simply last 3
|
||||
query_last = """
|
||||
query {
|
||||
allReporters(last: 3) {
|
||||
edges {
|
||||
node {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query_last)
|
||||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 3
|
||||
assert [
|
||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||
] == ["First 3", "First 4", "First 5"]
|
||||
|
||||
# Use a combination of first and last
|
||||
query_first_and_last = """
|
||||
query {
|
||||
allReporters(first: 4, last: 3) {
|
||||
edges {
|
||||
node {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query_first_and_last)
|
||||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 3
|
||||
assert [
|
||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||
] == ["First 1", "First 2", "First 3"]
|
||||
|
||||
# Use a combination of first and last and after
|
||||
query_first_last_and_after = """
|
||||
query queryAfter($after: String) {
|
||||
allReporters(first: 4, last: 3, after: $after) {
|
||||
edges {
|
||||
node {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
after = base64.b64encode(b"arrayconnection:0").decode()
|
||||
result = schema.execute(
|
||||
query_first_last_and_after, variable_values=dict(after=after)
|
||||
)
|
||||
assert not result.errors
|
||||
assert len(result.data["allReporters"]["edges"]) == 3
|
||||
assert [
|
||||
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
|
||||
] == ["First 2", "First 3", "First 4"]
|
||||
|
||||
def test_should_query(self, graphene_settings):
|
||||
"""
|
||||
Backward pagination should work as expected
|
||||
"""
|
||||
schema = self.setup_schema(graphene_settings, max_limit=100)
|
||||
self.do_queries(schema)
|
||||
|
||||
def test_should_query_with_low_max_limit(self, graphene_settings):
|
||||
"""
|
||||
When doing backward pagination (using last) in combination with a max limit higher than the number of objects
|
||||
we should really retrieve the last ones.
|
||||
"""
|
||||
schema = self.setup_schema(graphene_settings, max_limit=4)
|
||||
self.do_queries(schema)
|
||||
|
||||
|
||||
def test_should_preserve_prefetch_related(django_assert_num_queries):
|
||||
class ReporterType(DjangoObjectType):
|
||||
class Meta:
|
||||
|
|
|
@ -51,7 +51,9 @@ def test_graphql_test_case_op_name(post_mock):
|
|||
pass
|
||||
|
||||
tc = TestClass()
|
||||
tc._pre_setup()
|
||||
tc.setUpClass()
|
||||
|
||||
tc.query("query { }", op_name="QueryName")
|
||||
body = json.loads(post_mock.call_args.args[1])
|
||||
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
|
||||
|
|
|
@ -2,6 +2,14 @@ import json
|
|||
|
||||
import pytest
|
||||
|
||||
from mock import patch
|
||||
|
||||
from django.db import connection
|
||||
|
||||
from graphene_django.settings import graphene_settings
|
||||
|
||||
from .models import Pet
|
||||
|
||||
try:
|
||||
from urllib import urlencode
|
||||
except ImportError:
|
||||
|
@ -558,3 +566,265 @@ def test_passes_request_into_context_request(client):
|
|||
|
||||
assert response.status_code == 200
|
||||
assert response_json(response) == {"data": {"request": "testing"}}
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_atomic_request(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": True, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_atomic_mutation_1(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", True)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_atomic_mutation_2(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_form_mutation_multiple_creation_invalid_non_atomic(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petFormMutation1: petFormMutation(input: { name: "Mia", age: 99 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petFormMutation2: petFormMutation(input: { name: "Enzo", age: 0 }) {
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petFormMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petFormMutation2"]["errors"] == []
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Enzo"
|
||||
assert pet.age == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
|
||||
)
|
||||
def test_model_form_mutation_multiple_creation_invalid_atomic_request(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petMutation1"]["pet"] is None
|
||||
assert content["data"]["petMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 0
|
||||
|
||||
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_model_form_mutation_multiple_creation_invalid_non_atomic(client):
|
||||
query = """
|
||||
mutation PetMutations {
|
||||
petMutation1: petMutation(input: { name: "Mia", age: 99 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
petMutation2: petMutation(input: { name: "Enzo", age: 0 }) {
|
||||
pet {
|
||||
name
|
||||
age
|
||||
}
|
||||
errors {
|
||||
field
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = client.post(url_string(query=query))
|
||||
content = response_json(response)
|
||||
|
||||
assert "errors" not in content
|
||||
|
||||
assert content["data"]["petMutation1"]["pet"] is None
|
||||
assert content["data"]["petMutation1"]["errors"] == [
|
||||
{"field": "age", "messages": ["Too old"]}
|
||||
]
|
||||
|
||||
assert content["data"]["petMutation2"]["pet"] == {"name": "Enzo", "age": 0}
|
||||
|
||||
assert Pet.objects.count() == 1
|
||||
|
||||
pet = Pet.objects.get()
|
||||
assert pet.name == "Enzo"
|
||||
assert pet.age == 0
|
||||
|
||||
|
||||
@patch("graphene_django.utils.utils.transaction.set_rollback")
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": True}
|
||||
)
|
||||
def test_query_errors_atomic_request(set_rollback_mock, client):
|
||||
client.get(url_string(query="force error"))
|
||||
set_rollback_mock.assert_called_once_with(True)
|
||||
|
||||
|
||||
@patch("graphene_django.utils.utils.transaction.set_rollback")
|
||||
@patch("graphene_django.settings.graphene_settings.ATOMIC_MUTATIONS", False)
|
||||
@patch.dict(
|
||||
connection.settings_dict, {"ATOMIC_MUTATIONS": False, "ATOMIC_REQUESTS": False}
|
||||
)
|
||||
def test_query_errors_non_atomic(set_rollback_mock, client):
|
||||
client.get(url_string(query="force error"))
|
||||
set_rollback_mock.assert_not_called()
|
||||
|
|
9
graphene_django/tests/types.py
Normal file
9
graphene_django/tests/types.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from graphene_django.types import DjangoObjectType
|
||||
|
||||
from .models import Pet
|
||||
|
||||
|
||||
class PetType(DjangoObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
fields = "__all__"
|
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from unidecode import unidecode
|
||||
from text_unidecode import unidecode
|
||||
|
||||
|
||||
def to_const(string):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import warnings
|
||||
|
||||
from django.test import TestCase, Client
|
||||
from django.test import Client, TestCase
|
||||
|
||||
DEFAULT_GRAPHQL_URL = "/graphql/"
|
||||
|
||||
|
@ -68,12 +69,6 @@ class GraphQLTestCase(TestCase):
|
|||
# URL to graphql endpoint
|
||||
GRAPHQL_URL = DEFAULT_GRAPHQL_URL
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(GraphQLTestCase, cls).setUpClass()
|
||||
|
||||
cls._client = Client()
|
||||
|
||||
def query(self, query, op_name=None, input_data=None, variables=None, headers=None):
|
||||
"""
|
||||
Args:
|
||||
|
@ -99,19 +94,41 @@ class GraphQLTestCase(TestCase):
|
|||
input_data=input_data,
|
||||
variables=variables,
|
||||
headers=headers,
|
||||
client=self._client,
|
||||
client=self.client,
|
||||
graphql_url=self.GRAPHQL_URL,
|
||||
)
|
||||
|
||||
@property
|
||||
def _client(self):
|
||||
pass
|
||||
|
||||
@_client.getter
|
||||
def _client(self):
|
||||
warnings.warn(
|
||||
"Using `_client` is deprecated in favour of `client`.",
|
||||
PendingDeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.client
|
||||
|
||||
@_client.setter
|
||||
def _client(self, client):
|
||||
warnings.warn(
|
||||
"Using `_client` is deprecated in favour of `client`.",
|
||||
PendingDeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.client = client
|
||||
|
||||
def assertResponseNoErrors(self, resp, msg=None):
|
||||
"""
|
||||
Assert that the call went through correctly. 200 means the syntax is ok, if there are no `errors`,
|
||||
the call was fine.
|
||||
:resp HttpResponse: Response
|
||||
"""
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
content = json.loads(resp.content)
|
||||
self.assertNotIn("errors", list(content.keys()), msg)
|
||||
self.assertEqual(resp.status_code, 200, msg or content)
|
||||
self.assertNotIn("errors", list(content.keys()), msg or content)
|
||||
|
||||
def assertResponseHasErrors(self, resp, msg=None):
|
||||
"""
|
||||
|
@ -119,4 +136,4 @@ class GraphQLTestCase(TestCase):
|
|||
:resp HttpResponse: Response
|
||||
"""
|
||||
content = json.loads(resp.content)
|
||||
self.assertIn("errors", list(content.keys()), msg)
|
||||
self.assertIn("errors", list(content.keys()), msg or content)
|
||||
|
|
45
graphene_django/utils/tests/test_testing.py
Normal file
45
graphene_django/utils/tests/test_testing.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
|
||||
from .. import GraphQLTestCase
|
||||
from ...tests.test_types import with_local_registry
|
||||
from django.test import Client
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_graphql_test_case_deprecated_client_getter():
|
||||
"""
|
||||
`GraphQLTestCase._client`' getter should raise pending deprecation warning.
|
||||
"""
|
||||
|
||||
class TestClass(GraphQLTestCase):
|
||||
GRAPHQL_SCHEMA = True
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
tc = TestClass()
|
||||
tc._pre_setup()
|
||||
tc.setUpClass()
|
||||
|
||||
with pytest.warns(PendingDeprecationWarning):
|
||||
tc._client
|
||||
|
||||
|
||||
@with_local_registry
|
||||
def test_graphql_test_case_deprecated_client_setter():
|
||||
"""
|
||||
`GraphQLTestCase._client`' setter should raise pending deprecation warning.
|
||||
"""
|
||||
|
||||
class TestClass(GraphQLTestCase):
|
||||
GRAPHQL_SCHEMA = True
|
||||
|
||||
def runTest(self):
|
||||
pass
|
||||
|
||||
tc = TestClass()
|
||||
tc._pre_setup()
|
||||
tc.setUpClass()
|
||||
|
||||
with pytest.warns(PendingDeprecationWarning):
|
||||
tc._client = Client()
|
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
|
||||
import six
|
||||
from django.db import models
|
||||
from django.db import connection, models, transaction
|
||||
from django.db.models.manager import Manager
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.functional import Promise
|
||||
|
@ -100,3 +100,9 @@ def import_single_dispatch():
|
|||
)
|
||||
|
||||
return singledispatch
|
||||
|
||||
|
||||
def set_rollback():
|
||||
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
|
||||
if atomic_requests and connection.in_atomic_block:
|
||||
transaction.set_rollback(True)
|
||||
|
|
|
@ -3,6 +3,7 @@ import json
|
|||
import re
|
||||
|
||||
import six
|
||||
from django.db import connection, transaction
|
||||
from django.http import HttpResponse, HttpResponseNotAllowed
|
||||
from django.http.response import HttpResponseBadRequest
|
||||
from django.shortcuts import render
|
||||
|
@ -17,6 +18,9 @@ from graphql.execution import ExecutionResult
|
|||
from graphql.type.schema import GraphQLSchema
|
||||
from graphql.execution.middleware import MiddlewareManager
|
||||
|
||||
from graphene_django.constants import MUTATION_ERRORS_FLAG
|
||||
from graphene_django.utils.utils import set_rollback
|
||||
|
||||
from .settings import graphene_settings
|
||||
|
||||
|
||||
|
@ -203,11 +207,15 @@ class GraphQLView(View):
|
|||
request, data, query, variables, operation_name, show_graphiql
|
||||
)
|
||||
|
||||
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||
set_rollback()
|
||||
|
||||
status_code = 200
|
||||
if execution_result:
|
||||
response = {}
|
||||
|
||||
if execution_result.errors:
|
||||
set_rollback()
|
||||
response["errors"] = [
|
||||
self.format_error(e) for e in execution_result.errors
|
||||
]
|
||||
|
@ -312,14 +320,27 @@ class GraphQLView(View):
|
|||
# executor is not a valid argument in all backends
|
||||
extra_options["executor"] = self.executor
|
||||
|
||||
return document.execute(
|
||||
root_value=self.get_root_value(request),
|
||||
variable_values=variables,
|
||||
operation_name=operation_name,
|
||||
context_value=self.get_context(request),
|
||||
middleware=self.get_middleware(request),
|
||||
**extra_options
|
||||
)
|
||||
options = {
|
||||
"root_value": self.get_root_value(request),
|
||||
"variable_values": variables,
|
||||
"operation_name": operation_name,
|
||||
"context_value": self.get_context(request),
|
||||
"middleware": self.get_middleware(request),
|
||||
}
|
||||
options.update(extra_options)
|
||||
|
||||
operation_type = document.get_operation_type(operation_name)
|
||||
if operation_type == "mutation" and (
|
||||
graphene_settings.ATOMIC_MUTATIONS is True
|
||||
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
|
||||
):
|
||||
with transaction.atomic():
|
||||
result = document.execute(**options)
|
||||
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
|
||||
transaction.set_rollback(True)
|
||||
return result
|
||||
|
||||
return document.execute(**options)
|
||||
except Exception as e:
|
||||
return ExecutionResult(errors=[e], invalid=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user