Merge branch 'main' into fix/default-graphql-url

This commit is contained in:
Firas K 2022-09-24 16:04:56 +03:00 committed by GitHub
commit 7fe12cfb58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 2371 additions and 463 deletions

View File

@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.8
uses: actions/setup-python@v1
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.9
- name: Build wheel and source tarball
run: |
pip install wheel

View File

@ -7,11 +7,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.8
uses: actions/setup-python@v1
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@ -8,13 +8,15 @@ jobs:
strategy:
max-parallel: 4
matrix:
django: ["2.2", "3.0", "3.1"]
python-version: ["3.6", "3.7", "3.8"]
django: ["3.2", "4.0", "4.1"]
python-version: ["3.8", "3.9", "3.10"]
include:
- django: "3.2"
python-version: "3.7"
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies

View File

@ -1,22 +1,21 @@
.PHONY: help
help:
@echo "Please use \`make <target>' where <target> is one of"
@grep -E '^\.PHONY: [a-zA-Z_-]+ .*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = "(: |##)"}; {printf "\033[36m%-30s\033[0m %s\n", $$2, $$3}'
.PHONY: dev-setup ## Install development dependencies
dev-setup:
pip install -e ".[dev]"
.PHONY: install-dev
install-dev: dev-setup # Alias install-dev -> dev-setup
.PHONY: tests
.PHONY: tests ## Run unit tests
tests:
py.test graphene_django --cov=graphene_django -vv
.PHONY: test
test: tests # Alias test -> tests
.PHONY: format
.PHONY: format ## Format code
format:
black --exclude "/migrations/" graphene_django examples setup.py
.PHONY: lint
.PHONY: lint ## Lint code
lint:
flake8 graphene_django examples

View File

@ -55,7 +55,7 @@ from graphene_django.views import GraphQLView
urlpatterns = [
# ...
path('graphql', GraphQLView.as_view(graphiql=True)),
path('graphql/', GraphQLView.as_view(graphiql=True)),
]
```

View File

@ -48,6 +48,31 @@ conversely you can use ``exclude`` meta attribute.
exclude = ('published', 'owner')
interfaces = (relay.Node, )
Another pattern is to have a resolve method act as a gatekeeper, returning None
or raising an exception if the client isn't allowed to see the data.
.. code:: python
from graphene import relay
from graphene_django.types import DjangoObjectType
from .models import Post
class PostNode(DjangoObjectType):
class Meta:
model = Post
fields = ('title', 'content', 'owner')
interfaces = (relay.Node, )
def resolve_owner(self, info):
user = info.context.user
if user.is_anonymous:
raise PermissionDenied("Please login")
if not user.is_staff:
return None
return self.owner
Queryset Filtering On Lists
---------------------------
@ -173,7 +198,7 @@ For Django 2.2 and above:
urlpatterns = [
# some other urls
path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
path('graphql/', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
]
.. _LoginRequiredMixin: https://docs.djangoproject.com/en/dev/topics/auth/default/#the-loginrequired-mixin

View File

@ -4,7 +4,7 @@ Django Debug Middleware
You can debug your GraphQL queries in a similar way to
`django-debug-toolbar <https://django-debug-toolbar.readthedocs.org/>`__,
but outputting in the results in GraphQL response as fields, instead of
the graphical HTML interface.
the graphical HTML interface. Exceptions with their stack traces are also exposed.
For that, you will need to add the plugin in your graphene schema.
@ -63,6 +63,10 @@ the GraphQL request, like:
sql {
rawSql
}
exceptions {
message
stack
}
}
}

View File

@ -80,4 +80,6 @@ published and have a title:
DjangoConnectionField
---------------------
*TODO*
``DjangoConnectionField`` acts similarly to ``DjangoListField`` but returns a
paginated connection following the `relay spec <https://relay.dev/graphql/connections.htm>`__
The field supports the following arguments: `first`, `last`, `offset`, `after` & `before`.

View File

@ -2,8 +2,8 @@ Filtering
=========
Graphene integrates with
`django-filter <https://django-filter.readthedocs.io/en/master/>`__ to provide filtering of results.
See the `usage documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__
`django-filter <https://django-filter.readthedocs.io/en/main/>`__ to provide filtering of results.
See the `usage documentation <https://django-filter.readthedocs.io/en/main/guide/usage.html#the-filter>`__
for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``.
@ -15,7 +15,7 @@ You will need to install it manually, which can be done as follows:
# You'll need to install django-filter
pip install django-filter>=2
After installing ``django-filter`` you'll need to add the application in the ``settings.py`` file:
.. code:: python
@ -26,7 +26,7 @@ After installing ``django-filter`` you'll need to add the application in the ``s
]
Note: The techniques below are demoed in the `cookbook example
app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__.
app <https://github.com/graphql-python/graphene-django/tree/main/examples/cookbook>`__.
Filterable fields
-----------------
@ -34,7 +34,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering
documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__
documentation <https://django-filter.readthedocs.io/en/main/guide/usage.html#the-filter>`__
for full details on the range of options available.
For example:
@ -192,7 +192,7 @@ in unison with the ``filter_fields`` parameter:
all_animals = DjangoFilterConnectionField(AnimalNode)
The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/master/guide/usage.html#request-based-filtering>`__
The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/main/guide/usage.html#request-based-filtering>`__
in a ``django_filters.FilterSet`` instance. You can use this to customize your
filters to be context-dependent. We could modify the ``AnimalFilter`` above to
pre-filter animals owned by the authenticated user (set in ``context.user``).
@ -258,3 +258,86 @@ with this set up, you can now order the users under group:
}
}
}
PostgreSQL `ArrayField`
-----------------------
Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters:
.. code:: python
from django.db import models
from django_filters import FilterSet, OrderingFilter
from graphene_django.filter import ArrayFilter
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet
with this set up, you can now filter events by tags:
.. code::
query {
events(tags_Overlap: ["concert", "festival"]) {
name
}
}
`TypedFilter`
-------------
Sometimes the automatic detection of the filter input type is not satisfactory for what you are trying to achieve.
You can then explicitly specify the input type you want for your filter by using a `TypedFilter`:
.. code:: python
from django.db import models
from django_filters import FilterSet, OrderingFilter
import graphene
from graphene_django.filter import TypedFilter
class Event(models.Model):
name = models.CharField(max_length=50)
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
only_first = TypedFilter(input_type=graphene.Boolean, method="only_first_filter")
def only_first_filter(self, queryset, _name, value):
if value:
return queryset[:1]
else:
return queryset
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet

View File

@ -151,7 +151,7 @@ For example the following ``Model`` and ``DjangoObjectType``:
Results in the following GraphQL schema definition:
.. code::
.. code:: graphql
type Pet {
id: ID!
@ -178,7 +178,7 @@ You can disable this automatic conversion by setting
fields = ("id", "kind",)
convert_choices_to_enum = False
.. code::
.. code:: graphql
type Pet {
id: ID!
@ -313,7 +313,7 @@ Additionally, Resolvers will receive **any arguments declared in the field defin
bar=graphene.Int()
)
def resolve_question(root, info, foo, bar):
def resolve_question(root, info, foo=None, bar=None):
# If `foo` or `bar` are declared in the GraphQL query they will be here, else None.
return Question.objects.filter(foo=foo, bar=bar).first()
@ -336,12 +336,12 @@ of Django's ``HTTPRequest`` in your resolve methods, such as checking for authen
class Query(graphene.ObjectType):
questions = graphene.List(QuestionType)
def resolve_questions(root, info):
# See if a user is authenticated
if info.context.user.is_authenticated():
return Question.objects.all()
else:
return Question.objects.none()
def resolve_questions(root, info):
# See if a user is authenticated
if info.context.user.is_authenticated():
return Question.objects.all()
else:
return Question.objects.none()
DjangoObjectTypes
@ -418,29 +418,29 @@ the core graphene pages for more information on customizing the Relay experience
You can now execute queries like:
.. code:: python
.. code:: graphql
{
questions (first: 2, after: "YXJyYXljb25uZWN0aW9uOjEwNQ==") {
pageInfo {
startCursor
endCursor
hasNextPage
hasPreviousPage
startCursor
endCursor
hasNextPage
hasPreviousPage
}
edges {
cursor
node {
id
question_text
}
cursor
node {
id
question_text
}
}
}
}
Which returns:
.. code:: python
.. code:: json
{
"data": {

View File

@ -28,7 +28,7 @@ Usage:
}
}
''',
op_name='myModel'
operation_name='myModel'
)
content = json.loads(response.content)
@ -49,7 +49,7 @@ Usage:
}
}
''',
op_name='myModel',
operation_name='myModel',
variables={'id': 1}
)
@ -73,7 +73,42 @@ Usage:
}
}
''',
op_name='myMutation',
operation_name='myMutation',
input_data={'my_field': 'foo', 'other_field': 'bar'}
)
# This validates the status code and if you get errors
self.assertResponseNoErrors(response)
# Add some more asserts if you like
...
For testing mutations that are executed within a transaction you should subclass `GraphQLTransactionTestCase`
Usage:
.. code:: python
import json
from graphene_django.utils.testing import GraphQLTransactionTestCase
class MyFancyTransactionTestCase(GraphQLTransactionTestCase):
def test_some_mutation_that_executes_within_a_transaction(self):
response = self.query(
'''
mutation myMutation($input: MyMutationInput!) {
myMutation(input: $input) {
my-model {
id
name
}
}
}
''',
operation_name='myMutation',
input_data={'my_field': 'foo', 'other_field': 'bar'}
)
@ -113,7 +148,7 @@ To use pytest define a simple fixture using the query helper below
}
}
''',
op_name='myModel'
operation_name='myModel'
)
content = json.loads(response.content)

View File

@ -35,6 +35,7 @@ Now sync your database for the first time:
.. code:: bash
cd ..
python manage.py migrate
Let's create a few simple models...
@ -77,6 +78,18 @@ Add ingredients as INSTALLED_APPS:
"cookbook.ingredients",
]
Make sure the app name in ``cookbook.ingredients.apps.IngredientsConfig`` is set to ``cookbook.ingredients``.
.. code:: python
# cookbook/ingredients/apps.py
from django.apps import AppConfig
class IngredientsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'cookbook.ingredients'
Don't forget to create & run migrations:

View File

@ -70,7 +70,7 @@ Let's get started with these models:
class Ingredient(models.Model):
name = models.CharField(max_length=100)
notes = models.TextField()
category = models.ForeignKey(Category, related_name='ingredients')
category = models.ForeignKey(Category, related_name='ingredients', on_delete=models.CASCADE)
def __str__(self):
return self.name
@ -151,7 +151,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
interfaces = (relay.Node, )
class Query(graphene.ObjectType):
class Query(ObjectType):
category = relay.Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode)
@ -281,7 +281,7 @@ from the command line.
$ python ./manage.py runserver
Performing system checks...
Django version 1.11, using settings 'cookbook.settings'
Django version 3.1.7, using settings 'cookbook.settings'
Starting development server at http://127.0.0.1:8000/
Quit the server with CONTROL-C.

View File

@ -1,4 +1,4 @@
graphene>=2.1,<3
graphene-django>=2.1,<3
graphql-core>=2.1,<3
django==3.0.7
django==3.1.14

View File

@ -1,5 +1,5 @@
graphene>=2.1,<3
graphene-django>=2.1,<3
graphql-core>=2.1,<3
django==3.0.7
django==3.1.14
django-filter>=2

View File

@ -1,7 +1,7 @@
from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType
__version__ = "3.0.0b7"
__version__ = "3.0.0b8"
__all__ = [
"__version__",

View File

@ -1,5 +1,6 @@
class MissingType(object):
pass
def __init__(self, *args, **kwargs):
pass
try:

View File

@ -1,5 +1,5 @@
from collections import OrderedDict
from functools import singledispatch, partial, wraps
from functools import singledispatch, wraps
from django.db import models
from django.utils.encoding import force_str
@ -23,8 +23,8 @@ from graphene import (
Time,
Decimal,
)
from graphene.types.resolver import get_default_resolver
from graphene.types.json import JSONString
from graphene.types.scalars import BigInt
from graphene.utils.str_converters import to_camel_case
from graphql import GraphQLError, assert_valid_name
from graphql.pyutils import register_description
@ -36,7 +36,7 @@ from .utils.str_converters import to_const
class BlankValueField(Field):
def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver
# create custom resolver
@ -167,11 +167,19 @@ def convert_field_to_string(field, registry=None):
)
@convert_django_field.register(models.BigAutoField)
@convert_django_field.register(models.AutoField)
def convert_field_to_id(field, registry=None):
return ID(description=get_django_field_description(field), required=not field.null)
if hasattr(models, "SmallAutoField"):
@convert_django_field.register(models.SmallAutoField)
def convert_field_small_to_id(field, registry=None):
return convert_field_to_id(field, registry)
@convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None):
return UUID(
@ -179,10 +187,14 @@ def convert_field_to_uuid(field, registry=None):
)
@convert_django_field.register(models.BigIntegerField)
def convert_big_int_field(field, registry=None):
return BigInt(description=field.help_text, required=not field.null)
@convert_django_field.register(models.PositiveIntegerField)
@convert_django_field.register(models.PositiveSmallIntegerField)
@convert_django_field.register(models.SmallIntegerField)
@convert_django_field.register(models.BigIntegerField)
@convert_django_field.register(models.IntegerField)
def convert_field_to_int(field, registry=None):
return Int(description=get_django_field_description(field), required=not field.null)
@ -198,7 +210,9 @@ def convert_field_to_boolean(field, registry=None):
@convert_django_field.register(models.DecimalField)
def convert_field_to_decimal(field, registry=None):
return Decimal(description=field.help_text, required=not field.null)
return Decimal(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.FloatField)
@ -239,10 +253,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
if not _type:
return
# We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance
null = getattr(field, "null", True)
return Field(_type, required=not null)
return Field(_type, required=not field.null)
return Dynamic(dynamic_type)
@ -297,7 +308,24 @@ def convert_field_to_djangomodel(field, registry=None):
if not _type:
return
return Field(
class CustomField(Field):
def wrap_resolve(self, parent_resolver):
"""
Implements a custom resolver which go through the `get_node` method to insure that
it goes through the `get_queryset` method of the DjangoObjectType.
"""
resolver = super().wrap_resolve(parent_resolver)
def custom_resolver(root, info, **args):
fk_obj = resolver(root, info, **args)
if fk_obj is None:
return None
else:
return _type.get_node(info, fk_obj.pk)
return custom_resolver
return CustomField(
_type,
description=get_django_field_description(field),
required=not field.null,

View File

@ -0,0 +1,17 @@
import traceback
from django.utils.encoding import force_str
from .types import DjangoDebugException
def wrap_exception(exception):
return DjangoDebugException(
message=force_str(exception),
exc_type=force_str(type(exception)),
stack="".join(
traceback.format_exception(
exception, value=exception, tb=exception.__traceback__
)
),
)

View File

@ -0,0 +1,10 @@
from graphene import ObjectType, String
class DjangoDebugException(ObjectType):
class Meta:
description = "Represents a single exception raised."
exc_type = String(required=True, description="The class of the exception")
message = String(required=True, description="The message of the exception")
stack = String(required=True, description="The stack trace")

View File

@ -3,6 +3,7 @@ from django.db import connections
from promise import Promise
from .sql.tracking import unwrap_cursor, wrap_cursor
from .exception.formating import wrap_exception
from .types import DjangoDebug
@ -10,8 +11,8 @@ class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
self.object = DjangoDebug(sql=[], exceptions=[])
self.enable_instrumentation()
self.object = DjangoDebug(sql=[])
def get_debug_promise(self):
if not self.debug_promise:
@ -19,6 +20,11 @@ class DjangoDebugContext(object):
self.promises = []
return self.debug_promise.then(self.on_resolve_all_promises).get()
def on_resolve_error(self, value):
if hasattr(self, "object"):
self.object.exceptions.append(wrap_exception(value))
return Promise.reject(value)
def on_resolve_all_promises(self, values):
if self.promises:
self.debug_promise = None
@ -57,6 +63,9 @@ class DjangoDebugMiddleware(object):
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, info, **args)
try:
promise = next(root, info, **args)
except Exception as e:
return context.django_debug.on_resolve_error(e)
context.django_debug.add_promise(promise)
return promise

View File

@ -272,3 +272,42 @@ def test_should_query_connectionfilter(graphene_settings, max_limit):
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query
def test_should_query_stack_trace():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_reporter(self, info, **args):
raise Exception("caught stack trace")
query = """
query ReporterQuery {
reporter {
lastName
}
_debug {
exceptions {
message
stack
}
}
}
"""
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert result.errors
assert len(result.data["_debug"]["exceptions"])
debug_exception = result.data["_debug"]["exceptions"][0]
assert debug_exception["stack"].count("\n") > 1
assert "test_query.py" in debug_exception["stack"]
assert debug_exception["message"] == "caught stack trace"

View File

@ -1,6 +1,7 @@
from graphene import List, ObjectType
from .sql.types import DjangoDebugSQL
from .exception.types import DjangoDebugException
class DjangoDebug(ObjectType):
@ -8,3 +9,6 @@ class DjangoDebug(ObjectType):
description = "Debugging information for the current query."
sql = List(DjangoDebugSQL, description="Executed SQL queries for this API query.")
exceptions = List(
DjangoDebugException, description="Raise exceptions for this API query."
)

View File

@ -1,12 +1,14 @@
from functools import partial
from django.db.models.query import QuerySet
from graphql_relay.connection.arrayconnection import (
from graphql_relay import (
connection_from_array_slice,
cursor_to_offset,
get_offset_with_default,
offset_to_cursor,
)
from promise import Promise
from graphene import Int, NonNull
@ -61,12 +63,16 @@ class DjangoListField(Field):
return queryset
def wrap_resolve(self, parent_resolver):
resolver = super(DjangoListField, self).wrap_resolve(parent_resolver)
_type = self.type
if isinstance(_type, NonNull):
_type = _type.of_type
django_object_type = _type.of_type.of_type
return partial(
self.list_resolver, django_object_type, parent_resolver, self.get_manager(),
self.list_resolver,
django_object_type,
resolver,
self.get_manager(),
)
@ -143,36 +149,40 @@ class DjangoConnectionField(ConnectionField):
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
list_length = iterable.count()
array_length = iterable.count()
else:
list_length = len(iterable)
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
array_length = len(iterable)
# If after is higher than list_length, connection_from_list_slice
# If after is higher than array_length, connection_from_array_slice
# would try to do a negative slicing which makes django throw an
# AssertionError
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
slice_start = min(
get_offset_with_default(args.get("after"), -1) + 1,
array_length,
)
array_slice_length = array_length - slice_start
if max_limit is not None and args.get("first", None) == None:
if args.get("last", None) != None:
after = list_length - args["last"]
else:
args["first"] = max_limit
# Impose the maximum limit via the `first` field if neither first or last are already provided
# (note that if any of them is provided they must be under max_limit otherwise an error is raised).
if (
max_limit is not None
and args.get("first", None) is None
and args.get("last", None) is None
):
args["first"] = max_limit
connection = connection_from_array_slice(
iterable[after:],
iterable[slice_start:],
args,
slice_start=after,
array_length=list_length,
array_slice_length=list_slice_length,
slice_start=slice_start,
array_length=array_length,
array_slice_length=array_slice_length,
connection_type=partial(connection_adapter, connection),
edge_type=connection.Edge,
page_info_type=page_info_adapter,
)
connection.iterable = iterable
connection.length = list_length
connection.length = array_length
return connection
@classmethod

View File

@ -9,10 +9,21 @@ if not DJANGO_FILTER_INSTALLED:
)
else:
from .fields import DjangoFilterConnectionField
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .filters import (
ArrayFilter,
GlobalIDFilter,
GlobalIDMultipleChoiceFilter,
ListFilter,
RangeFilter,
TypedFilter,
)
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
"TypedFilter",
]

View File

@ -2,16 +2,35 @@ from collections import OrderedDict
from functools import partial
from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class
def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
self,
type,
type_,
fields=None,
order_by=None,
extra_filter_meta=None,
@ -25,7 +44,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self._filtering_args = None
self._extra_filter_meta = extra_filter_meta
self._base_args = None
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs)
super(DjangoFilterConnectionField, self).__init__(type_, *args, **kwargs)
@property
def args(self):
@ -43,8 +62,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
filterset_class = self._provided_filterset_class or (
self.node_type._meta.filterset_class
filterset_class = (
self._provided_filterset_class or self.node_type._meta.filterset_class
)
self._filterset_class = get_filterset_class(filterset_class, **meta)
@ -68,7 +87,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if k in filtering_args:
if k == "order_by" and v is not None:
v = to_snake_case(v)
kwargs[k] = v
kwargs[k] = convert_enum(v)
return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset(
@ -78,7 +97,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context
)
if filterset.form.is_valid():
if filterset.is_valid():
return filterset.qs
raise ValidationError(filterset.form.errors.as_json())

View File

@ -1,75 +0,0 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
class InFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super().filter(qs, value)
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(Filter):
field_class = RangeField

View File

@ -0,0 +1,25 @@
import warnings
from ...utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .array_filter import ArrayFilter
from .global_id_filter import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .list_filter import ListFilter
from .range_filter import RangeFilter
from .typed_filter import TypedFilter
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
"TypedFilter",
]

View File

@ -0,0 +1,27 @@
from django_filters.constants import EMPTY_VALUES
from .typed_filter import TypedFilter
class ArrayFilter(TypedFilter):
"""
Filter made for PostgreSQL ArrayField.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get the filter applied with
an empty list since it's a valid value but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value in EMPTY_VALUES and value != []:
return qs
if self.distinct:
qs = qs.distinct()
lookup = "%s__%s" % (self.field_name, self.lookup_expr)
qs = self.get_method(qs)(**{lookup: value})
return qs

View File

@ -0,0 +1,28 @@
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
"""Convert the filter value to a primary key before filtering"""
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)

View File

@ -0,0 +1,26 @@
from .typed_filter import TypedFilter
class ListFilter(TypedFilter):
"""
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super(ListFilter, self).filter(qs, value)

View File

@ -0,0 +1,24 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from .typed_filter import TypedFilter
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(TypedFilter):
field_class = RangeField

View File

@ -0,0 +1,27 @@
from django_filters import Filter
from graphene.types.utils import get_type
class TypedFilter(Filter):
"""
Filter class for which the input GraphQL type can explicitly be provided.
If it is not provided, when building the schema, it will try to guess
it from the field.
"""
def __init__(self, input_type=None, *args, **kwargs):
self._input_type = input_type
super(TypedFilter, self).__init__(*args, **kwargs)
@property
def input_type(self):
input_type = get_type(self._input_type)
if input_type is not None:
if not callable(getattr(input_type, "get_type", None)):
raise ValueError(
"Wrong `input_type` for {}: it only accepts graphene types, got {}".format(
self.__class__.__name__, input_type
)
)
return input_type

View File

@ -1,7 +1,6 @@
import itertools
from django.db import models
from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
@ -19,8 +18,8 @@ GRAPHENE_FILTER_SET_OVERRIDES = {
class GrapheneFilterSetMixin(BaseFilterSet):
""" A django_filters.filterset.BaseFilterSet with default filter overrides
to handle global IDs """
"""A django_filters.filterset.BaseFilterSet with default filter overrides
to handle global IDs"""
FILTER_DEFAULTS = dict(
itertools.chain(
@ -30,8 +29,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality
"""
"""Wrap a provided filterset in Graphene-specific functionality"""
return type(
"Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin),
@ -40,8 +38,7 @@ def setup_filterset(filterset_class):
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data
"""
"""Create a filterset for the given model using the provided meta data"""
meta.update({"model": model})
meta_class = type(str("Meta"), (object,), meta)
filterset = type(

View File

@ -9,6 +9,7 @@ import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.filter import ArrayFilter, ListFilter
from ...compat import ArrayField
@ -27,58 +28,71 @@ else:
STORE = {"events": []}
@pytest.fixture
def Event():
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
return Event
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField())
@pytest.fixture
def EventFilterSet(Event):
from django.contrib.postgres.forms import SimpleArrayField
class ArrayFilter(filters.Filter):
base_field_class = SimpleArrayField
def EventFilterSet():
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact"],
"name": ["exact", "contains"],
}
# Those are actually usable with our Query fixture bellow
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
random_field__contains = ArrayFilter(
field_name="random_field", lookup_expr="contains"
)
random_field__overlap = ArrayFilter(
field_name="random_field", lookup_expr="overlap"
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
return EventFilterSet
@pytest.fixture
def EventType(Event, EventFilterSet):
def EventType(EventFilterSet):
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet
return EventType
@pytest.fixture
def Query(Event, EventType):
def Query(EventType):
"""
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields.
"""
class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)
def resolve_events(self, info, **kwargs):
events = [
Event(name="Live Show", tags=["concert", "music", "rock"],),
Event(name="Musical", tags=["movie", "music"],),
Event(name="Ballet", tags=["concert", "dance"],),
Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]
STORE["events"] = events
@ -105,6 +119,13 @@ def Query(Event, EventType):
STORE["events"],
)
)
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
)
)
def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs)
@ -121,7 +142,9 @@ def Query(Event, EventType):
m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__
m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)
return m_queryset

View File

@ -10,7 +10,7 @@ class ArticleFilter(django_filters.FilterSet):
fields = {
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"],
"reporter": ["exact", "in"],
}
order_by = OrderingFilter(fields=("pub_date",))

View File

@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_multiple(Query):
def test_array_field_contains_multiple(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
@ -32,9 +32,9 @@ def test_string_contains_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_one(Query):
def test_array_field_contains_one(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
@ -59,9 +59,9 @@ def test_string_contains_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_none(Query):
def test_array_field_contains_empty_list(Query):
"""
Test contains filter on a string field.
Test contains filter on a array field of string.
"""
schema = Schema(query=Query)
@ -79,4 +79,9 @@ def test_string_contains_none(Query):
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
{"node": {"name": "Speech"}},
]

View File

@ -0,0 +1,127 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_no_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["movie", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_empty_list(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Speech"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_filter_schema_type(Query):
"""
Check that the type in the filter is an array field like on the object type.
"""
schema = Schema(query=Query)
schema_str = str(schema)
assert (
'''type EventType implements Node {
"""The ID of the object"""
id: ID!
name: String!
tags: [String!]!
tagIds: [Int!]!
randomField: [Boolean!]!
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"name": "String",
"name_Contains": "String",
"tags_Contains": "[String!]",
"tags_Overlap": "[String!]",
"tags": "[String!]",
"tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]",
"randomField_Contains": "[Boolean!]",
"randomField_Overlap": "[Boolean!]",
"randomField": "[Boolean!]",
}
filters_str = ", ".join(
[f"{filter_field}: {gql_type}" for filter_field, gql_type in filters.items()]
)
assert (
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
)

View File

@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_multiple(Query):
def test_array_field_overlap_multiple(Query):
"""
Test overlap filter on a string field.
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
@ -34,9 +34,9 @@ def test_string_overlap_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_one(Query):
def test_array_field_overlap_one(Query):
"""
Test overlap filter on a string field.
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)
@ -61,9 +61,9 @@ def test_string_overlap_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_none(Query):
def test_array_field_overlap_empty_list(Query):
"""
Test overlap filter on a string field.
Test overlap filter on a array field of string.
"""
schema = Schema(query=Query)

View File

@ -0,0 +1,163 @@
import pytest
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType, DjangoConnectionField
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"lang": ["exact", "in"],
"reporter__a_choice": ["exact", "in"],
}
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
@pytest.fixture
def reporter_article_data():
john = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
jane = Reporter.objects.create(
first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2
)
Article.objects.create(
headline="Article Node 1", reporter=john, editor=john, lang="es"
)
Article.objects.create(
headline="Article Node 2", reporter=john, editor=john, lang="en"
)
Article.objects.create(
headline="Article Node 3", reporter=jane, editor=jane, lang="en"
)
def test_filter_enum_on_connection(schema, reporter_article_data):
"""
Check that we can filter with enums on a connection.
"""
query = """
query {
allArticles(lang: ES) {
edges {
node {
headline
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_on_foreign_key_enum_field(schema, reporter_article_data):
"""
Check that we can filter with enums on a field from a foreign key.
"""
query = """
query {
allArticles(reporter_AChoice: A_1) {
edges {
node {
headline
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
{"node": {"headline": "Article Node 2"}},
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_enum_field_schema_type(schema):
"""
Check that the type in the filter is an enum like on the object type.
"""
schema_str = str(schema)
assert (
'''type ArticleType implements Node {
"""The ID of the object"""
id: ID!
headline: String!
pubDate: Date!
pubDateTime: DateTime!
reporter: ReporterType!
editor: ReporterType!
"""Language"""
lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"reporter_AChoice": "TestsReporterAChoiceChoices",
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
}
filters_str = ", ".join(
[f"{filter_field}: {gql_type}" for filter_field, gql_type in filters.items()]
)
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str

View File

@ -5,18 +5,18 @@ import pytest
from django.db.models import TextField, Value
from django.db.models.functions import Concat
from graphene import Argument, Boolean, Field, Float, ObjectType, Schema, String
from graphene import Argument, Boolean, Decimal, Field, ObjectType, Schema, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.tests.models import Article, Person, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
import django_filters
from django_filters import FilterSet, NumberFilter
from django_filters import FilterSet, NumberFilter, OrderingFilter
from graphene_django.filter import (
GlobalIDFilter,
@ -90,6 +90,7 @@ def test_filter_explicit_filterset_arguments():
"pub_date__gt",
"pub_date__lt",
"reporter",
"reporter__in",
)
@ -400,7 +401,7 @@ def test_filterset_descriptions():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
max_time = field.args["max_time"]
assert isinstance(max_time, Argument)
assert max_time.type == Float
assert max_time.type == Decimal
assert max_time.description == "The maximum time"
@ -696,7 +697,7 @@ def test_should_query_filter_node_limit():
node {
id
firstName
articles(lang: "es") {
articles(lang: ES) {
edges {
node {
id
@ -738,6 +739,7 @@ def test_order_by():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
@ -1006,7 +1008,7 @@ def test_integer_field_filter_type():
assert str(schema) == dedent(
"""\
type Query {
pets(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null, age: Int = null): PetTypeConnection
pets(offset: Int, before: String, after: String, first: Int, last: Int, age: Int): PetTypeConnection
}
type PetTypeConnection {
@ -1054,8 +1056,7 @@ def test_integer_field_filter_type():
interface Node {
\"""The ID of the object\"""
id: ID!
}
"""
}"""
)
@ -1075,7 +1076,7 @@ def test_other_filter_types():
assert str(schema) == dedent(
"""\
type Query {
pets(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null, age: Int = null, age_Isnull: Boolean = null, age_Lt: Int = null): PetTypeConnection
pets(offset: Int, before: String, after: String, first: Int, last: Int, age: Int, age_Isnull: Boolean, age_Lt: Int): PetTypeConnection
}
type PetTypeConnection {
@ -1123,8 +1124,7 @@ def test_other_filter_types():
interface Node {
\"""The ID of the object\"""
id: ID!
}
"""
}"""
)
@ -1143,7 +1143,7 @@ def test_filter_filterset_based_on_mixin():
return filters
def filter_email_in(cls, queryset, name, value):
def filter_email_in(self, queryset, name, value):
return queryset.filter(**{name: [value]})
class NewArticleFilter(ArticleFilterMixin, ArticleFilter):
@ -1224,7 +1224,81 @@ def test_filter_filterset_based_on_mixin():
}
}
result = schema.execute(query, variable_values={"email": reporter_1.email},)
result = schema.execute(query, variable_values={"email": reporter_1.email})
assert not result.errors
assert result.data == expected
def test_filter_string_contains():
class PersonType(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
fields = "__all__"
filter_fields = {"name": ["exact", "in", "contains", "icontains"]}
class Query(ObjectType):
people = DjangoFilterConnectionField(PersonType)
schema = Schema(query=Query)
Person.objects.bulk_create(
[
Person(name="Jack"),
Person(name="Joe"),
Person(name="Jane"),
Person(name="Peter"),
Person(name="Bob"),
]
)
query = """query nameContain($filter: String) {
people(name_Contains: $filter) {
edges {
node {
name
}
}
}
}"""
result = schema.execute(query, variables={"filter": "Ja"})
assert not result.errors
assert result.data == {
"people": {
"edges": [
{"node": {"name": "Jack"}},
{"node": {"name": "Jane"}},
]
}
}
result = schema.execute(query, variables={"filter": "o"})
assert not result.errors
assert result.data == {
"people": {
"edges": [
{"node": {"name": "Joe"}},
{"node": {"name": "Bob"}},
]
}
}
def test_only_custom_filters():
class ReporterFilter(FilterSet):
class Meta:
model = Reporter
fields = []
some_filter = OrderingFilter(fields=("name",))
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
filterset_class = ReporterFilter
field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, "some_filter")

View File

@ -1,3 +1,5 @@
from datetime import datetime
import pytest
from django_filters import FilterSet
@ -5,7 +7,8 @@ from django_filters import rest_framework as filters
from graphene import ObjectType, Schema
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet, Person
from graphene_django.tests.models import Pet, Person, Reporter, Article, Film
from graphene_django.filter.tests.filters import ArticleFilter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
@ -20,40 +23,77 @@ else:
)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
@pytest.fixture
def query():
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"id": ["exact", "in"],
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
}
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
# choice filter using enum
filter_fields = {"reporter_type": ["exact", "in"]}
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilter
class FilmNode(DjangoObjectType):
class Meta:
model = Film
interfaces = (Node,)
fields = "__all__"
# choice filter not using enum
filter_fields = {
"genre": ["exact", "in"],
}
convert_choices_to_enum = False
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {"name": ["in"]}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
"""
This custom filter take a string as input with comma separated values.
Note that the value here is already a list as it has been transformed by the BaseInFilter class.
"""
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
fields = "__all__"
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
articles = DjangoFilterConnectionField(ArticleNode)
films = DjangoFilterConnectionField(FilmNode)
reporters = DjangoFilterConnectionField(ReporterNode)
return Query
class PersonFilterSet(FilterSet):
class Meta:
model = Person
fields = {}
names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value):
return qs.filter(name__in=value)
class PersonNode(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
filterset_class = PersonFilterSet
class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode)
def test_string_in_filter():
def test_string_in_filter(query):
"""
Test in filter on a string field.
"""
@ -61,7 +101,7 @@ def test_string_in_filter():
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
@ -82,17 +122,19 @@ def test_string_in_filter():
]
def test_string_in_filter_with_filterset_class():
"""Test in filter on a string field with a custom filterset class."""
def test_string_in_filter_with_otjer_filter(query):
"""
Test in filter on a string field which has also a custom filter doing a similar operation.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
people (names: ["John", "Michael"]) {
people (name_In: ["John", "Michael"]) {
edges {
node {
name
@ -109,7 +151,36 @@ def test_string_in_filter_with_filterset_class():
]
def test_int_in_filter():
def test_string_in_filter_with_declared_filter(query):
"""
Test in filter on a string field with a custom filterset class.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=query)
query = """
query {
people (names: "John,Michael") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_int_in_filter(query):
"""
Test in filter on an integer field.
"""
@ -117,7 +188,7 @@ def test_int_in_filter():
Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
@ -157,7 +228,7 @@ def test_int_in_filter():
]
def test_in_filter_with_empty_list():
def test_in_filter_with_empty_list(query):
"""
Check that using a in filter with an empty list provided as input returns no objects.
"""
@ -165,7 +236,7 @@ def test_in_filter_with_empty_list():
Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query)
schema = Schema(query=query)
query = """
query {
@ -181,3 +252,197 @@ def test_in_filter_with_empty_list():
result = schema.execute(query)
assert not result.errors
assert len(result.data["pets"]["edges"]) == 0
def test_choice_in_filter_without_enum(query):
"""
Test in filter o an choice field not using an enum (Film.genre).
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
documentary_film = Film.objects.create(genre="do")
documentary_film.reporters.add(john_doe)
action_film = Film.objects.create(genre="ac")
action_film.reporters.add(john_doe)
other_film = Film.objects.create(genre="ot")
other_film.reporters.add(john_doe)
other_film.reporters.add(jean_bon)
schema = Schema(query=query)
query = """
query {
films (genre_In: ["do", "ac"]) {
edges {
node {
genre
reporters {
edges {
node {
lastName
}
}
}
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["films"]["edges"] == [
{
"node": {
"genre": "do",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
{
"node": {
"genre": "ac",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
]
def test_fk_id_in_filter(query):
"""
Test in filter on an foreign key relationship.
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
sara_croche = Reporter.objects.create(
first_name="Sara", last_name="Croche", email="sara@croche.com"
)
Article.objects.create(
headline="A",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=john_doe,
editor=john_doe,
)
Article.objects.create(
headline="B",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=jean_bon,
editor=jean_bon,
)
Article.objects.create(
headline="C",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=sara_croche,
editor=sara_croche,
)
schema = Schema(query=query)
query = """
query {
articles (reporter_In: [%s, %s]) {
edges {
node {
headline
reporter {
lastName
}
}
}
}
}
""" % (
john_doe.id,
jean_bon.id,
)
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A", "reporter": {"lastName": "Doe"}}},
{"node": {"headline": "B", "reporter": {"lastName": "Bon"}}},
]
def test_enum_in_filter(query):
"""
Test in filter on a choice field using an enum (Reporter.reporter_type).
"""
Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1
)
Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None
)
schema = Schema(query=query)
query = """
query {
reporters (reporterType_In: [A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2, A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]

View File

@ -25,6 +25,7 @@ class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"name": ["exact", "in"],
"age": ["exact", "in", "range"],
@ -101,14 +102,14 @@ def test_range_filter_with_invalid_input():
# Empty list
result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
assert result.errors[0].message == expected_error
# Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
assert result.errors[0].message == expected_error
# More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']"
assert result.errors[0].message == expected_error

View File

@ -0,0 +1,151 @@
import pytest
from django_filters import FilterSet
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import (
DjangoFilterConnectionField,
TypedFilter,
ListFilter,
)
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ArticleFilterSet(FilterSet):
class Meta:
model = Article
fields = {
"lang": ["exact", "in"],
}
lang__contains = TypedFilter(
field_name="lang", lookup_expr="icontains", input_type=graphene.String
)
lang__in_str = ListFilter(
field_name="lang",
lookup_expr="in",
input_type=graphene.List(graphene.String),
)
first_n = TypedFilter(input_type=graphene.Int, method="first_n_filter")
only_first = TypedFilter(
input_type=graphene.Boolean, method="only_first_filter"
)
def first_n_filter(self, queryset, _name, value):
return queryset[:value]
def only_first_filter(self, queryset, _name, value):
if value:
return queryset[:1]
else:
return queryset
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilterSet
class Query(graphene.ObjectType):
articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
def test_typed_filter_schema(schema):
"""
Check that the type provided in the filter is reflected in the schema.
"""
schema_str = str(schema)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"lang_Contains": "String",
"lang_InStr": "[String]",
"firstN": "Int",
"onlyFirst": "Boolean",
}
all_articles_filters = (
schema_str.split(" articles(")[1]
.split("): ArticleTypeConnection\n")[0]
.split(", ")
)
for filter_field, gql_type in filters.items():
assert "{}: {}".format(filter_field, gql_type) in all_articles_filters
def test_typed_filters_work(schema):
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en")
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "C"}},
]
query = "query { articles (firstN: 2) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
]

View File

@ -1,53 +1,103 @@
import graphene
from django_filters.utils import get_model_field
from django import forms
from django_filters.utils import get_model_field, get_field_parts
from django_filters.filters import Filter, BaseCSVFilter
from .filters import ArrayFilter, ListFilter, RangeFilter, TypedFilter
from .filterset import custom_filterset_factory, setup_filterset
from .filters import InFilter, RangeFilter
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
def get_field_type(registry, model, field_name):
"""
Try to get a model field corresponding Graphql type from the DjangoObjectType.
"""
object_type = registry.get_type_for_model(model)
if object_type:
object_type_field = object_type._meta.fields.get(field_name)
if object_type_field:
field_type = object_type_field.type
if isinstance(field_type, graphene.NonNull):
field_type = field_type.of_type
return field_type
return None
def get_filtering_args_from_filterset(filterset_class, type):
""" Inspect a FilterSet and produce the arguments to pass to
a Graphene Field. These arguments will be available to
filter against in the GraphQL
"""
Inspect a FilterSet and produce the arguments to pass to a Graphene Field.
These arguments will be available to filter against in the GraphQL API.
"""
from ..forms.converter import convert_form_field
args = {}
model = filterset_class._meta.model
registry = type._meta.registry
for name, filter_field in filterset_class.base_filters.items():
form_field = None
filter_type = filter_field.lookup_expr
required = filter_field.extra.get("required", False)
field_type = None
form_field = None
if name in filterset_class.declared_filters:
# Get the filter field from the explicitly declared filter
form_field = filter_field.field
field = convert_form_field(form_field)
if (
isinstance(filter_field, TypedFilter)
and filter_field.input_type is not None
):
# First check if the filter input type has been explicitely given
field_type = filter_field.input_type
else:
# Get the filter field with no explicit type declaration
model_field = get_model_field(model, filter_field.field_name)
if filter_type != "isnull" and hasattr(model_field, "formfield"):
form_field = model_field.formfield(
required=filter_field.extra.get("required", False)
)
if name not in filterset_class.declared_filters or isinstance(
filter_field, TypedFilter
):
# Get the filter field for filters that are no explicitly declared.
if filter_type == "isnull":
field = graphene.Boolean(required=required)
else:
model_field = get_model_field(model, filter_field.field_name)
# Fallback to field defined on filter if we can't get it from the
# model field
if not form_field:
form_field = filter_field.field
# Get the form field either from:
# 1. the formfield corresponding to the model field
# 2. the field defined on filter
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(required=required)
if not form_field:
form_field = filter_field.field
field = convert_form_field(form_field)
# First try to get the matching field type from the GraphQL DjangoObjectType
if model_field:
if (
isinstance(form_field, forms.ModelChoiceField)
or isinstance(form_field, forms.ModelMultipleChoiceField)
or isinstance(form_field, GlobalIDMultipleChoiceField)
or isinstance(form_field, GlobalIDFormField)
):
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
field_type = get_field_type(
registry, model_field.related_model, "id"
)
else:
field_type = get_field_type(
registry, model_field.model, model_field.name
)
if filter_type in {"in", "range", "contains", "overlap"}:
# Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of
# the same type as the field. See comments in
# `replace_csv_filters` method for more details.
field = graphene.List(field.get_type())
if not field_type:
# Fallback on converting the form field either because:
# - it's an explicitly declared filters
# - we did not manage to get the type from the model type
form_field = form_field or filter_field.field
field_type = convert_form_field(form_field).get_type()
field_type = field.Argument()
field_type.description = str(filter_field.label) if filter_field.label else None
args[name] = field_type
if isinstance(filter_field, ListFilter) or isinstance(
filter_field, RangeFilter
):
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in `replace_csv_filters` method for more details.
field_type = graphene.List(field_type)
args[name] = graphene.Argument(
field_type,
description=filter_field.label,
required=required,
)
return args
@ -69,18 +119,26 @@ def get_filterset_class(filterset_class, **meta):
def replace_csv_filters(filterset_class):
"""
Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but regular Filter objects that simply use the input value as filter argument on the queryset.
Replace the "in" and "range" filters (that are not explicitly declared)
to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but our custom InFilter/RangeFilter filter class that use the input
value as filter argument on the queryset.
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we
can actually have a list as input and have a proper type verification of each value in the list.
This is because those BaseCSVFilter are expecting a string as input with
comma separated values.
But with GraphQl we can actually have a list as input and have a proper
type verification of each value in the list.
See issue https://github.com/graphql-python/graphene-django/issues/1068.
"""
for name, filter_field in list(filterset_class.base_filters.items()):
# Do not touch any declared filters
if name in filterset_class.declared_filters:
continue
filter_type = filter_field.lookup_expr
if filter_type in {"in", "contains", "overlap"}:
filterset_class.base_filters[name] = InFilter(
if filter_type == "in":
filterset_class.base_filters[name] = ListFilter(
field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr,
label=filter_field.label,
@ -88,7 +146,6 @@ def replace_csv_filters(filterset_class):
exclude=filter_field.exclude,
**filter_field.extra
)
elif filter_type == "range":
filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name,

View File

@ -3,7 +3,19 @@ from functools import singledispatch
from django import forms
from django.core.exceptions import ImproperlyConfigured
from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
from graphene import (
ID,
Boolean,
Decimal,
Float,
Int,
List,
String,
UUID,
Date,
DateTime,
Time,
)
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
@ -57,12 +69,18 @@ def convert_form_field_to_nullboolean(field):
return Boolean(description=get_form_field_description(field))
@convert_form_field.register(forms.DecimalField)
@convert_form_field.register(forms.FloatField)
def convert_form_field_to_float(field):
return Float(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.DecimalField)
def convert_form_field_to_decimal(field):
return Decimal(
description=get_form_field_description(field), required=field.required
)
@convert_form_field.register(forms.MultipleChoiceField)
def convert_form_field_to_string_list(field):
return List(

View File

@ -14,10 +14,6 @@ from graphene.types.utils import yank_fields_from_attrs
from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.registry import get_global_registry
from django.core.exceptions import ValidationError
from django.db import connection
from ..types import ErrorType
from .converter import convert_form_field
@ -105,7 +101,10 @@ class DjangoFormMutation(BaseDjangoFormMutation):
@classmethod
def perform_mutate(cls, form, info):
form.save()
if hasattr(form, "save"):
# `save` method won't exist on plain Django forms, but this mutation can
# in theory be used with `ModelForm`s as well and we do want to save them.
form.save()
return cls(errors=[], **form.cleaned_data)
@ -118,7 +117,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
errors = graphene.List(graphene.NonNull(ErrorType), required=True)
@classmethod
def __init_subclass_with_meta__(

View File

@ -6,6 +6,7 @@ from graphene import (
String,
Int,
Boolean,
Decimal,
Float,
ID,
UUID,
@ -97,8 +98,8 @@ def test_should_float_convert_float():
assert_conversion(forms.FloatField, Float)
def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, Float)
def test_should_decimal_convert_decimal():
assert_conversion(forms.DecimalField, Decimal)
def test_should_multiple_choice_convert_list():

View File

@ -1,3 +1 @@
import graphene
from ..types import ErrorType # noqa Import ErrorType for backwards compatability

View File

@ -48,14 +48,14 @@ class CommandArguments(BaseCommand):
class Command(CommandArguments):
help = "Dump Graphene schema as a JSON or GraphQL file"
can_import_settings = True
requires_system_checks = False
requires_system_checks = []
def save_json_file(self, out, schema_dict, indent):
with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent, sort_keys=True)
def save_graphql_file(self, out, schema):
with open(out, "w") as outfile:
with open(out, "w", encoding="utf-8") as outfile:
outfile.write(print_schema(schema.graphql_schema))
def get_schema(self, schema, out, indent):

View File

@ -110,11 +110,15 @@ def convert_serializer_field_to_bool(field):
@get_graphene_type_from_serializer_field.register(serializers.FloatField)
@get_graphene_type_from_serializer_field.register(serializers.DecimalField)
def convert_serializer_field_to_float(field):
return graphene.Float
@get_graphene_type_from_serializer_field.register(serializers.DecimalField)
def convert_serializer_field_to_decimal(field):
return graphene.Decimal
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
def convert_serializer_field_to_datetime_time(field):
return graphene.types.datetime.DateTime

View File

@ -133,9 +133,9 @@ def test_should_float_convert_float():
assert_conversion(serializers.FloatField, graphene.Float)
def test_should_decimal_convert_float():
def test_should_decimal_convert_decimal():
assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
serializers.DecimalField, graphene.Decimal, max_digits=4, decimal_places=2
)

View File

@ -16,10 +16,7 @@ from __future__ import unicode_literals
from django.conf import settings
from django.test.signals import setting_changed
try:
import importlib # Available in Python 3.1+
except ImportError:
from django.utils import importlib # Will be removed in Django 1.9
import importlib # Available in Python 3.1+
# Copied shamelessly from Django REST Framework

View File

@ -10,14 +10,6 @@
history,
location,
) {
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ("; " + document.cookie).split("; csrftoken=");
if (cookies.length == 2) {
csrftoken = cookies.pop().split(";").shift();
} else {
csrftoken = document.querySelector("[name=csrfmiddlewaretoken]").value;
}
// Collect the URL parameters
var parameters = {};
@ -68,9 +60,19 @@
var headers = opts.headers || {};
headers['Accept'] = headers['Accept'] || 'application/json';
headers['Content-Type'] = headers['Content-Type'] || 'application/json';
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ("; " + document.cookie).split("; csrftoken=");
if (cookies.length == 2) {
csrftoken = cookies.pop().split(";").shift();
} else {
csrftoken = document.querySelector("[name=csrfmiddlewaretoken]").value;
}
if (csrftoken) {
headers['X-CSRFToken'] = csrftoken
}
return fetch(fetchURL, {
method: "post",
headers: headers,
@ -123,8 +125,8 @@
if (operationType === "subscription") {
return {
subscribe: function (observer) {
subscriptionClient.request(graphQLParams).subscribe(observer);
activeSubscription = subscriptionClient;
return subscriptionClient.request(graphQLParams, opts).subscribe(observer);
},
};
} else {

View File

@ -13,6 +13,9 @@ class Person(models.Model):
class Pet(models.Model):
name = models.CharField(max_length=30)
age = models.PositiveIntegerField()
owner = models.ForeignKey(
"Person", on_delete=models.CASCADE, null=True, blank=True, related_name="pets"
)
class FilmDetails(models.Model):
@ -26,7 +29,7 @@ class Film(models.Model):
genre = models.CharField(
max_length=2,
help_text="Genre",
choices=[("do", "Documentary"), ("ot", "Other")],
choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")],
default="ot",
)
reporters = models.ManyToManyField("Reporter", related_name="films")
@ -91,8 +94,8 @@ class CNNReporter(Reporter):
class Article(models.Model):
headline = models.CharField(max_length=100)
pub_date = models.DateField()
pub_date_time = models.DateTimeField()
pub_date = models.DateField(auto_now_add=True)
pub_date_time = models.DateTimeField(auto_now_add=True)
reporter = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="articles"
)

View File

@ -53,6 +53,5 @@ def test_generate_graphql_file_on_call_graphql_schema():
"""\
type Query {
hi: String
}
"""
}"""
)

View File

@ -10,6 +10,7 @@ from graphene import NonNull
from graphene.relay import ConnectionField, Node
from graphene.types.datetime import Date, DateTime, Time
from graphene.types.json import JSONString
from graphene.types.scalars import BigInt
from ..compat import (
ArrayField,
@ -111,6 +112,15 @@ def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID, primary_key=True)
def test_should_big_auto_convert_id():
assert_conversion(models.BigAutoField, graphene.ID, primary_key=True)
def test_should_small_auto_convert_id():
if hasattr(models, "SmallAutoField"):
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True)
def test_should_uuid_convert_id():
assert_conversion(models.UUIDField, graphene.UUID)
@ -131,8 +141,8 @@ def test_should_small_integer_convert_int():
assert_conversion(models.SmallIntegerField, graphene.Int)
def test_should_big_integer_convert_int():
assert_conversion(models.BigIntegerField, graphene.Int)
def test_should_big_integer_convert_big_int():
assert_conversion(models.BigIntegerField, BigInt)
def test_should_integer_convert_int():

View File

@ -408,3 +408,95 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []},
]
}
def test_resolve_list_external_resolver(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_filter_external_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return ReporterModel.objects.all()
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}

View File

@ -0,0 +1,361 @@
import pytest
import graphene
from graphene.relay import Node
from graphql_relay import to_global_id
from ..fields import DjangoConnectionField
from ..types import DjangoObjectType
from .models import Article, Reporter
class TestShouldCallGetQuerySetOnForeignKey:
"""
Check that the get_queryset method is called in both forward and reversed direction
of a foreignkey on types.
(see issue #1111)
"""
@pytest.fixture(autouse=True)
def setup_schema(self):
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
@classmethod
def get_queryset(cls, queryset, info):
if info.context and info.context.get("admin"):
return queryset
raise Exception("Not authorized to access reporters.")
class ArticleType(DjangoObjectType):
class Meta:
model = Article
@classmethod
def get_queryset(cls, queryset, info):
return queryset.exclude(headline__startswith="Draft")
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType, id=graphene.ID(required=True))
article = graphene.Field(ArticleType, id=graphene.ID(required=True))
def resolve_reporter(self, info, id):
return (
ReporterType.get_queryset(Reporter.objects, info)
.filter(id=id)
.last()
)
def resolve_article(self, info, id):
return (
ArticleType.get_queryset(Article.objects, info).filter(id=id).last()
)
self.schema = graphene.Schema(query=Query)
self.reporter = Reporter.objects.create(first_name="Jane", last_name="Doe")
self.articles = [
Article.objects.create(
headline="A fantastic article",
reporter=self.reporter,
editor=self.reporter,
),
Article.objects.create(
headline="Draft: My next best seller",
reporter=self.reporter,
editor=self.reporter,
),
]
def test_get_queryset_called_on_field(self):
# If a user tries to access an article it is fine as long as it's not a draft one
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
}
}
"""
# Non-draft
result = self.schema.execute(query, variables={"id": self.articles[0].id})
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
}
# Draft
result = self.schema.execute(query, variables={"id": self.articles[1].id})
assert not result.errors
assert result.data["article"] is None
# If a non admin user tries to access a reporter they should get our authorization error
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(query, variables={"id": self.reporter.id})
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.reporter.id},
context_value={"admin": True},
)
assert not result.errors
assert result.data == {"reporter": {"firstName": "Jane"}}
def test_get_queryset_called_on_foreignkey(self):
# If a user tries to access a reporter through an article they should get our authorization error
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(query, variables={"id": self.articles[0].id})
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters through an article
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.articles[0].id},
context_value={"admin": True},
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
"reporter": {"firstName": "Jane"},
}
# An admin user should not be able to access draft article through a reporter
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
articles {
headline
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.reporter.id},
context_value={"admin": True},
)
assert not result.errors
assert result.data["reporter"] == {
"firstName": "Jane",
"articles": [{"headline": "A fantastic article"}],
}
class TestShouldCallGetQuerySetOnForeignKeyNode:
"""
Check that the get_queryset method is called in both forward and reversed direction
of a foreignkey on types using a node interface.
(see issue #1111)
"""
@pytest.fixture(autouse=True)
def setup_schema(self):
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
@classmethod
def get_queryset(cls, queryset, info):
if info.context and info.context.get("admin"):
return queryset
raise Exception("Not authorized to access reporters.")
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
@classmethod
def get_queryset(cls, queryset, info):
return queryset.exclude(headline__startswith="Draft")
class Query(graphene.ObjectType):
reporter = Node.Field(ReporterType)
article = Node.Field(ArticleType)
self.schema = graphene.Schema(query=Query)
self.reporter = Reporter.objects.create(first_name="Jane", last_name="Doe")
self.articles = [
Article.objects.create(
headline="A fantastic article",
reporter=self.reporter,
editor=self.reporter,
),
Article.objects.create(
headline="Draft: My next best seller",
reporter=self.reporter,
editor=self.reporter,
),
]
def test_get_queryset_called_on_node(self):
# If a user tries to access an article it is fine as long as it's not a draft one
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
}
}
"""
# Non-draft
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[0].id)}
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
}
# Draft
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[1].id)}
)
assert not result.errors
assert result.data["article"] is None
# If a non admin user tries to access a reporter they should get our authorization error
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(
query, variables={"id": to_global_id("ReporterType", self.reporter.id)}
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ReporterType", self.reporter.id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data == {"reporter": {"firstName": "Jane"}}
def test_get_queryset_called_on_foreignkey(self):
# If a user tries to access a reporter through an article they should get our authorization error
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[0].id)}
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters through an article
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ArticleType", self.articles[0].id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
"reporter": {"firstName": "Jane"},
}
# An admin user should not be able to access draft article through a reporter
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
articles {
edges {
node {
headline
}
}
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ReporterType", self.reporter.id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data["reporter"] == {
"firstName": "Jane",
"articles": {"edges": [{"node": {"headline": "A fantastic article"}}]},
}

View File

@ -15,7 +15,7 @@ from ..compat import IntegerRangeField, MissingType
from ..fields import DjangoConnectionField
from ..types import DjangoObjectType
from ..utils import DJANGO_FILTER_INSTALLED
from .models import Article, CNNReporter, Film, FilmDetails, Reporter
from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter
def test_should_query_only_fields():
@ -251,8 +251,8 @@ def test_should_node():
def test_should_query_onetoone_fields():
film = Film(id=1)
film_details = FilmDetails(id=1, film=film)
film = Film.objects.create(id=1)
film_details = FilmDetails.objects.create(id=1, film=film)
class FilmNode(DjangoObjectType):
class Meta:
@ -421,6 +421,7 @@ def test_should_query_node_filtering():
interfaces = (Node,)
fields = "__all__"
filter_fields = ("lang",)
convert_choices_to_enum = False
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -546,6 +547,7 @@ def test_should_query_node_multiple_filtering():
interfaces = (Node,)
fields = "__all__"
filter_fields = ("lang", "headline")
convert_choices_to_enum = False
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1241,6 +1243,7 @@ def test_should_have_next_page(graphene_settings):
}
@pytest.mark.parametrize("max_limit", [100, 4])
class TestBackwardPagination:
def setup_schema(self, graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
@ -1251,6 +1254,7 @@ class TestBackwardPagination:
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1258,8 +1262,8 @@ class TestBackwardPagination:
schema = graphene.Schema(query=Query)
return schema
def do_queries(self, schema):
# Simply last 3
def test_query_last(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_last = """
query {
allReporters(last: 3) {
@ -1279,7 +1283,8 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 3", "First 4", "First 5"]
# Use a combination of first and last
def test_query_first_and_last(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_and_last = """
query {
allReporters(first: 4, last: 3) {
@ -1299,7 +1304,8 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 1", "First 2", "First 3"]
# Use a combination of first and last and after
def test_query_first_last_and_after(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_last_and_after = """
query queryAfter($after: String) {
allReporters(first: 4, last: 3, after: $after) {
@ -1314,7 +1320,8 @@ class TestBackwardPagination:
after = base64.b64encode(b"arrayconnection:0").decode()
result = schema.execute(
query_first_last_and_after, variable_values=dict(after=after)
query_first_last_and_after,
variable_values=dict(after=after),
)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3
@ -1322,20 +1329,35 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 2", "First 3", "First 4"]
def test_should_query(self, graphene_settings):
def test_query_last_and_before(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_last_and_after = """
query queryAfter($before: String) {
allReporters(last: 1, before: $before) {
edges {
node {
firstName
}
}
}
}
"""
Backward pagination should work as expected
"""
schema = self.setup_schema(graphene_settings, max_limit=100)
self.do_queries(schema)
def test_should_query_with_low_max_limit(self, graphene_settings):
"""
When doing backward pagination (using last) in combination with a max limit higher than the number of objects
we should really retrieve the last ones.
"""
schema = self.setup_schema(graphene_settings, max_limit=4)
self.do_queries(schema)
result = schema.execute(
query_first_last_and_after,
)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
before = base64.b64encode(b"arrayconnection:5").decode()
result = schema.execute(
query_first_last_and_after,
variable_values=dict(before=before),
)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
def test_should_preserve_prefetch_related(django_assert_num_queries):
@ -1455,6 +1477,7 @@ def test_connection_should_enable_offset_filtering():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1476,7 +1499,11 @@ def test_connection_should_enable_offset_filtering():
result = schema.execute(query)
assert not result.errors
expected = {
"allReporters": {"edges": [{"node": {"firstName": "Some", "lastName": "Guy"}},]}
"allReporters": {
"edges": [
{"node": {"firstName": "Some", "lastName": "Guy"}},
]
}
}
assert result.data == expected
@ -1494,6 +1521,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1516,7 +1544,9 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
assert not result.errors
expected = {
"allReporters": {
"edges": [{"node": {"firstName": "Some", "lastName": "Lady"}},]
"edges": [
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
@ -1527,6 +1557,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1561,6 +1592,7 @@ def test_connection_should_allow_offset_filtering_with_after():
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
@ -1583,6 +1615,149 @@ def test_connection_should_allow_offset_filtering_with_after():
result = schema.execute(query, variable_values=dict(after=after))
assert not result.errors
expected = {
"allReporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe"}},]}
"allReporters": {
"edges": [
{"node": {"firstName": "Jane", "lastName": "Roe"}},
]
}
}
assert result.data == expected
def test_connection_should_succeed_if_last_higher_than_number_of_objects():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
schema = graphene.Schema(query=Query)
query = """
query ReporterPromiseConnectionQuery ($last: Int) {
allReporters(last: $last) {
edges {
node {
firstName
lastName
}
}
}
}
"""
result = schema.execute(query, variable_values=dict(last=2))
assert not result.errors
expected = {"allReporters": {"edges": []}}
assert result.data == expected
Reporter.objects.create(first_name="John", last_name="Doe")
Reporter.objects.create(first_name="Some", last_name="Guy")
Reporter.objects.create(first_name="Jane", last_name="Roe")
Reporter.objects.create(first_name="Some", last_name="Lady")
result = schema.execute(query, variable_values=dict(last=2))
assert not result.errors
expected = {
"allReporters": {
"edges": [
{"node": {"firstName": "Jane", "lastName": "Roe"}},
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
result = schema.execute(query, variable_values=dict(last=4))
assert not result.errors
expected = {
"allReporters": {
"edges": [
{"node": {"firstName": "John", "lastName": "Doe"}},
{"node": {"firstName": "Some", "lastName": "Guy"}},
{"node": {"firstName": "Jane", "lastName": "Roe"}},
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
result = schema.execute(query, variable_values=dict(last=20))
assert not result.errors
expected = {
"allReporters": {
"edges": [
{"node": {"firstName": "John", "lastName": "Doe"}},
{"node": {"firstName": "Some", "lastName": "Guy"}},
{"node": {"firstName": "Jane", "lastName": "Roe"}},
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
def test_should_query_nullable_foreign_key():
class PetType(DjangoObjectType):
class Meta:
model = Pet
class PersonType(DjangoObjectType):
class Meta:
model = Person
class Query(graphene.ObjectType):
pet = graphene.Field(PetType, name=graphene.String(required=True))
person = graphene.Field(PersonType, name=graphene.String(required=True))
def resolve_pet(self, info, name):
return Pet.objects.filter(name=name).first()
def resolve_person(self, info, name):
return Person.objects.filter(name=name).first()
schema = graphene.Schema(query=Query)
person = Person.objects.create(name="Jane")
pets = [
Pet.objects.create(name="Stray dog", age=1),
Pet.objects.create(name="Jane's dog", owner=person, age=1),
]
query_pet = """
query getPet($name: String!) {
pet(name: $name) {
owner {
name
}
}
}
"""
result = schema.execute(query_pet, variables={"name": "Stray dog"})
assert not result.errors
assert result.data["pet"] == {
"owner": None,
}
result = schema.execute(query_pet, variables={"name": "Jane's dog"})
assert not result.errors
assert result.data["pet"] == {
"owner": {"name": "Jane"},
}
query_owner = """
query getOwner($name: String!) {
person(name: $name) {
pets {
name
}
}
}
"""
result = schema.execute(query_owner, variables={"name": "Jane"})
assert not result.errors
assert result.data["person"] == {
"pets": [{"name": "Jane's dog"}],
}

View File

@ -183,7 +183,7 @@ def test_schema_representation():
pets: [Reporter!]!
aChoice: TestsReporterAChoiceChoices
reporterType: TestsReporterReporterTypeChoices
articles(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null): ArticleConnection!
articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection!
}
\"""An enumeration.\"""
@ -244,8 +244,7 @@ def test_schema_representation():
\"""The ID of the object\"""
id: ID!
): Node
}
"""
}"""
)
assert str(schema) == expected
@ -525,8 +524,7 @@ class TestDjangoObjectType:
id: ID!
kind: String!
cuteness: Int!
}
"""
}"""
)
def test_django_objecttype_convert_choices_enum_list(self, PetModel):
@ -560,8 +558,7 @@ class TestDjangoObjectType:
\"""Dog\"""
DOG
}
"""
}"""
)
def test_django_objecttype_convert_choices_enum_empty_list(self, PetModel):
@ -586,8 +583,7 @@ class TestDjangoObjectType:
id: ID!
kind: String!
cuteness: Int!
}
"""
}"""
)
def test_django_objecttype_convert_choices_enum_naming_collisions(
@ -621,8 +617,7 @@ class TestDjangoObjectType:
\"""Dog\"""
DOG
}
"""
}"""
)
def test_django_objecttype_choices_custom_enum_name(
@ -660,8 +655,7 @@ class TestDjangoObjectType:
\"""Dog\"""
DOG
}
"""
}"""
)
@ -671,6 +665,7 @@ def test_django_objecttype_name_connection_propagation():
class Meta:
model = ReporterModel
name = "CustomReporterName"
fields = "__all__"
filter_fields = ["email"]
interfaces = (Node,)

View File

@ -83,6 +83,6 @@ def client_query(client):
def test_pytest_fixture_usage(client_query):
response = graphql_query("query { test }")
response = client_query("query { test }")
content = json.loads(response.content)
assert content == {"data": {"test": "Hello World"}}

View File

@ -109,12 +109,10 @@ def test_reports_validation_errors(client):
{
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 9}],
"path": None,
},
{
"message": "Cannot query field 'unknownTwo' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 21}],
"path": None,
},
]
}
@ -135,8 +133,6 @@ def test_errors_when_missing_operation_name(client):
"errors": [
{
"message": "Must provide operation name if query contains multiple operations.",
"locations": None,
"path": None,
}
]
}
@ -477,7 +473,6 @@ def test_handles_syntax_errors_caught_by_graphql(client):
{
"locations": [{"column": 1, "line": 1}],
"message": "Syntax Error: Unexpected Name 'syntaxerror'.",
"path": None,
}
]
}

View File

@ -1,8 +1,8 @@
from django.conf.urls import url
from django.urls import path
from ..views import GraphQLView
urlpatterns = [
url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
url(r"^graphql", GraphQLView.as_view(graphiql=True)),
path("graphql/batch", GraphQLView.as_view(batch=True)),
path("graphql", GraphQLView.as_view(graphiql=True)),
]

View File

@ -1,4 +1,4 @@
from django.conf.urls import url
from django.urls import path
from ..views import GraphQLView
from .schema_view import schema
@ -10,4 +10,4 @@ class CustomGraphQLView(GraphQLView):
pretty = True
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]
urlpatterns = [path("graphql/inherited/", CustomGraphQLView.as_view())]

View File

@ -1,6 +1,6 @@
from django.conf.urls import url
from django.urls import path
from ..views import GraphQLView
from .schema_view import schema
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]
urlpatterns = [path("graphql", GraphQLView.as_view(schema=schema, pretty=True))]

View File

@ -2,11 +2,8 @@ import warnings
from collections import OrderedDict
from typing import Type
from django.db.models import Model
from django.utils.functional import SimpleLazyObject
import graphene
from graphene import Field
from django.db.models import Model
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
@ -21,7 +18,6 @@ from .utils import (
is_valid_django_model,
)
ALL_FIELDS = "__all__"
@ -108,12 +104,7 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
(
'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. '
'Either remove the custom field or remove the field from the "exclude" list.'
).format(
field_name=name,
app_label=model._meta.app_label,
object_name=model._meta.object_name,
type_=type_,
)
).format(field_name=name, type_=type_)
)
else:
if not hasattr(model, name):
@ -131,7 +122,7 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
class DjangoObjectTypeOptions(ObjectTypeOptions):
model = None # type: Model
model = None # type: Type[Model]
registry = None # type: Registry
connection = None # type: Type[Connection]
@ -225,14 +216,14 @@ class DjangoObjectType(ObjectType):
"Creating a DjangoObjectType without either the `fields` "
"or the `exclude` option is deprecated. Add an explicit `fields "
"= '__all__'` option on DjangoObjectType {class_name} to use all "
"fields".format(class_name=cls.__name__,),
"fields".format(class_name=cls.__name__),
DeprecationWarning,
stacklevel=2,
)
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, fields, exclude, convert_choices_to_enum),
_as=Field,
_as=graphene.Field,
)
if use_connection is None and interfaces:

View File

@ -1,10 +1,10 @@
import json
import warnings
from django.test import Client, TestCase
from django.test import Client, TestCase, TransactionTestCase
from graphene_django.settings import graphene_settings
DEFAULT_GRAPHQL_URL = "/graphql"
def graphql_query(
query,
@ -19,7 +19,7 @@ def graphql_query(
Args:
query (string) - GraphQL query to run
operation_name (string) - If the query is a mutation or named query, you must
supply the op_name. For annon queries ("{ ... }"),
supply the operation_name. For annon queries ("{ ... }"),
should be None (default).
input_data (dict) - If provided, the $input variable in GraphQL will be set
to this value. If both ``input_data`` and ``variables``,
@ -28,7 +28,9 @@ def graphql_query(
variables (dict) - If provided, the "variables" field in GraphQL will be
set to this value.
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
will be set to this value.
will be set to this value. Keys should be prepended with
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
use "HTTP_AUTHORIZATION" as the key).
client (django.test.Client) - Test client. Defaults to django.test.Client.
graphql_url (string) - URL to graphql endpoint. Defaults to "/graphql".
@ -61,7 +63,7 @@ def graphql_query(
return resp
class GraphQLTestCase(TestCase):
class GraphQLTestMixin(object):
"""
Based on: https://www.sam.today/blog/testing-graphql-with-graphene-django/
"""
@ -76,7 +78,7 @@ class GraphQLTestCase(TestCase):
Args:
query (string) - GraphQL query to run
operation_name (string) - If the query is a mutation or named query, you must
supply the op_name. For annon queries ("{ ... }"),
supply the operation_name. For annon queries ("{ ... }"),
should be None (default).
input_data (dict) - If provided, the $input variable in GraphQL will be set
to this value. If both ``input_data`` and ``variables``,
@ -85,7 +87,9 @@ class GraphQLTestCase(TestCase):
variables (dict) - If provided, the "variables" field in GraphQL will be
set to this value.
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
will be set to this value.
will be set to this value. Keys should be prepended with
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
use "HTTP_AUTHORIZATION" as the key).
Returns:
Response object from client
@ -139,3 +143,11 @@ class GraphQLTestCase(TestCase):
"""
content = json.loads(resp.content)
self.assertIn("errors", list(content.keys()), msg or content)
class GraphQLTestCase(GraphQLTestMixin, TestCase):
pass
class GraphQLTransactionTestCase(GraphQLTestMixin, TransactionTestCase):
pass

View File

@ -6,4 +6,4 @@ def test_to_const():
def test_to_const_unicode():
assert to_const(u"Skoða þetta unicode stöff") == "SKODA_THETTA_UNICODE_STOFF"
assert to_const("Skoða þetta unicode stöff") == "SKODA_THETTA_UNICODE_STOFF"

View File

@ -11,7 +11,6 @@ from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import View
from graphql import OperationType, get_operation_ast, parse, validate
from graphql.error import GraphQLError
from graphql.error import format_error as format_graphql_error
from graphql.execution import ExecutionResult
from graphene import Schema
@ -58,23 +57,23 @@ class GraphQLView(View):
graphiql_template = "graphene/graphiql.html"
# Polyfill for window.fetch.
whatwg_fetch_version = "3.2.0"
whatwg_fetch_sri = "sha256-l6HCB9TT2v89oWbDdo2Z3j+PSVypKNLA/nqfzSbM8mo="
whatwg_fetch_version = "3.6.2"
whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c="
# React and ReactDOM.
react_version = "16.13.1"
react_sri = "sha256-yUhvEmYVhZ/GGshIQKArLvySDSh6cdmdcIx0spR3UP4="
react_dom_sri = "sha256-vFt3l+illeNlwThbDUdoPTqF81M8WNSZZZt3HEjsbSU="
react_version = "17.0.2"
react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8="
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
# The GraphiQL React app.
graphiql_version = "1.0.3"
graphiql_sri = "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
graphiql_css_sri = "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
graphiql_version = "1.4.1" # "1.0.3"
graphiql_sri = "sha256-JUMkXBQWZMfJ7fGEsTXalxVA10lzKOS9loXdLjwZKi4=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
graphiql_css_sri = "sha256-Md3vdR7PDzWyo/aGfsFVF4tvS5/eAUWuIsg9QHUusCY=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
# The websocket transport library for subscriptions.
subscriptions_transport_ws_version = "0.9.17"
subscriptions_transport_ws_version = "0.9.18"
subscriptions_transport_ws_sri = (
"sha256-kCDzver8iRaIQ/SVlfrIwxaBQ/avXf9GQFJRLlErBnk="
"sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw="
)
schema = None
@ -84,6 +83,7 @@ class GraphQLView(View):
pretty = False
batch = False
subscription_path = None
execution_context_class = None
def __init__(
self,
@ -386,7 +386,7 @@ class GraphQLView(View):
@staticmethod
def format_error(error):
if isinstance(error, GraphQLError):
return format_graphql_error(error)
return error.formatted
return {"message": str(error)}

View File

@ -14,22 +14,22 @@ rest_framework_require = ["djangorestframework>=3.6.3"]
tests_require = [
"pytest>=3.6.3",
"pytest>=7.1.3",
"pytest-cov",
"pytest-random-order",
"coveralls",
"mock",
"pytz",
"django-filter>=2",
"pytest-django>=3.3.2",
"django-filter>=22.1",
"pytest-django>=4.5.2",
] + rest_framework_require
dev_requires = [
"black==19.10b0",
"flake8==3.7.9",
"flake8-black==0.1.1",
"flake8-bugbear==20.1.4",
"black==22.8.0",
"flake8==5.0.4",
"flake8-black==0.3.3",
"flake8-bugbear==22.9.11",
] + tests_require
setup(
@ -46,21 +46,23 @@ setup(
"Intended Audience :: Developers",
"Topic :: Software Development :: Libraries",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Django",
"Framework :: Django :: 1.11",
"Framework :: Django :: 2.2",
"Framework :: Django :: 3.0",
"Framework :: Django :: 3.2",
"Framework :: Django :: 4.0",
"Framework :: Django :: 4.1",
],
keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests", "examples", "examples.*"]),
install_requires=[
"graphene>=3.0.0b5,<4",
"graphene>=3.0,<4",
"graphql-core>=3.1.0,<4",
"Django>=2.2",
"graphql-relay>=3.1.1,<4",
"Django>=3.2",
"promise>=2.1",
"text-unidecode",
],

30
tox.ini
View File

@ -1,20 +1,22 @@
[tox]
envlist =
py{36,37,38}-django{22,30,31,master},
py{37,38,39,310}-django32,
py{38,39,310}-django{40,41,main},
black,flake8
[gh-actions]
python =
3.6: py36
3.7: py37
3.8: py38
3.9: py39
3.10: py310
[gh-actions:env]
DJANGO =
2.2: django22
3.0: django30
3.1: django31
master: djangomaster
3.2: django32
4.0: django40
4.1: django41
main: djangomain
[testenv]
passenv = *
@ -24,24 +26,20 @@ setenv =
deps =
-e.[test]
psycopg2-binary
django111: Django>=1.11,<2.0
django111: djangorestframework<3.12
django20: Django>=2.0,<2.1
django21: Django>=2.1,<2.2
django22: Django>=2.2,<3.0
django30: Django>=3.0a1,<3.1
django31: Django>=3.1,<3.2
djangomaster: https://github.com/django/django/archive/master.zip
django32: Django>=3.2,<4.0
django40: Django>=4.0,<4.1
django41: Django>=4.1,<4.2
djangomain: https://github.com/django/django/archive/main.zip
commands = {posargs:py.test --cov=graphene_django graphene_django examples}
[testenv:black]
basepython = python3.8
basepython = python3.9
deps = -e.[dev]
commands =
black --exclude "/migrations/" graphene_django examples setup.py --check
[testenv:flake8]
basepython = python3.8
basepython = python3.9
deps = -e.[dev]
commands =
flake8 graphene_django examples setup.py