Merge branch 'main' into feature/user-error-middle-ware

This commit is contained in:
Firas K 2023-03-01 11:24:15 +03:00 committed by GitHub
commit 1f46804e81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
109 changed files with 3193 additions and 718 deletions

View File

@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v2
- name: Set up Python 3.8 - name: Set up Python 3.9
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: 3.8 python-version: 3.9
- name: Build wheel and source tarball - name: Build wheel and source tarball
run: | run: |
pip install wheel pip install wheel

View File

@ -7,16 +7,16 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v2
- name: Set up Python 3.8 - name: Set up Python 3.9
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: 3.8 python-version: 3.9
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install tox pip install tox
- name: Run lint 💅 - name: Run pre-commit 💅
run: tox run: tox
env: env:
TOXENV: flake8 TOXENV: pre-commit

View File

@ -8,13 +8,15 @@ jobs:
strategy: strategy:
max-parallel: 4 max-parallel: 4
matrix: matrix:
django: ["2.2", "3.0", "3.1"] django: ["3.2", "4.0", "4.1"]
python-version: ["3.6", "3.7", "3.8"] python-version: ["3.8", "3.9", "3.10"]
include:
- django: "3.2"
python-version: "3.7"
steps: steps:
- uses: actions/checkout@v1 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies

30
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,30 @@
default_language_version:
python: python3.9
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: check-merge-conflict
- id: check-json
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
exclude: ^docs/.*$
- id: pretty-format-json
args:
- --autofix
- id: trailing-whitespace
exclude: README.md
- repo: https://github.com/asottile/pyupgrade
rev: v3.2.0
hooks:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8

View File

@ -1,22 +1,21 @@
.PHONY: help
help:
@echo "Please use \`make <target>' where <target> is one of"
@grep -E '^\.PHONY: [a-zA-Z_-]+ .*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = "(: |##)"}; {printf "\033[36m%-30s\033[0m %s\n", $$2, $$3}'
.PHONY: dev-setup ## Install development dependencies .PHONY: dev-setup ## Install development dependencies
dev-setup: dev-setup:
pip install -e ".[dev]" pip install -e ".[dev]"
.PHONY: install-dev .PHONY: tests ## Run unit tests
install-dev: dev-setup # Alias install-dev -> dev-setup
.PHONY: tests
tests: tests:
py.test graphene_django --cov=graphene_django -vv py.test graphene_django --cov=graphene_django -vv
.PHONY: test .PHONY: format ## Format code
test: tests # Alias test -> tests
.PHONY: format
format: format:
black --exclude "/migrations/" graphene_django examples setup.py black graphene_django examples setup.py
.PHONY: lint .PHONY: lint ## Lint code
lint: lint:
flake8 graphene_django examples flake8 graphene_django examples

View File

@ -55,7 +55,7 @@ from graphene_django.views import GraphQLView
urlpatterns = [ urlpatterns = [
# ... # ...
path('graphql', GraphQLView.as_view(graphiql=True)), path('graphql/', GraphQLView.as_view(graphiql=True)),
] ]
``` ```

View File

@ -198,7 +198,7 @@ For Django 2.2 and above:
urlpatterns = [ urlpatterns = [
# some other urls # some other urls
path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)), path('graphql/', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
] ]
.. _LoginRequiredMixin: https://docs.djangoproject.com/en/dev/topics/auth/default/#the-loginrequired-mixin .. _LoginRequiredMixin: https://docs.djangoproject.com/en/dev/topics/auth/default/#the-loginrequired-mixin

View File

@ -4,7 +4,7 @@ Django Debug Middleware
You can debug your GraphQL queries in a similar way to You can debug your GraphQL queries in a similar way to
`django-debug-toolbar <https://django-debug-toolbar.readthedocs.org/>`__, `django-debug-toolbar <https://django-debug-toolbar.readthedocs.org/>`__,
but outputting in the results in GraphQL response as fields, instead of but outputting in the results in GraphQL response as fields, instead of
the graphical HTML interface. the graphical HTML interface. Exceptions with their stack traces are also exposed.
For that, you will need to add the plugin in your graphene schema. For that, you will need to add the plugin in your graphene schema.
@ -63,6 +63,10 @@ the GraphQL request, like:
sql { sql {
rawSql rawSql
} }
exceptions {
message
stack
}
} }
} }

View File

@ -2,8 +2,8 @@ Filtering
========= =========
Graphene integrates with Graphene integrates with
`django-filter <https://django-filter.readthedocs.io/en/master/>`__ to provide filtering of results. `django-filter <https://django-filter.readthedocs.io/en/main/>`__ to provide filtering of results.
See the `usage documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__ See the `usage documentation <https://django-filter.readthedocs.io/en/main/guide/usage.html#the-filter>`__
for details on the format for ``filter_fields``. for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``. This filtering is automatically available when implementing a ``relay.Node``.
@ -26,7 +26,7 @@ After installing ``django-filter`` you'll need to add the application in the ``s
] ]
Note: The techniques below are demoed in the `cookbook example Note: The techniques below are demoed in the `cookbook example
app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__. app <https://github.com/graphql-python/graphene-django/tree/main/examples/cookbook>`__.
Filterable fields Filterable fields
----------------- -----------------
@ -34,7 +34,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering ``django-filter``, so see the `filtering
documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__ documentation <https://django-filter.readthedocs.io/en/main/guide/usage.html#the-filter>`__
for full details on the range of options available. for full details on the range of options available.
For example: For example:
@ -192,7 +192,7 @@ in unison with the ``filter_fields`` parameter:
all_animals = DjangoFilterConnectionField(AnimalNode) all_animals = DjangoFilterConnectionField(AnimalNode)
The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/master/guide/usage.html#request-based-filtering>`__ The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/main/guide/usage.html#request-based-filtering>`__
in a ``django_filters.FilterSet`` instance. You can use this to customize your in a ``django_filters.FilterSet`` instance. You can use this to customize your
filters to be context-dependent. We could modify the ``AnimalFilter`` above to filters to be context-dependent. We could modify the ``AnimalFilter`` above to
pre-filter animals owned by the authenticated user (set in ``context.user``). pre-filter animals owned by the authenticated user (set in ``context.user``).
@ -258,3 +258,86 @@ with this set up, you can now order the users under group:
} }
} }
} }
PostgreSQL `ArrayField`
-----------------------
Graphene provides an easy to implement filters on `ArrayField` as they are not natively supported by django_filters:
.. code:: python
from django.db import models
from django_filters import FilterSet, OrderingFilter
from graphene_django.filter import ArrayFilter
class Event(models.Model):
name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50))
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet
with this set up, you can now filter events by tags:
.. code::
query {
events(tags_Overlap: ["concert", "festival"]) {
name
}
}
`TypedFilter`
-------------
Sometimes the automatic detection of the filter input type is not satisfactory for what you are trying to achieve.
You can then explicitly specify the input type you want for your filter by using a `TypedFilter`:
.. code:: python
from django.db import models
from django_filters import FilterSet, OrderingFilter
import graphene
from graphene_django.filter import TypedFilter
class Event(models.Model):
name = models.CharField(max_length=50)
class EventFilterSet(FilterSet):
class Meta:
model = Event
fields = {
"name": ["exact", "contains"],
}
only_first = TypedFilter(input_type=graphene.Boolean, method="only_first_filter")
def only_first_filter(self, queryset, _name, value):
if value:
return queryset[:1]
else:
return queryset
class EventType(DjangoObjectType):
class Meta:
model = Event
interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet

View File

@ -151,7 +151,7 @@ For example the following ``Model`` and ``DjangoObjectType``:
Results in the following GraphQL schema definition: Results in the following GraphQL schema definition:
.. code:: .. code:: graphql
type Pet { type Pet {
id: ID! id: ID!
@ -178,7 +178,7 @@ You can disable this automatic conversion by setting
fields = ("id", "kind",) fields = ("id", "kind",)
convert_choices_to_enum = False convert_choices_to_enum = False
.. code:: .. code:: graphql
type Pet { type Pet {
id: ID! id: ID!
@ -313,7 +313,7 @@ Additionally, Resolvers will receive **any arguments declared in the field defin
bar=graphene.Int() bar=graphene.Int()
) )
def resolve_question(root, info, foo, bar): def resolve_question(root, info, foo=None, bar=None):
# If `foo` or `bar` are declared in the GraphQL query they will be here, else None. # If `foo` or `bar` are declared in the GraphQL query they will be here, else None.
return Question.objects.filter(foo=foo, bar=bar).first() return Question.objects.filter(foo=foo, bar=bar).first()
@ -418,7 +418,7 @@ the core graphene pages for more information on customizing the Relay experience
You can now execute queries like: You can now execute queries like:
.. code:: python .. code:: graphql
{ {
questions (first: 2, after: "YXJyYXljb25uZWN0aW9uOjEwNQ==") { questions (first: 2, after: "YXJyYXljb25uZWN0aW9uOjEwNQ==") {
@ -440,7 +440,7 @@ You can now execute queries like:
Which returns: Which returns:
.. code:: python .. code:: json
{ {
"data": { "data": {

View File

@ -8,25 +8,22 @@
class CategoryType(DjangoObjectType): class CategoryType(DjangoObjectType):
class Meta: class Meta:
model = Category model = Category
fields = '__all__' fields = "__all__"
class IngredientType(DjangoObjectType): class IngredientType(DjangoObjectType):
class Meta: class Meta:
model = Ingredient model = Ingredient
fields = '__all__' fields = "__all__"
class Query(object): class Query:
category = graphene.Field(CategoryType, category = graphene.Field(CategoryType, id=graphene.Int(), name=graphene.String())
id=graphene.Int(),
name=graphene.String())
all_categories = graphene.List(CategoryType) all_categories = graphene.List(CategoryType)
ingredient = graphene.Field(
ingredient = graphene.Field(IngredientType, IngredientType, id=graphene.Int(), name=graphene.String()
id=graphene.Int(), )
name=graphene.String())
all_ingredients = graphene.List(IngredientType) all_ingredients = graphene.List(IngredientType)
def resolve_all_categories(self, info, **kwargs): def resolve_all_categories(self, info, **kwargs):
@ -36,8 +33,8 @@
return Ingredient.objects.all() return Ingredient.objects.all()
def resolve_category(self, info, **kwargs): def resolve_category(self, info, **kwargs):
id = kwargs.get('id') id = kwargs.get("id")
name = kwargs.get('name') name = kwargs.get("name")
if id is not None: if id is not None:
return Category.objects.get(pk=id) return Category.objects.get(pk=id)
@ -48,8 +45,8 @@
return None return None
def resolve_ingredient(self, info, **kwargs): def resolve_ingredient(self, info, **kwargs):
id = kwargs.get('id') id = kwargs.get("id")
name = kwargs.get('name') name = kwargs.get("name")
if id is not None: if id is not None:
return Ingredient.objects.get(pk=id) return Ingredient.objects.get(pk=id)

View File

@ -189,7 +189,7 @@ Default: ``None``
``GRAPHIQL_HEADER_EDITOR_ENABLED`` ``GRAPHIQL_HEADER_EDITOR_ENABLED``
--------------------- ----------------------------------
GraphiQL starting from version 1.0.0 allows setting custom headers in similar fashion to query variables. GraphiQL starting from version 1.0.0 allows setting custom headers in similar fashion to query variables.
@ -207,3 +207,36 @@ Default: ``True``
GRAPHENE = { GRAPHENE = {
'GRAPHIQL_HEADER_EDITOR_ENABLED': True, 'GRAPHIQL_HEADER_EDITOR_ENABLED': True,
} }
``TESTING_ENDPOINT``
--------------------
Define the graphql endpoint url used for the `GraphQLTestCase` class.
Default: ``/graphql``
.. code:: python
GRAPHENE = {
'TESTING_ENDPOINT': '/customEndpoint'
}
``GRAPHIQL_SHOULD_PERSIST_HEADERS``
---------------------
Set to ``True`` if you want to persist GraphiQL headers after refreshing the page.
This setting is passed to ``shouldPersistHeaders`` GraphiQL options, for details refer to GraphiQLDocs_.
.. _GraphiQLDocs: https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
Default: ``False``
.. code:: python
GRAPHENE = {
'GRAPHIQL_SHOULD_PERSIST_HEADERS': False,
}

View File

@ -6,7 +6,8 @@ Using unittest
If you want to unittest your API calls derive your test case from the class `GraphQLTestCase`. If you want to unittest your API calls derive your test case from the class `GraphQLTestCase`.
Your endpoint is set through the `GRAPHQL_URL` attribute on `GraphQLTestCase`. The default endpoint is `GRAPHQL_URL = "/graphql/"`. The default endpoint for testing is `/graphql`. You can override this in the `settings <https://docs.graphene-python.org/projects/django/en/latest/settings/#testing-endpoint>`__.
Usage: Usage:
@ -27,7 +28,7 @@ Usage:
} }
} }
''', ''',
op_name='myModel' operation_name='myModel'
) )
content = json.loads(response.content) content = json.loads(response.content)
@ -48,7 +49,7 @@ Usage:
} }
} }
''', ''',
op_name='myModel', operation_name='myModel',
variables={'id': 1} variables={'id': 1}
) )
@ -72,7 +73,42 @@ Usage:
} }
} }
''', ''',
op_name='myMutation', operation_name='myMutation',
input_data={'my_field': 'foo', 'other_field': 'bar'}
)
# This validates the status code and if you get errors
self.assertResponseNoErrors(response)
# Add some more asserts if you like
...
For testing mutations that are executed within a transaction you should subclass `GraphQLTransactionTestCase`
Usage:
.. code:: python
import json
from graphene_django.utils.testing import GraphQLTransactionTestCase
class MyFancyTransactionTestCase(GraphQLTransactionTestCase):
def test_some_mutation_that_executes_within_a_transaction(self):
response = self.query(
'''
mutation myMutation($input: MyMutationInput!) {
myMutation(input: $input) {
my-model {
id
name
}
}
}
''',
operation_name='myMutation',
input_data={'my_field': 'foo', 'other_field': 'bar'} input_data={'my_field': 'foo', 'other_field': 'bar'}
) )
@ -112,7 +148,7 @@ To use pytest define a simple fixture using the query helper below
} }
} }
''', ''',
op_name='myModel' operation_name='myModel'
) )
content = json.loads(response.content) content = json.loads(response.content)

View File

@ -35,6 +35,7 @@ Now sync your database for the first time:
.. code:: bash .. code:: bash
cd ..
python manage.py migrate python manage.py migrate
Let's create a few simple models... Let's create a few simple models...
@ -77,6 +78,18 @@ Add ingredients as INSTALLED_APPS:
"cookbook.ingredients", "cookbook.ingredients",
] ]
Make sure the app name in ``cookbook.ingredients.apps.IngredientsConfig`` is set to ``cookbook.ingredients``.
.. code:: python
# cookbook/ingredients/apps.py
from django.apps import AppConfig
class IngredientsConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'cookbook.ingredients'
Don't forget to create & run migrations: Don't forget to create & run migrations:

View File

@ -70,7 +70,7 @@ Let's get started with these models:
class Ingredient(models.Model): class Ingredient(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
notes = models.TextField() notes = models.TextField()
category = models.ForeignKey(Category, related_name='ingredients') category = models.ForeignKey(Category, related_name='ingredients', on_delete=models.CASCADE)
def __str__(self): def __str__(self):
return self.name return self.name
@ -151,7 +151,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
interfaces = (relay.Node, ) interfaces = (relay.Node, )
class Query(graphene.ObjectType): class Query(ObjectType):
category = relay.Node.Field(CategoryNode) category = relay.Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)
@ -281,7 +281,7 @@ from the command line.
$ python ./manage.py runserver $ python ./manage.py runserver
Performing system checks... Performing system checks...
Django version 1.11, using settings 'cookbook.settings' Django version 3.1.7, using settings 'cookbook.settings'
Starting development server at http://127.0.0.1:8000/ Starting development server at http://127.0.0.1:8000/
Quit the server with CONTROL-C. Quit the server with CONTROL-C.

View File

@ -14,7 +14,7 @@ whole Graphene repository:
```bash ```bash
# Get the example project code # Get the example project code
git clone https://github.com/graphql-python/graphene-django.git git clone https://github.com/graphql-python/graphene-django.git
cd graphene-django/examples/cookbook cd graphene-django/examples/cookbook-plain
``` ```
It is good idea (but not required) to create a virtual environment It is good idea (but not required) to create a virtual environment

View File

@ -1 +1,52 @@
[{"model": "ingredients.category", "pk": 1, "fields": {"name": "Dairy"}}, {"model": "ingredients.category", "pk": 2, "fields": {"name": "Meat"}}, {"model": "ingredients.ingredient", "pk": 1, "fields": {"name": "Eggs", "notes": "Good old eggs", "category": 1}}, {"model": "ingredients.ingredient", "pk": 2, "fields": {"name": "Milk", "notes": "Comes from a cow", "category": 1}}, {"model": "ingredients.ingredient", "pk": 3, "fields": {"name": "Beef", "notes": "Much like milk, this comes from a cow", "category": 2}}, {"model": "ingredients.ingredient", "pk": 4, "fields": {"name": "Chicken", "notes": "Definitely doesn't come from a cow", "category": 2}}] [
{
"fields": {
"name": "Dairy"
},
"model": "ingredients.category",
"pk": 1
},
{
"fields": {
"name": "Meat"
},
"model": "ingredients.category",
"pk": 2
},
{
"fields": {
"category": 1,
"name": "Eggs",
"notes": "Good old eggs"
},
"model": "ingredients.ingredient",
"pk": 1
},
{
"fields": {
"category": 1,
"name": "Milk",
"notes": "Comes from a cow"
},
"model": "ingredients.ingredient",
"pk": 2
},
{
"fields": {
"category": 2,
"name": "Beef",
"notes": "Much like milk, this comes from a cow"
},
"model": "ingredients.ingredient",
"pk": 3
},
{
"fields": {
"category": 2,
"name": "Chicken",
"notes": "Definitely doesn't come from a cow"
},
"model": "ingredients.ingredient",
"pk": 4
}
]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2015-12-04 18:15 # Generated by Django 1.9 on 2015-12-04 18:15
from __future__ import unicode_literals
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
@ -10,24 +8,46 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = []
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Category', name="Category",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('name', models.CharField(max_length=100)), "id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=100)),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='Ingredient', name="Ingredient",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('name', models.CharField(max_length=100)), "id",
('notes', models.TextField()), models.AutoField(
('category', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='ingredients', to='ingredients.Category')), auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=100)),
("notes", models.TextField()),
(
"category",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="ingredients",
to="ingredients.Category",
),
),
], ],
), ),
] ]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2016-11-04 00:50 # Generated by Django 1.9 on 2016-11-04 00:50
from __future__ import unicode_literals
from django.db import migrations, models from django.db import migrations, models
@ -8,13 +6,13 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('ingredients', '0001_initial'), ("ingredients", "0001_initial"),
] ]
operations = [ operations = [
migrations.AlterField( migrations.AlterField(
model_name='ingredient', model_name="ingredient",
name='notes', name="notes",
field=models.TextField(blank=True, null=True), field=models.TextField(blank=True, null=True),
), ),
] ]

View File

@ -6,12 +6,12 @@ from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('ingredients', '0002_auto_20161104_0050'), ("ingredients", "0002_auto_20161104_0050"),
] ]
operations = [ operations = [
migrations.AlterModelOptions( migrations.AlterModelOptions(
name='category', name="category",
options={'verbose_name_plural': 'Categories'}, options={"verbose_name_plural": "Categories"},
), ),
] ]

View File

@ -16,7 +16,7 @@ class IngredientType(DjangoObjectType):
fields = "__all__" fields = "__all__"
class Query(object): class Query:
category = graphene.Field(CategoryType, id=graphene.Int(), name=graphene.String()) category = graphene.Field(CategoryType, id=graphene.Int(), name=graphene.String())
all_categories = graphene.List(CategoryType) all_categories = graphene.List(CategoryType)

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2015-12-04 18:20 # Generated by Django 1.9 on 2015-12-04 18:20
from __future__ import unicode_literals
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
@ -11,26 +9,62 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [
('ingredients', '0001_initial'), ("ingredients", "0001_initial"),
] ]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Recipe', name="Recipe",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('title', models.CharField(max_length=100)), "id",
('instructions', models.TextField()), models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("title", models.CharField(max_length=100)),
("instructions", models.TextField()),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='RecipeIngredient', name="RecipeIngredient",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('amount', models.FloatField()), "id",
('unit', models.CharField(choices=[('kg', 'Kilograms'), ('l', 'Litres'), ('', 'Units')], max_length=20)), models.AutoField(
('ingredient', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='used_by', to='ingredients.Ingredient')), auto_created=True,
('recipes', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='amounts', to='recipes.Recipe')), primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("amount", models.FloatField()),
(
"unit",
models.CharField(
choices=[("kg", "Kilograms"), ("l", "Litres"), ("", "Units")],
max_length=20,
),
),
(
"ingredient",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="used_by",
to="ingredients.Ingredient",
),
),
(
"recipes",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="amounts",
to="recipes.Recipe",
),
),
], ],
), ),
] ]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2016-11-04 01:06 # Generated by Django 1.9 on 2016-11-04 01:06
from __future__ import unicode_literals
from django.db import migrations, models from django.db import migrations, models
@ -8,18 +6,26 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('recipes', '0001_initial'), ("recipes", "0001_initial"),
] ]
operations = [ operations = [
migrations.RenameField( migrations.RenameField(
model_name='recipeingredient', model_name="recipeingredient",
old_name='recipes', old_name="recipes",
new_name='recipe', new_name="recipe",
), ),
migrations.AlterField( migrations.AlterField(
model_name='recipeingredient', model_name="recipeingredient",
name='unit', name="unit",
field=models.CharField(choices=[(b'unit', b'Units'), (b'kg', b'Kilograms'), (b'l', b'Litres'), (b'st', b'Shots')], max_length=20), field=models.CharField(
choices=[
(b"unit", b"Units"),
(b"kg", b"Kilograms"),
(b"l", b"Litres"),
(b"st", b"Shots"),
],
max_length=20,
),
), ),
] ]

View File

@ -6,13 +6,21 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('recipes', '0002_auto_20161104_0106'), ("recipes", "0002_auto_20161104_0106"),
] ]
operations = [ operations = [
migrations.AlterField( migrations.AlterField(
model_name='recipeingredient', model_name="recipeingredient",
name='unit', name="unit",
field=models.CharField(choices=[('unit', 'Units'), ('kg', 'Kilograms'), ('l', 'Litres'), ('st', 'Shots')], max_length=20), field=models.CharField(
choices=[
("unit", "Units"),
("kg", "Kilograms"),
("l", "Litres"),
("st", "Shots"),
],
max_length=20,
),
), ),
] ]

View File

@ -16,7 +16,7 @@ class RecipeIngredientType(DjangoObjectType):
fields = "__all__" fields = "__all__"
class Query(object): class Query:
recipe = graphene.Field(RecipeType, id=graphene.Int(), title=graphene.String()) recipe = graphene.Field(RecipeType, id=graphene.Int(), title=graphene.String())
all_recipes = graphene.List(RecipeType) all_recipes = graphene.List(RecipeType)

View File

@ -1,4 +1,4 @@
graphene>=2.1,<3 graphene>=2.1,<3
graphene-django>=2.1,<3 graphene-django>=2.1,<3
graphql-core>=2.1,<3 graphql-core>=2.1,<3
django==3.0.7 django==3.1.14

View File

@ -1,4 +1,4 @@
Cookbook Example Django Project Cookbook Example (Relay) Django Project
=============================== ===============================
This example project demos integration between Graphene and Django. This example project demos integration between Graphene and Django.
@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql) [http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries! and run some queries!
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema) (See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-relay/#testing-our-graphql-schema)
for some example queries) for some example queries)

View File

@ -1 +1,52 @@
[{"model": "ingredients.category", "pk": 1, "fields": {"name": "Dairy"}}, {"model": "ingredients.category", "pk": 2, "fields": {"name": "Meat"}}, {"model": "ingredients.ingredient", "pk": 1, "fields": {"name": "Eggs", "notes": "Good old eggs", "category": 1}}, {"model": "ingredients.ingredient", "pk": 2, "fields": {"name": "Milk", "notes": "Comes from a cow", "category": 1}}, {"model": "ingredients.ingredient", "pk": 3, "fields": {"name": "Beef", "notes": "Much like milk, this comes from a cow", "category": 2}}, {"model": "ingredients.ingredient", "pk": 4, "fields": {"name": "Chicken", "notes": "Definitely doesn't come from a cow", "category": 2}}] [
{
"fields": {
"name": "Dairy"
},
"model": "ingredients.category",
"pk": 1
},
{
"fields": {
"name": "Meat"
},
"model": "ingredients.category",
"pk": 2
},
{
"fields": {
"category": 1,
"name": "Eggs",
"notes": "Good old eggs"
},
"model": "ingredients.ingredient",
"pk": 1
},
{
"fields": {
"category": 1,
"name": "Milk",
"notes": "Comes from a cow"
},
"model": "ingredients.ingredient",
"pk": 2
},
{
"fields": {
"category": 2,
"name": "Beef",
"notes": "Much like milk, this comes from a cow"
},
"model": "ingredients.ingredient",
"pk": 3
},
{
"fields": {
"category": 2,
"name": "Chicken",
"notes": "Definitely doesn't come from a cow"
},
"model": "ingredients.ingredient",
"pk": 4
}
]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2015-12-04 18:15 # Generated by Django 1.9 on 2015-12-04 18:15
from __future__ import unicode_literals
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
@ -10,24 +8,46 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = []
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Category', name="Category",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('name', models.CharField(max_length=100)), "id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=100)),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='Ingredient', name="Ingredient",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('name', models.CharField(max_length=100)), "id",
('notes', models.TextField()), models.AutoField(
('category', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='ingredients', to='ingredients.Category')), auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("name", models.CharField(max_length=100)),
("notes", models.TextField()),
(
"category",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="ingredients",
to="ingredients.Category",
),
),
], ],
), ),
] ]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2016-11-04 00:50 # Generated by Django 1.9 on 2016-11-04 00:50
from __future__ import unicode_literals
from django.db import migrations, models from django.db import migrations, models
@ -8,13 +6,13 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('ingredients', '0001_initial'), ("ingredients", "0001_initial"),
] ]
operations = [ operations = [
migrations.AlterField( migrations.AlterField(
model_name='ingredient', model_name="ingredient",
name='notes', name="notes",
field=models.TextField(blank=True, null=True), field=models.TextField(blank=True, null=True),
), ),
] ]

View File

@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType):
} }
class Query(object): class Query:
category = Node.Field(CategoryNode) category = Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2015-12-04 18:20 # Generated by Django 1.9 on 2015-12-04 18:20
from __future__ import unicode_literals
import django.db.models.deletion import django.db.models.deletion
from django.db import migrations, models from django.db import migrations, models
@ -11,26 +9,62 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [
('ingredients', '0001_initial'), ("ingredients", "0001_initial"),
] ]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Recipe', name="Recipe",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('title', models.CharField(max_length=100)), "id",
('instructions', models.TextField()), models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("title", models.CharField(max_length=100)),
("instructions", models.TextField()),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='RecipeIngredient', name="RecipeIngredient",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('amount', models.FloatField()), "id",
('unit', models.CharField(choices=[('kg', 'Kilograms'), ('l', 'Litres'), ('', 'Units')], max_length=20)), models.AutoField(
('ingredient', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='used_by', to='ingredients.Ingredient')), auto_created=True,
('recipes', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='amounts', to='recipes.Recipe')), primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("amount", models.FloatField()),
(
"unit",
models.CharField(
choices=[("kg", "Kilograms"), ("l", "Litres"), ("", "Units")],
max_length=20,
),
),
(
"ingredient",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="used_by",
to="ingredients.Ingredient",
),
),
(
"recipes",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="amounts",
to="recipes.Recipe",
),
),
], ],
), ),
] ]

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9 on 2016-11-04 01:06 # Generated by Django 1.9 on 2016-11-04 01:06
from __future__ import unicode_literals
from django.db import migrations, models from django.db import migrations, models
@ -8,18 +6,26 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('recipes', '0001_initial'), ("recipes", "0001_initial"),
] ]
operations = [ operations = [
migrations.RenameField( migrations.RenameField(
model_name='recipeingredient', model_name="recipeingredient",
old_name='recipes', old_name="recipes",
new_name='recipe', new_name="recipe",
), ),
migrations.AlterField( migrations.AlterField(
model_name='recipeingredient', model_name="recipeingredient",
name='unit', name="unit",
field=models.CharField(choices=[(b'unit', b'Units'), (b'kg', b'Kilograms'), (b'l', b'Litres'), (b'st', b'Shots')], max_length=20), field=models.CharField(
choices=[
(b"unit", b"Units"),
(b"kg", b"Kilograms"),
(b"l", b"Litres"),
(b"st", b"Shots"),
],
max_length=20,
),
), ),
] ]

View File

@ -25,7 +25,7 @@ class RecipeIngredientNode(DjangoObjectType):
} }
class Query(object): class Query:
recipe = Node.Field(RecipeNode) recipe = Node.Field(RecipeNode)
all_recipes = DjangoFilterConnectionField(RecipeNode) all_recipes = DjangoFilterConnectionField(RecipeNode)

View File

@ -1 +1,302 @@
[{"model": "auth.user", "pk": 1, "fields": {"password": "pbkdf2_sha256$24000$0SgBlSlnbv5c$ijVQipm2aNDlcrTL8Qi3SVNHphTm4HIsDfUi4kn9tog=", "last_login": "2016-11-04T00:46:58Z", "is_superuser": true, "username": "admin", "first_name": "", "last_name": "", "email": "asdf@example.com", "is_staff": true, "is_active": true, "date_joined": "2016-11-03T18:24:40Z", "groups": [], "user_permissions": []}}, {"model": "recipes.recipe", "pk": 1, "fields": {"title": "Cheerios With a Shot of Vermouth", "instructions": "https://xkcd.com/720/"}}, {"model": "recipes.recipe", "pk": 2, "fields": {"title": "Quail Eggs in Whipped Cream and MSG", "instructions": "https://xkcd.com/720/"}}, {"model": "recipes.recipe", "pk": 3, "fields": {"title": "Deep Fried Skittles", "instructions": "https://xkcd.com/720/"}}, {"model": "recipes.recipe", "pk": 4, "fields": {"title": "Newt ala Doritos", "instructions": "https://xkcd.com/720/"}}, {"model": "recipes.recipe", "pk": 5, "fields": {"title": "Fruit Salad", "instructions": "Chop up and add together"}}, {"model": "recipes.recipeingredient", "pk": 1, "fields": {"recipes": 5, "ingredient": 9, "amount": 1.0, "unit": "unit"}}, {"model": "recipes.recipeingredient", "pk": 2, "fields": {"recipes": 5, "ingredient": 10, "amount": 2.0, "unit": "unit"}}, {"model": "recipes.recipeingredient", "pk": 3, "fields": {"recipes": 5, "ingredient": 7, "amount": 3.0, "unit": "unit"}}, {"model": "recipes.recipeingredient", "pk": 4, "fields": {"recipes": 5, "ingredient": 8, "amount": 4.0, "unit": "unit"}}, {"model": "recipes.recipeingredient", "pk": 5, "fields": {"recipes": 4, "ingredient": 5, "amount": 1.0, "unit": "kg"}}, {"model": "recipes.recipeingredient", "pk": 6, "fields": {"recipes": 4, "ingredient": 6, "amount": 2.0, "unit": "l"}}, {"model": "recipes.recipeingredient", "pk": 7, "fields": {"recipes": 3, "ingredient": 4, "amount": 1.0, "unit": "unit"}}, {"model": "recipes.recipeingredient", "pk": 8, "fields": {"recipes": 2, "ingredient": 2, "amount": 1.0, "unit": "kg"}}, {"model": "recipes.recipeingredient", "pk": 9, "fields": {"recipes": 2, "ingredient": 11, "amount": 2.0, "unit": "l"}}, {"model": "recipes.recipeingredient", "pk": 10, "fields": {"recipes": 2, "ingredient": 12, "amount": 3.0, "unit": "st"}}, {"model": "recipes.recipeingredient", "pk": 11, "fields": {"recipes": 1, "ingredient": 1, "amount": 1.0, "unit": "kg"}}, {"model": "recipes.recipeingredient", "pk": 12, "fields": {"recipes": 1, "ingredient": 3, "amount": 1.0, "unit": "st"}}, {"model": "ingredients.category", "pk": 1, "fields": {"name": "fruit"}}, {"model": "ingredients.category", "pk": 3, "fields": {"name": "xkcd"}}, {"model": "ingredients.ingredient", "pk": 1, "fields": {"name": "Cheerios", "notes": "this is a note", "category": 3}}, {"model": "ingredients.ingredient", "pk": 2, "fields": {"name": "Quail Eggs", "notes": "has more notes", "category": 3}}, {"model": "ingredients.ingredient", "pk": 3, "fields": {"name": "Vermouth", "notes": "", "category": 3}}, {"model": "ingredients.ingredient", "pk": 4, "fields": {"name": "Skittles", "notes": "", "category": 3}}, {"model": "ingredients.ingredient", "pk": 5, "fields": {"name": "Newt", "notes": "Braised and Confuesd", "category": 3}}, {"model": "ingredients.ingredient", "pk": 6, "fields": {"name": "Doritos", "notes": "Crushed", "category": 3}}, {"model": "ingredients.ingredient", "pk": 7, "fields": {"name": "Apple", "notes": "", "category": 1}}, {"model": "ingredients.ingredient", "pk": 8, "fields": {"name": "Orange", "notes": "", "category": 1}}, {"model": "ingredients.ingredient", "pk": 9, "fields": {"name": "Banana", "notes": "", "category": 1}}, {"model": "ingredients.ingredient", "pk": 10, "fields": {"name": "Grapes", "notes": "", "category": 1}}, {"model": "ingredients.ingredient", "pk": 11, "fields": {"name": "Whipped Cream", "notes": "", "category": 3}}, {"model": "ingredients.ingredient", "pk": 12, "fields": {"name": "MSG", "notes": "", "category": 3}}] [
{
"fields": {
"date_joined": "2016-11-03T18:24:40Z",
"email": "asdf@example.com",
"first_name": "",
"groups": [],
"is_active": true,
"is_staff": true,
"is_superuser": true,
"last_login": "2016-11-04T00:46:58Z",
"last_name": "",
"password": "pbkdf2_sha256$24000$0SgBlSlnbv5c$ijVQipm2aNDlcrTL8Qi3SVNHphTm4HIsDfUi4kn9tog=",
"user_permissions": [],
"username": "admin"
},
"model": "auth.user",
"pk": 1
},
{
"fields": {
"instructions": "https://xkcd.com/720/",
"title": "Cheerios With a Shot of Vermouth"
},
"model": "recipes.recipe",
"pk": 1
},
{
"fields": {
"instructions": "https://xkcd.com/720/",
"title": "Quail Eggs in Whipped Cream and MSG"
},
"model": "recipes.recipe",
"pk": 2
},
{
"fields": {
"instructions": "https://xkcd.com/720/",
"title": "Deep Fried Skittles"
},
"model": "recipes.recipe",
"pk": 3
},
{
"fields": {
"instructions": "https://xkcd.com/720/",
"title": "Newt ala Doritos"
},
"model": "recipes.recipe",
"pk": 4
},
{
"fields": {
"instructions": "Chop up and add together",
"title": "Fruit Salad"
},
"model": "recipes.recipe",
"pk": 5
},
{
"fields": {
"amount": 1.0,
"ingredient": 9,
"recipes": 5,
"unit": "unit"
},
"model": "recipes.recipeingredient",
"pk": 1
},
{
"fields": {
"amount": 2.0,
"ingredient": 10,
"recipes": 5,
"unit": "unit"
},
"model": "recipes.recipeingredient",
"pk": 2
},
{
"fields": {
"amount": 3.0,
"ingredient": 7,
"recipes": 5,
"unit": "unit"
},
"model": "recipes.recipeingredient",
"pk": 3
},
{
"fields": {
"amount": 4.0,
"ingredient": 8,
"recipes": 5,
"unit": "unit"
},
"model": "recipes.recipeingredient",
"pk": 4
},
{
"fields": {
"amount": 1.0,
"ingredient": 5,
"recipes": 4,
"unit": "kg"
},
"model": "recipes.recipeingredient",
"pk": 5
},
{
"fields": {
"amount": 2.0,
"ingredient": 6,
"recipes": 4,
"unit": "l"
},
"model": "recipes.recipeingredient",
"pk": 6
},
{
"fields": {
"amount": 1.0,
"ingredient": 4,
"recipes": 3,
"unit": "unit"
},
"model": "recipes.recipeingredient",
"pk": 7
},
{
"fields": {
"amount": 1.0,
"ingredient": 2,
"recipes": 2,
"unit": "kg"
},
"model": "recipes.recipeingredient",
"pk": 8
},
{
"fields": {
"amount": 2.0,
"ingredient": 11,
"recipes": 2,
"unit": "l"
},
"model": "recipes.recipeingredient",
"pk": 9
},
{
"fields": {
"amount": 3.0,
"ingredient": 12,
"recipes": 2,
"unit": "st"
},
"model": "recipes.recipeingredient",
"pk": 10
},
{
"fields": {
"amount": 1.0,
"ingredient": 1,
"recipes": 1,
"unit": "kg"
},
"model": "recipes.recipeingredient",
"pk": 11
},
{
"fields": {
"amount": 1.0,
"ingredient": 3,
"recipes": 1,
"unit": "st"
},
"model": "recipes.recipeingredient",
"pk": 12
},
{
"fields": {
"name": "fruit"
},
"model": "ingredients.category",
"pk": 1
},
{
"fields": {
"name": "xkcd"
},
"model": "ingredients.category",
"pk": 3
},
{
"fields": {
"category": 3,
"name": "Cheerios",
"notes": "this is a note"
},
"model": "ingredients.ingredient",
"pk": 1
},
{
"fields": {
"category": 3,
"name": "Quail Eggs",
"notes": "has more notes"
},
"model": "ingredients.ingredient",
"pk": 2
},
{
"fields": {
"category": 3,
"name": "Vermouth",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 3
},
{
"fields": {
"category": 3,
"name": "Skittles",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 4
},
{
"fields": {
"category": 3,
"name": "Newt",
"notes": "Braised and Confuesd"
},
"model": "ingredients.ingredient",
"pk": 5
},
{
"fields": {
"category": 3,
"name": "Doritos",
"notes": "Crushed"
},
"model": "ingredients.ingredient",
"pk": 6
},
{
"fields": {
"category": 1,
"name": "Apple",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 7
},
{
"fields": {
"category": 1,
"name": "Orange",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 8
},
{
"fields": {
"category": 1,
"name": "Banana",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 9
},
{
"fields": {
"category": 1,
"name": "Grapes",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 10
},
{
"fields": {
"category": 3,
"name": "Whipped Cream",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 11
},
{
"fields": {
"category": 3,
"name": "MSG",
"notes": ""
},
"model": "ingredients.ingredient",
"pk": 12
}
]

View File

@ -1,5 +1,5 @@
graphene>=2.1,<3 graphene>=2.1,<3
graphene-django>=2.1,<3 graphene-django>=2.1,<3
graphql-core>=2.1,<3 graphql-core>=2.1,<3
django==3.0.7 django==3.1.14
django-filter>=2 django-filter>=2

View File

@ -1,5 +1,3 @@
from __future__ import absolute_import
from django.db import models from django.db import models

View File

@ -1,7 +1,7 @@
from .fields import DjangoConnectionField, DjangoListField from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType from .types import DjangoObjectType
__version__ = "3.0.0b7" __version__ = "3.0.0"
__all__ = [ __all__ = [
"__version__", "__version__",

View File

@ -1,4 +1,5 @@
class MissingType(object): class MissingType:
def __init__(self, *args, **kwargs):
pass pass

View File

@ -1,5 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from functools import singledispatch, partial, wraps from functools import singledispatch, wraps
from django.db import models from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -23,10 +23,16 @@ from graphene import (
Time, Time,
Decimal, Decimal,
) )
from graphene.types.resolver import get_default_resolver
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.scalars import BigInt
from graphene.utils.str_converters import to_camel_case from graphene.utils.str_converters import to_camel_case
from graphql import GraphQLError, assert_valid_name from graphql import GraphQLError
try:
from graphql import assert_name
except ImportError:
# Support for older versions of graphql
from graphql import assert_valid_name as assert_name
from graphql.pyutils import register_description from graphql.pyutils import register_description
from .compat import ArrayField, HStoreField, JSONField, PGJSONField, RangeField from .compat import ArrayField, HStoreField, JSONField, PGJSONField, RangeField
@ -36,7 +42,7 @@ from .utils.str_converters import to_const
class BlankValueField(Field): class BlankValueField(Field):
def get_resolver(self, parent_resolver): def wrap_resolve(self, parent_resolver):
resolver = self.resolver or parent_resolver resolver = self.resolver or parent_resolver
# create custom resolver # create custom resolver
@ -56,7 +62,7 @@ class BlankValueField(Field):
def convert_choice_name(name): def convert_choice_name(name):
name = to_const(force_str(name)) name = to_const(force_str(name))
try: try:
assert_valid_name(name) assert_name(name)
except GraphQLError: except GraphQLError:
name = "A_%s" % name name = "A_%s" % name
return name return name
@ -68,8 +74,7 @@ def get_choices(choices):
choices = choices.items() choices = choices.items()
for value, help_text in choices: for value, help_text in choices:
if isinstance(help_text, (tuple, list)): if isinstance(help_text, (tuple, list)):
for choice in get_choices(help_text): yield from get_choices(help_text)
yield choice
else: else:
name = convert_choice_name(value) name = convert_choice_name(value)
while name in converted_names: while name in converted_names:
@ -86,7 +91,7 @@ def convert_choices_to_named_enum_with_descriptions(name, choices):
named_choices = [(c[0], c[1]) for c in choices] named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices} named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object): class EnumWithDescriptionsType:
@property @property
def description(self): def description(self):
return str(named_choices_descriptions[self.name]) return str(named_choices_descriptions[self.name])
@ -103,7 +108,7 @@ def generate_enum_name(django_model_meta, field):
) )
name = custom_func(field) name = custom_func(field)
elif graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING is True: elif graphene_settings.DJANGO_CHOICE_FIELD_ENUM_V2_NAMING is True:
name = to_camel_case("{}_{}".format(django_model_meta.object_name, field.name)) name = to_camel_case(f"{django_model_meta.object_name}_{field.name}")
else: else:
name = "{app_label}{object_name}{field_name}Choices".format( name = "{app_label}{object_name}{field_name}Choices".format(
app_label=to_camel_case(django_model_meta.app_label.title()), app_label=to_camel_case(django_model_meta.app_label.title()),
@ -149,7 +154,9 @@ def get_django_field_description(field):
@singledispatch @singledispatch
def convert_django_field(field, registry=None): def convert_django_field(field, registry=None):
raise Exception( raise Exception(
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__) "Don't know how to convert the Django field {} ({})".format(
field, field.__class__
)
) )
@ -167,11 +174,19 @@ def convert_field_to_string(field, registry=None):
) )
@convert_django_field.register(models.BigAutoField)
@convert_django_field.register(models.AutoField) @convert_django_field.register(models.AutoField)
def convert_field_to_id(field, registry=None): def convert_field_to_id(field, registry=None):
return ID(description=get_django_field_description(field), required=not field.null) return ID(description=get_django_field_description(field), required=not field.null)
if hasattr(models, "SmallAutoField"):
@convert_django_field.register(models.SmallAutoField)
def convert_field_small_to_id(field, registry=None):
return convert_field_to_id(field, registry)
@convert_django_field.register(models.UUIDField) @convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None): def convert_field_to_uuid(field, registry=None):
return UUID( return UUID(
@ -179,10 +194,14 @@ def convert_field_to_uuid(field, registry=None):
) )
@convert_django_field.register(models.BigIntegerField)
def convert_big_int_field(field, registry=None):
return BigInt(description=field.help_text, required=not field.null)
@convert_django_field.register(models.PositiveIntegerField) @convert_django_field.register(models.PositiveIntegerField)
@convert_django_field.register(models.PositiveSmallIntegerField) @convert_django_field.register(models.PositiveSmallIntegerField)
@convert_django_field.register(models.SmallIntegerField) @convert_django_field.register(models.SmallIntegerField)
@convert_django_field.register(models.BigIntegerField)
@convert_django_field.register(models.IntegerField) @convert_django_field.register(models.IntegerField)
def convert_field_to_int(field, registry=None): def convert_field_to_int(field, registry=None):
return Int(description=get_django_field_description(field), required=not field.null) return Int(description=get_django_field_description(field), required=not field.null)
@ -198,7 +217,9 @@ def convert_field_to_boolean(field, registry=None):
@convert_django_field.register(models.DecimalField) @convert_django_field.register(models.DecimalField)
def convert_field_to_decimal(field, registry=None): def convert_field_to_decimal(field, registry=None):
return Decimal(description=field.help_text, required=not field.null) return Decimal(
description=get_django_field_description(field), required=not field.null
)
@convert_django_field.register(models.FloatField) @convert_django_field.register(models.FloatField)
@ -239,10 +260,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
if not _type: if not _type:
return return
# We do this for a bug in Django 1.8, where null attr return Field(_type, required=not field.null)
# is not available in the OneToOneRel instance
null = getattr(field, "null", True)
return Field(_type, required=not null)
return Dynamic(dynamic_type) return Dynamic(dynamic_type)
@ -297,7 +315,26 @@ def convert_field_to_djangomodel(field, registry=None):
if not _type: if not _type:
return return
return Field( class CustomField(Field):
def wrap_resolve(self, parent_resolver):
"""
Implements a custom resolver which go through the `get_node` method to ensure that
it goes through the `get_queryset` method of the DjangoObjectType.
"""
resolver = super().wrap_resolve(parent_resolver)
def custom_resolver(root, info, **args):
fk_obj = resolver(root, info, **args)
if not isinstance(fk_obj, model):
# In case the resolver is a custom one that overwrites
# the default Django resolver
# This happens, for example, when using custom awaitable resolvers.
return fk_obj
return _type.get_node(info, fk_obj.pk)
return custom_resolver
return CustomField(
_type, _type,
description=get_django_field_description(field), description=get_django_field_description(field),
required=not field.null, required=not field.null,

View File

@ -0,0 +1,17 @@
import traceback
from django.utils.encoding import force_str
from .types import DjangoDebugException
def wrap_exception(exception):
return DjangoDebugException(
message=force_str(exception),
exc_type=force_str(type(exception)),
stack="".join(
traceback.format_exception(
exception, value=exception, tb=exception.__traceback__
)
),
)

View File

@ -0,0 +1,10 @@
from graphene import ObjectType, String
class DjangoDebugException(ObjectType):
class Meta:
description = "Represents a single exception raised."
exc_type = String(required=True, description="The class of the exception")
message = String(required=True, description="The message of the exception")
stack = String(required=True, description="The stack trace")

View File

@ -3,32 +3,38 @@ from django.db import connections
from promise import Promise from promise import Promise
from .sql.tracking import unwrap_cursor, wrap_cursor from .sql.tracking import unwrap_cursor, wrap_cursor
from .exception.formating import wrap_exception
from .types import DjangoDebug from .types import DjangoDebug
class DjangoDebugContext(object): class DjangoDebugContext:
def __init__(self): def __init__(self):
self.debug_promise = None self.debug_result = None
self.promises = [] self.results = []
self.object = DjangoDebug(sql=[], exceptions=[])
self.enable_instrumentation() self.enable_instrumentation()
self.object = DjangoDebug(sql=[])
def get_debug_promise(self): def get_debug_result(self):
if not self.debug_promise: if not self.debug_result:
self.debug_promise = Promise.all(self.promises) self.debug_result = self.results
self.promises = [] self.results = []
return self.debug_promise.then(self.on_resolve_all_promises).get() return self.on_resolve_all_results()
def on_resolve_all_promises(self, values): def on_resolve_error(self, value):
if self.promises: if hasattr(self, "object"):
self.debug_promise = None self.object.exceptions.append(wrap_exception(value))
return self.get_debug_promise() return value
def on_resolve_all_results(self):
if self.results:
self.debug_result = None
return self.get_debug_result()
self.disable_instrumentation() self.disable_instrumentation()
return self.object return self.object
def add_promise(self, promise): def add_result(self, result):
if self.debug_promise: if self.debug_result:
self.promises.append(promise) self.results.append(result)
def enable_instrumentation(self): def enable_instrumentation(self):
# This is thread-safe because database connections are thread-local. # This is thread-safe because database connections are thread-local.
@ -40,7 +46,7 @@ class DjangoDebugContext(object):
unwrap_cursor(connection) unwrap_cursor(connection)
class DjangoDebugMiddleware(object): class DjangoDebugMiddleware:
def resolve(self, next, root, info, **args): def resolve(self, next, root, info, **args):
context = info.context context = info.context
django_debug = getattr(context, "django_debug", None) django_debug = getattr(context, "django_debug", None)
@ -56,7 +62,10 @@ class DjangoDebugMiddleware(object):
) )
) )
if info.schema.get_type("DjangoDebug") == info.return_type: if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise() return context.django_debug.get_debug_result()
promise = next(root, info, **args) try:
context.django_debug.add_promise(promise) result = next(root, info, **args)
return promise except Exception as e:
return context.django_debug.on_resolve_error(e)
context.django_debug.add_result(result)
return result

View File

@ -1,5 +1,4 @@
# Code obtained from django-debug-toolbar sql panel tracking # Code obtained from django-debug-toolbar sql panel tracking
from __future__ import absolute_import, unicode_literals
import json import json
from threading import local from threading import local
@ -50,7 +49,7 @@ def unwrap_cursor(connection):
del connection._graphene_cursor del connection._graphene_cursor
class ExceptionCursorWrapper(object): class ExceptionCursorWrapper:
""" """
Wraps a cursor and raises an exception on any operation. Wraps a cursor and raises an exception on any operation.
Used in Templates panel. Used in Templates panel.
@ -63,7 +62,7 @@ class ExceptionCursorWrapper(object):
raise SQLQueryTriggered() raise SQLQueryTriggered()
class NormalCursorWrapper(object): class NormalCursorWrapper:
""" """
Wraps a cursor and logs queries. Wraps a cursor and logs queries.
""" """
@ -85,7 +84,7 @@ class NormalCursorWrapper(object):
if not params: if not params:
return params return params
if isinstance(params, dict): if isinstance(params, dict):
return dict((key, self._quote_expr(value)) for key, value in params.items()) return {key: self._quote_expr(value) for key, value in params.items()}
return list(map(self._quote_expr, params)) return list(map(self._quote_expr, params))
def _decode(self, param): def _decode(self, param):

View File

@ -8,7 +8,7 @@ from ..middleware import DjangoDebugMiddleware
from ..types import DjangoDebug from ..types import DjangoDebug
class context(object): class context:
pass pass
@ -272,3 +272,42 @@ def test_should_query_connectionfilter(graphene_settings, max_limit):
assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"] assert "COUNT" in result.data["_debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data["_debug"]["sql"][1]["rawSql"] == query assert result.data["_debug"]["sql"][1]["rawSql"] == query
def test_should_query_stack_trace():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name="_debug")
def resolve_reporter(self, info, **args):
raise Exception("caught stack trace")
query = """
query ReporterQuery {
reporter {
lastName
}
_debug {
exceptions {
message
stack
}
}
}
"""
schema = graphene.Schema(query=Query)
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert result.errors
assert len(result.data["_debug"]["exceptions"])
debug_exception = result.data["_debug"]["exceptions"][0]
assert debug_exception["stack"].count("\n") > 1
assert "test_query.py" in debug_exception["stack"]
assert debug_exception["message"] == "caught stack trace"

View File

@ -1,6 +1,7 @@
from graphene import List, ObjectType from graphene import List, ObjectType
from .sql.types import DjangoDebugSQL from .sql.types import DjangoDebugSQL
from .exception.types import DjangoDebugException
class DjangoDebug(ObjectType): class DjangoDebug(ObjectType):
@ -8,3 +9,6 @@ class DjangoDebug(ObjectType):
description = "Debugging information for the current query." description = "Debugging information for the current query."
sql = List(DjangoDebugSQL, description="Executed SQL queries for this API query.") sql = List(DjangoDebugSQL, description="Executed SQL queries for this API query.")
exceptions = List(
DjangoDebugException, description="Raise exceptions for this API query."
)

View File

@ -1,12 +1,14 @@
from functools import partial from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from graphql_relay.connection.arrayconnection import (
from graphql_relay import (
connection_from_array_slice, connection_from_array_slice,
cursor_to_offset, cursor_to_offset,
get_offset_with_default, get_offset_with_default,
offset_to_cursor, offset_to_cursor,
) )
from promise import Promise from promise import Promise
from graphene import Int, NonNull from graphene import Int, NonNull
@ -26,7 +28,7 @@ class DjangoListField(Field):
_type = _type.of_type _type = _type.of_type
# Django would never return a Set of None vvvvvvv # Django would never return a Set of None vvvvvvv
super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs) super().__init__(List(NonNull(_type)), *args, **kwargs)
assert issubclass( assert issubclass(
self._underlying_type, DjangoObjectType self._underlying_type, DjangoObjectType
@ -61,12 +63,16 @@ class DjangoListField(Field):
return queryset return queryset
def wrap_resolve(self, parent_resolver): def wrap_resolve(self, parent_resolver):
resolver = super().wrap_resolve(parent_resolver)
_type = self.type _type = self.type
if isinstance(_type, NonNull): if isinstance(_type, NonNull):
_type = _type.of_type _type = _type.of_type
django_object_type = _type.of_type.of_type django_object_type = _type.of_type.of_type
return partial( return partial(
self.list_resolver, django_object_type, parent_resolver, self.get_manager(), self.list_resolver,
django_object_type,
resolver,
self.get_manager(),
) )
@ -81,7 +87,7 @@ class DjangoConnectionField(ConnectionField):
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
) )
kwargs.setdefault("offset", Int()) kwargs.setdefault("offset", Int())
super(DjangoConnectionField, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@property @property
def type(self): def type(self):
@ -143,36 +149,40 @@ class DjangoConnectionField(ConnectionField):
iterable = maybe_queryset(iterable) iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet): if isinstance(iterable, QuerySet):
list_length = iterable.count() array_length = iterable.count()
else: else:
list_length = len(iterable) array_length = len(iterable)
list_slice_length = (
min(max_limit, list_length) if max_limit is not None else list_length
)
# If after is higher than list_length, connection_from_list_slice # If after is higher than array_length, connection_from_array_slice
# would try to do a negative slicing which makes django throw an # would try to do a negative slicing which makes django throw an
# AssertionError # AssertionError
after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length) slice_start = min(
get_offset_with_default(args.get("after"), -1) + 1,
array_length,
)
array_slice_length = array_length - slice_start
if max_limit is not None and args.get("first", None) == None: # Impose the maximum limit via the `first` field if neither first or last are already provided
if args.get("last", None) != None: # (note that if any of them is provided they must be under max_limit otherwise an error is raised).
after = list_length - args["last"] if (
else: max_limit is not None
and args.get("first", None) is None
and args.get("last", None) is None
):
args["first"] = max_limit args["first"] = max_limit
connection = connection_from_array_slice( connection = connection_from_array_slice(
iterable[after:], iterable[slice_start:],
args, args,
slice_start=after, slice_start=slice_start,
array_length=list_length, array_length=array_length,
array_slice_length=list_slice_length, array_slice_length=array_slice_length,
connection_type=partial(connection_adapter, connection), connection_type=partial(connection_adapter, connection),
edge_type=connection.Edge, edge_type=connection.Edge,
page_info_type=page_info_adapter, page_info_type=page_info_adapter,
) )
connection.iterable = iterable connection.iterable = iterable
connection.length = list_length connection.length = array_length
return connection return connection
@classmethod @classmethod

View File

@ -9,10 +9,21 @@ if not DJANGO_FILTER_INSTALLED:
) )
else: else:
from .fields import DjangoFilterConnectionField from .fields import DjangoFilterConnectionField
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter from .filters import (
ArrayFilter,
GlobalIDFilter,
GlobalIDMultipleChoiceFilter,
ListFilter,
RangeFilter,
TypedFilter,
)
__all__ = [ __all__ = [
"DjangoFilterConnectionField", "DjangoFilterConnectionField",
"GlobalIDFilter", "GlobalIDFilter",
"GlobalIDMultipleChoiceFilter", "GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
"TypedFilter",
] ]

View File

@ -2,16 +2,35 @@ from collections import OrderedDict
from functools import partial from functools import partial
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from graphene.utils.str_converters import to_snake_case from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
def convert_enum(data):
"""
Check if the data is a enum option (or potentially nested list of enum option)
and convert it to its value.
This method is used to pre-process the data for the filters as they can take an
graphene.Enum as argument, but filters (from django_filters) expect a simple value.
"""
if isinstance(data, list):
return [convert_enum(item) for item in data]
if isinstance(type(data), EnumType):
return data.value
else:
return data
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(DjangoConnectionField):
def __init__( def __init__(
self, self,
type, type_,
fields=None, fields=None,
order_by=None, order_by=None,
extra_filter_meta=None, extra_filter_meta=None,
@ -25,7 +44,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self._filtering_args = None self._filtering_args = None
self._extra_filter_meta = extra_filter_meta self._extra_filter_meta = extra_filter_meta
self._base_args = None self._base_args = None
super(DjangoFilterConnectionField, self).__init__(type, *args, **kwargs) super().__init__(type_, *args, **kwargs)
@property @property
def args(self): def args(self):
@ -43,8 +62,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if self._extra_filter_meta: if self._extra_filter_meta:
meta.update(self._extra_filter_meta) meta.update(self._extra_filter_meta)
filterset_class = self._provided_filterset_class or ( filterset_class = (
self.node_type._meta.filterset_class self._provided_filterset_class or self.node_type._meta.filterset_class
) )
self._filterset_class = get_filterset_class(filterset_class, **meta) self._filterset_class = get_filterset_class(filterset_class, **meta)
@ -68,17 +87,15 @@ class DjangoFilterConnectionField(DjangoConnectionField):
if k in filtering_args: if k in filtering_args:
if k == "order_by" and v is not None: if k == "order_by" and v is not None:
v = to_snake_case(v) v = to_snake_case(v)
kwargs[k] = v kwargs[k] = convert_enum(v)
return kwargs return kwargs
qs = super(DjangoFilterConnectionField, cls).resolve_queryset( qs = super().resolve_queryset(connection, iterable, info, args)
connection, iterable, info, args
)
filterset = filterset_class( filterset = filterset_class(
data=filter_kwargs(), queryset=qs, request=info.context data=filter_kwargs(), queryset=qs, request=info.context
) )
if filterset.form.is_valid(): if filterset.is_valid():
return filterset.qs return filterset.qs
raise ValidationError(filterset.form.errors.as_json()) raise ValidationError(filterset.form.errors.as_json())

View File

@ -1,75 +0,0 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super(GlobalIDMultipleChoiceFilter, self).filter(qs, gids)
class InFilter(Filter):
"""
Filter for a list of value using the `__in` Django filter.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first weather the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super().filter(qs, value)
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(Filter):
field_class = RangeField

View File

@ -0,0 +1,25 @@
import warnings
from ...utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .array_filter import ArrayFilter
from .global_id_filter import GlobalIDFilter, GlobalIDMultipleChoiceFilter
from .list_filter import ListFilter
from .range_filter import RangeFilter
from .typed_filter import TypedFilter
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
"ArrayFilter",
"ListFilter",
"RangeFilter",
"TypedFilter",
]

View File

@ -0,0 +1,27 @@
from django_filters.constants import EMPTY_VALUES
from .typed_filter import TypedFilter
class ArrayFilter(TypedFilter):
"""
Filter made for PostgreSQL ArrayField.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get the filter applied with
an empty list since it's a valid value but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value in EMPTY_VALUES and value != []:
return qs
if self.distinct:
qs = qs.distinct()
lookup = f"{self.field_name}__{self.lookup_expr}"
qs = self.get_method(qs)(**{lookup: value})
return qs

View File

@ -0,0 +1,28 @@
from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id
from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField
class GlobalIDFilter(Filter):
"""
Filter for Relay global ID.
"""
field_class = GlobalIDFormField
def filter(self, qs, value):
"""Convert the filter value to a primary key before filtering"""
_id = None
if value is not None:
_, _id = from_global_id(value)
return super().filter(qs, _id)
class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
field_class = GlobalIDMultipleChoiceField
def filter(self, qs, value):
gids = [from_global_id(v)[1] for v in value]
return super().filter(qs, gids)

View File

@ -0,0 +1,26 @@
from .typed_filter import TypedFilter
class ListFilter(TypedFilter):
"""
Filter that takes a list of value as input.
It is for example used for `__in` filters.
"""
def filter(self, qs, value):
"""
Override the default filter class to check first whether the list is
empty or not.
This needs to be done as in this case we expect to get an empty output
(if not an exclude filter) but django_filter consider an empty list
to be an empty input value (see `EMPTY_VALUES`) meaning that
the filter does not need to be applied (hence returning the original
queryset).
"""
if value is not None and len(value) == 0:
if self.exclude:
return qs
else:
return qs.none()
else:
return super().filter(qs, value)

View File

@ -0,0 +1,24 @@
from django.core.exceptions import ValidationError
from django.forms import Field
from .typed_filter import TypedFilter
def validate_range(value):
"""
Validator for range filter input: the list of value must be of length 2.
Note that validators are only run if the value is not empty.
"""
if len(value) != 2:
raise ValidationError(
"Invalid range specified: it needs to contain 2 values.", code="invalid"
)
class RangeField(Field):
default_validators = [validate_range]
empty_values = [None]
class RangeFilter(TypedFilter):
field_class = RangeField

View File

@ -0,0 +1,27 @@
from django_filters import Filter
from graphene.types.utils import get_type
class TypedFilter(Filter):
"""
Filter class for which the input GraphQL type can explicitly be provided.
If it is not provided, when building the schema, it will try to guess
it from the field.
"""
def __init__(self, input_type=None, *args, **kwargs):
self._input_type = input_type
super().__init__(*args, **kwargs)
@property
def input_type(self):
input_type = get_type(self._input_type)
if input_type is not None:
if not callable(getattr(input_type, "get_type", None)):
raise ValueError(
"Wrong `input_type` for {}: it only accepts graphene types, got {}".format(
self.__class__.__name__, input_type
)
)
return input_type

View File

@ -1,7 +1,6 @@
import itertools import itertools
from django.db import models from django.db import models
from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import BaseFilterSet, FilterSet from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
@ -30,20 +29,18 @@ class GrapheneFilterSetMixin(BaseFilterSet):
def setup_filterset(filterset_class): def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality """Wrap a provided filterset in Graphene-specific functionality"""
"""
return type( return type(
"Graphene{}".format(filterset_class.__name__), f"Graphene{filterset_class.__name__}",
(filterset_class, GrapheneFilterSetMixin), (filterset_class, GrapheneFilterSetMixin),
{}, {},
) )
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta): def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data """Create a filterset for the given model using the provided meta data"""
"""
meta.update({"model": model}) meta.update({"model": model})
meta_class = type(str("Meta"), (object,), meta) meta_class = type("Meta", (object,), meta)
filterset = type( filterset = type(
str("%sFilterSet" % model._meta.object_name), str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin), (filterset_base_class, GrapheneFilterSetMixin),

View File

@ -1,4 +1,4 @@
from mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from django.db import models from django.db import models
@ -9,6 +9,7 @@ import graphene
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.filter import ArrayFilter, ListFilter
from ...compat import ArrayField from ...compat import ArrayField
@ -27,58 +28,71 @@ else:
STORE = {"events": []} STORE = {"events": []}
@pytest.fixture
def Event():
class Event(models.Model): class Event(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50)) tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField())
return Event random_field = ArrayField(models.BooleanField())
@pytest.fixture @pytest.fixture
def EventFilterSet(Event): def EventFilterSet():
from django.contrib.postgres.forms import SimpleArrayField
class ArrayFilter(filters.Filter):
base_field_class = SimpleArrayField
class EventFilterSet(FilterSet): class EventFilterSet(FilterSet):
class Meta: class Meta:
model = Event model = Event
fields = { fields = {
"name": ["exact"], "name": ["exact", "contains"],
} }
# Those are actually usable with our Query fixture bellow
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact")
# Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
tags_ids__overlap = ArrayFilter(field_name="tag_ids", lookup_expr="overlap")
tags_ids = ArrayFilter(field_name="tag_ids", lookup_expr="exact")
random_field__contains = ArrayFilter(
field_name="random_field", lookup_expr="contains"
)
random_field__overlap = ArrayFilter(
field_name="random_field", lookup_expr="overlap"
)
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
return EventFilterSet return EventFilterSet
@pytest.fixture @pytest.fixture
def EventType(Event, EventFilterSet): def EventType(EventFilterSet):
class EventType(DjangoObjectType): class EventType(DjangoObjectType):
class Meta: class Meta:
model = Event model = Event
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filterset_class = EventFilterSet filterset_class = EventFilterSet
return EventType return EventType
@pytest.fixture @pytest.fixture
def Query(Event, EventType): def Query(EventType):
"""
Note that we have to use a custom resolver to replicate the arrayfield filter behavior as
we are running unit tests in sqlite which does not have ArrayFields.
"""
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType) events = DjangoFilterConnectionField(EventType)
def resolve_events(self, info, **kwargs): def resolve_events(self, info, **kwargs):
events = [ events = [
Event(name="Live Show", tags=["concert", "music", "rock"],), Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"],), Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"],), Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
] ]
STORE["events"] = events STORE["events"] = events
@ -105,6 +119,13 @@ def Query(Event, EventType):
STORE["events"], STORE["events"],
) )
) )
if "tags__exact" in kwargs:
STORE["events"] = list(
filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"],
)
)
def mock_queryset_filter(*args, **kwargs): def mock_queryset_filter(*args, **kwargs):
filter_events(**kwargs) filter_events(**kwargs)
@ -121,7 +142,9 @@ def Query(Event, EventType):
m_queryset.filter.side_effect = mock_queryset_filter m_queryset.filter.side_effect = mock_queryset_filter
m_queryset.none.side_effect = mock_queryset_none m_queryset.none.side_effect = mock_queryset_none
m_queryset.count.side_effect = mock_queryset_count m_queryset.count.side_effect = mock_queryset_count
m_queryset.__getitem__.side_effect = STORE["events"].__getitem__ m_queryset.__getitem__.side_effect = lambda index: STORE[
"events"
].__getitem__(index)
return m_queryset return m_queryset

View File

@ -10,7 +10,7 @@ class ArticleFilter(django_filters.FilterSet):
fields = { fields = {
"headline": ["exact", "icontains"], "headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"], "pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"], "reporter": ["exact", "in"],
} }
order_by = OrderingFilter(fields=("pub_date",)) order_by = OrderingFilter(fields=("pub_date",))

View File

@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_multiple(Query): def test_array_field_contains_multiple(Query):
""" """
Test contains filter on a string field. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query) schema = Schema(query=Query)
@ -32,9 +32,9 @@ def test_string_contains_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_one(Query): def test_array_field_contains_one(Query):
""" """
Test contains filter on a string field. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query) schema = Schema(query=Query)
@ -59,9 +59,9 @@ def test_string_contains_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_contains_none(Query): def test_array_field_contains_empty_list(Query):
""" """
Test contains filter on a string field. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query) schema = Schema(query=Query)
@ -79,4 +79,9 @@ def test_string_contains_none(Query):
""" """
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data["events"]["edges"] == [] assert result.data["events"]["edges"] == [
{"node": {"name": "Live Show"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
{"node": {"name": "Speech"}},
]

View File

@ -0,0 +1,127 @@
import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_no_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["concert", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_match(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: ["movie", "music"]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Musical"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_empty_list(Query):
"""
Test exact filter on a array field of string.
"""
schema = Schema(query=Query)
query = """
query {
events (tags: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Speech"}},
]
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_filter_schema_type(Query):
"""
Check that the type in the filter is an array field like on the object type.
"""
schema = Schema(query=Query)
schema_str = str(schema)
assert (
'''type EventType implements Node {
"""The ID of the object"""
id: ID!
name: String!
tags: [String!]!
tagIds: [Int!]!
randomField: [Boolean!]!
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"name": "String",
"name_Contains": "String",
"tags_Contains": "[String!]",
"tags_Overlap": "[String!]",
"tags": "[String!]",
"tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]",
"randomField_Contains": "[Boolean!]",
"randomField_Overlap": "[Boolean!]",
"randomField": "[Boolean!]",
}
filters_str = ", ".join(
[f"{filter_field}: {gql_type}" for filter_field, gql_type in filters.items()]
)
assert (
f"type Query {{\n events({filters_str}): EventTypeConnection\n}}" in schema_str
)

View File

@ -6,9 +6,9 @@ from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_multiple(Query): def test_array_field_overlap_multiple(Query):
""" """
Test overlap filter on a string field. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query) schema = Schema(query=Query)
@ -34,9 +34,9 @@ def test_string_overlap_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_one(Query): def test_array_field_overlap_one(Query):
""" """
Test overlap filter on a string field. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query) schema = Schema(query=Query)
@ -61,9 +61,9 @@ def test_string_overlap_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_string_overlap_none(Query): def test_array_field_overlap_empty_list(Query):
""" """
Test overlap filter on a string field. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query) schema = Schema(query=Query)

View File

@ -0,0 +1,163 @@
import pytest
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType, DjangoConnectionField
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import DjangoFilterConnectionField
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filter_fields = {
"lang": ["exact", "in"],
"reporter__a_choice": ["exact", "in"],
}
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
@pytest.fixture
def reporter_article_data():
john = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
jane = Reporter.objects.create(
first_name="Jane", last_name="Doe", email="janedoe@example.com", a_choice=2
)
Article.objects.create(
headline="Article Node 1", reporter=john, editor=john, lang="es"
)
Article.objects.create(
headline="Article Node 2", reporter=john, editor=john, lang="en"
)
Article.objects.create(
headline="Article Node 3", reporter=jane, editor=jane, lang="en"
)
def test_filter_enum_on_connection(schema, reporter_article_data):
"""
Check that we can filter with enums on a connection.
"""
query = """
query {
allArticles(lang: ES) {
edges {
node {
headline
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_on_foreign_key_enum_field(schema, reporter_article_data):
"""
Check that we can filter with enums on a field from a foreign key.
"""
query = """
query {
allArticles(reporter_AChoice: A_1) {
edges {
node {
headline
}
}
}
}
"""
expected = {
"allArticles": {
"edges": [
{"node": {"headline": "Article Node 1"}},
{"node": {"headline": "Article Node 2"}},
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_filter_enum_field_schema_type(schema):
"""
Check that the type in the filter is an enum like on the object type.
"""
schema_str = str(schema)
assert (
'''type ArticleType implements Node {
"""The ID of the object"""
id: ID!
headline: String!
pubDate: Date!
pubDateTime: DateTime!
reporter: ReporterType!
editor: ReporterType!
"""Language"""
lang: TestsArticleLangChoices!
importance: TestsArticleImportanceChoices
}'''
in schema_str
)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"reporter_AChoice": "TestsReporterAChoiceChoices",
"reporter_AChoice_In": "[TestsReporterAChoiceChoices]",
}
filters_str = ", ".join(
[f"{filter_field}: {gql_type}" for filter_field, gql_type in filters.items()]
)
assert f" allArticles({filters_str}): ArticleTypeConnection\n" in schema_str

View File

@ -5,18 +5,18 @@ import pytest
from django.db.models import TextField, Value from django.db.models import TextField, Value
from django.db.models.functions import Concat from django.db.models.functions import Concat
from graphene import Argument, Boolean, Field, Float, ObjectType, Schema, String from graphene import Argument, Boolean, Decimal, Field, ObjectType, Schema, String
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Person, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = [] pytestmark = []
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
import django_filters import django_filters
from django_filters import FilterSet, NumberFilter from django_filters import FilterSet, NumberFilter, OrderingFilter
from graphene_django.filter import ( from graphene_django.filter import (
GlobalIDFilter, GlobalIDFilter,
@ -67,7 +67,7 @@ def assert_arguments(field, *arguments):
actual = [name for name in args if name not in ignore and not name.startswith("_")] actual = [name for name in args if name not in ignore and not name.startswith("_")]
assert set(arguments) == set( assert set(arguments) == set(
actual actual
), "Expected arguments ({}) did not match actual ({})".format(arguments, actual) ), f"Expected arguments ({arguments}) did not match actual ({actual})"
def assert_orderable(field): def assert_orderable(field):
@ -90,6 +90,7 @@ def test_filter_explicit_filterset_arguments():
"pub_date__gt", "pub_date__gt",
"pub_date__lt", "pub_date__lt",
"reporter", "reporter",
"reporter__in",
) )
@ -140,7 +141,7 @@ def test_filter_shortcut_filterset_context():
@property @property
def qs(self): def qs(self):
qs = super(ArticleContextFilter, self).qs qs = super().qs
return qs.filter(reporter=self.request.reporter) return qs.filter(reporter=self.request.reporter)
class Query(ObjectType): class Query(ObjectType):
@ -165,7 +166,7 @@ def test_filter_shortcut_filterset_context():
editor=r2, editor=r2,
) )
class context(object): class context:
reporter = r2 reporter = r2
query = """ query = """
@ -400,7 +401,7 @@ def test_filterset_descriptions():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
max_time = field.args["max_time"] max_time = field.args["max_time"]
assert isinstance(max_time, Argument) assert isinstance(max_time, Argument)
assert max_time.type == Float assert max_time.type == Decimal
assert max_time.description == "The maximum time" assert max_time.description == "The maximum time"
@ -696,7 +697,7 @@ def test_should_query_filter_node_limit():
node { node {
id id
firstName firstName
articles(lang: "es") { articles(lang: ES) {
edges { edges {
node { node {
id id
@ -738,6 +739,7 @@ def test_order_by():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField( all_reporters = DjangoFilterConnectionField(
@ -1006,7 +1008,7 @@ def test_integer_field_filter_type():
assert str(schema) == dedent( assert str(schema) == dedent(
"""\ """\
type Query { type Query {
pets(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null, age: Int = null): PetTypeConnection pets(offset: Int, before: String, after: String, first: Int, last: Int, age: Int): PetTypeConnection
} }
type PetTypeConnection { type PetTypeConnection {
@ -1054,8 +1056,7 @@ def test_integer_field_filter_type():
interface Node { interface Node {
\"""The ID of the object\""" \"""The ID of the object\"""
id: ID! id: ID!
} }"""
"""
) )
@ -1075,7 +1076,7 @@ def test_other_filter_types():
assert str(schema) == dedent( assert str(schema) == dedent(
"""\ """\
type Query { type Query {
pets(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null, age: Int = null, age_Isnull: Boolean = null, age_Lt: Int = null): PetTypeConnection pets(offset: Int, before: String, after: String, first: Int, last: Int, age: Int, age_Isnull: Boolean, age_Lt: Int): PetTypeConnection
} }
type PetTypeConnection { type PetTypeConnection {
@ -1123,8 +1124,7 @@ def test_other_filter_types():
interface Node { interface Node {
\"""The ID of the object\""" \"""The ID of the object\"""
id: ID! id: ID!
} }"""
"""
) )
@ -1143,7 +1143,7 @@ def test_filter_filterset_based_on_mixin():
return filters return filters
def filter_email_in(cls, queryset, name, value): def filter_email_in(self, queryset, name, value):
return queryset.filter(**{name: [value]}) return queryset.filter(**{name: [value]})
class NewArticleFilter(ArticleFilterMixin, ArticleFilter): class NewArticleFilter(ArticleFilterMixin, ArticleFilter):
@ -1224,7 +1224,81 @@ def test_filter_filterset_based_on_mixin():
} }
} }
result = schema.execute(query, variable_values={"email": reporter_1.email},) result = schema.execute(query, variable_values={"email": reporter_1.email})
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_filter_string_contains():
class PersonType(DjangoObjectType):
class Meta:
model = Person
interfaces = (Node,)
fields = "__all__"
filter_fields = {"name": ["exact", "in", "contains", "icontains"]}
class Query(ObjectType):
people = DjangoFilterConnectionField(PersonType)
schema = Schema(query=Query)
Person.objects.bulk_create(
[
Person(name="Jack"),
Person(name="Joe"),
Person(name="Jane"),
Person(name="Peter"),
Person(name="Bob"),
]
)
query = """query nameContain($filter: String) {
people(name_Contains: $filter) {
edges {
node {
name
}
}
}
}"""
result = schema.execute(query, variables={"filter": "Ja"})
assert not result.errors
assert result.data == {
"people": {
"edges": [
{"node": {"name": "Jack"}},
{"node": {"name": "Jane"}},
]
}
}
result = schema.execute(query, variables={"filter": "o"})
assert not result.errors
assert result.data == {
"people": {
"edges": [
{"node": {"name": "Joe"}},
{"node": {"name": "Bob"}},
]
}
}
def test_only_custom_filters():
class ReporterFilter(FilterSet):
class Meta:
model = Reporter
fields = []
some_filter = OrderingFilter(fields=("name",))
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
filterset_class = ReporterFilter
field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, "some_filter")

View File

@ -1,3 +1,5 @@
from datetime import datetime
import pytest import pytest
from django_filters import FilterSet from django_filters import FilterSet
@ -5,7 +7,8 @@ from django_filters import rest_framework as filters
from graphene import ObjectType, Schema from graphene import ObjectType, Schema
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet, Person from graphene_django.tests.models import Pet, Person, Reporter, Article, Film
from graphene_django.filter.tests.filters import ArticleFilter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = [] pytestmark = []
@ -20,40 +23,77 @@ else:
) )
@pytest.fixture
def query():
class PetNode(DjangoObjectType): class PetNode(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filter_fields = { filter_fields = {
"id": ["exact", "in"],
"name": ["exact", "in"], "name": ["exact", "in"],
"age": ["exact", "in", "range"], "age": ["exact", "in", "range"],
} }
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
# choice filter using enum
filter_fields = {"reporter_type": ["exact", "in"]}
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilter
class FilmNode(DjangoObjectType):
class Meta:
model = Film
interfaces = (Node,)
fields = "__all__"
# choice filter not using enum
filter_fields = {
"genre": ["exact", "in"],
}
convert_choices_to_enum = False
class PersonFilterSet(FilterSet): class PersonFilterSet(FilterSet):
class Meta: class Meta:
model = Person model = Person
fields = {} fields = {"name": ["in"]}
names = filters.BaseInFilter(method="filter_names") names = filters.BaseInFilter(method="filter_names")
def filter_names(self, qs, name, value): def filter_names(self, qs, name, value):
"""
This custom filter take a string as input with comma separated values.
Note that the value here is already a list as it has been transformed by the BaseInFilter class.
"""
return qs.filter(name__in=value) return qs.filter(name__in=value)
class PersonNode(DjangoObjectType): class PersonNode(DjangoObjectType):
class Meta: class Meta:
model = Person model = Person
interfaces = (Node,) interfaces = (Node,)
filterset_class = PersonFilterSet filterset_class = PersonFilterSet
fields = "__all__"
class Query(ObjectType): class Query(ObjectType):
pets = DjangoFilterConnectionField(PetNode) pets = DjangoFilterConnectionField(PetNode)
people = DjangoFilterConnectionField(PersonNode) people = DjangoFilterConnectionField(PersonNode)
articles = DjangoFilterConnectionField(ArticleNode)
films = DjangoFilterConnectionField(FilmNode)
reporters = DjangoFilterConnectionField(ReporterNode)
return Query
def test_string_in_filter(): def test_string_in_filter(query):
""" """
Test in filter on a string field. Test in filter on a string field.
""" """
@ -61,7 +101,7 @@ def test_string_in_filter():
Pet.objects.create(name="Mimi", age=3) Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3) Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query) schema = Schema(query=query)
query = """ query = """
query { query {
@ -82,17 +122,19 @@ def test_string_in_filter():
] ]
def test_string_in_filter_with_filterset_class(): def test_string_in_filter_with_otjer_filter(query):
"""Test in filter on a string field with a custom filterset class.""" """
Test in filter on a string field which has also a custom filter doing a similar operation.
"""
Person.objects.create(name="John") Person.objects.create(name="John")
Person.objects.create(name="Michael") Person.objects.create(name="Michael")
Person.objects.create(name="Angela") Person.objects.create(name="Angela")
schema = Schema(query=Query) schema = Schema(query=query)
query = """ query = """
query { query {
people (names: ["John", "Michael"]) { people (name_In: ["John", "Michael"]) {
edges { edges {
node { node {
name name
@ -109,7 +151,36 @@ def test_string_in_filter_with_filterset_class():
] ]
def test_int_in_filter(): def test_string_in_filter_with_declared_filter(query):
"""
Test in filter on a string field with a custom filterset class.
"""
Person.objects.create(name="John")
Person.objects.create(name="Michael")
Person.objects.create(name="Angela")
schema = Schema(query=query)
query = """
query {
people (names: "John,Michael") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["people"]["edges"] == [
{"node": {"name": "John"}},
{"node": {"name": "Michael"}},
]
def test_int_in_filter(query):
""" """
Test in filter on an integer field. Test in filter on an integer field.
""" """
@ -117,7 +188,7 @@ def test_int_in_filter():
Pet.objects.create(name="Mimi", age=3) Pet.objects.create(name="Mimi", age=3)
Pet.objects.create(name="Jojo, the rabbit", age=3) Pet.objects.create(name="Jojo, the rabbit", age=3)
schema = Schema(query=Query) schema = Schema(query=query)
query = """ query = """
query { query {
@ -157,7 +228,7 @@ def test_int_in_filter():
] ]
def test_in_filter_with_empty_list(): def test_in_filter_with_empty_list(query):
""" """
Check that using a in filter with an empty list provided as input returns no objects. Check that using a in filter with an empty list provided as input returns no objects.
""" """
@ -165,7 +236,7 @@ def test_in_filter_with_empty_list():
Pet.objects.create(name="Mimi", age=8) Pet.objects.create(name="Mimi", age=8)
Pet.objects.create(name="Picotin", age=5) Pet.objects.create(name="Picotin", age=5)
schema = Schema(query=Query) schema = Schema(query=query)
query = """ query = """
query { query {
@ -181,3 +252,197 @@ def test_in_filter_with_empty_list():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert len(result.data["pets"]["edges"]) == 0 assert len(result.data["pets"]["edges"]) == 0
def test_choice_in_filter_without_enum(query):
"""
Test in filter o an choice field not using an enum (Film.genre).
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
documentary_film = Film.objects.create(genre="do")
documentary_film.reporters.add(john_doe)
action_film = Film.objects.create(genre="ac")
action_film.reporters.add(john_doe)
other_film = Film.objects.create(genre="ot")
other_film.reporters.add(john_doe)
other_film.reporters.add(jean_bon)
schema = Schema(query=query)
query = """
query {
films (genre_In: ["do", "ac"]) {
edges {
node {
genre
reporters {
edges {
node {
lastName
}
}
}
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["films"]["edges"] == [
{
"node": {
"genre": "do",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
{
"node": {
"genre": "ac",
"reporters": {"edges": [{"node": {"lastName": "Doe"}}]},
}
},
]
def test_fk_id_in_filter(query):
"""
Test in filter on an foreign key relationship.
"""
john_doe = Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com"
)
jean_bon = Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com"
)
sara_croche = Reporter.objects.create(
first_name="Sara", last_name="Croche", email="sara@croche.com"
)
Article.objects.create(
headline="A",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=john_doe,
editor=john_doe,
)
Article.objects.create(
headline="B",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=jean_bon,
editor=jean_bon,
)
Article.objects.create(
headline="C",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=sara_croche,
editor=sara_croche,
)
schema = Schema(query=query)
query = """
query {{
articles (reporter_In: [{}, {}]) {{
edges {{
node {{
headline
reporter {{
lastName
}}
}}
}}
}}
}}
""".format(
john_doe.id,
jean_bon.id,
)
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A", "reporter": {"lastName": "Doe"}}},
{"node": {"headline": "B", "reporter": {"lastName": "Bon"}}},
]
def test_enum_in_filter(query):
"""
Test in filter on a choice field using an enum (Reporter.reporter_type).
"""
Reporter.objects.create(
first_name="John", last_name="Doe", email="john@doe.com", reporter_type=1
)
Reporter.objects.create(
first_name="Jean", last_name="Bon", email="jean@bon.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jane", last_name="Doe", email="jane@doe.com", reporter_type=2
)
Reporter.objects.create(
first_name="Jack", last_name="Black", email="jack@black.com", reporter_type=None
)
schema = Schema(query=query)
query = """
query {
reporters (reporterType_In: [A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]
query = """
query {
reporters (reporterType_In: [A_2, A_1]) {
edges {
node {
email
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["reporters"]["edges"] == [
{"node": {"email": "john@doe.com"}},
{"node": {"email": "jean@bon.com"}},
{"node": {"email": "jane@doe.com"}},
]

View File

@ -25,6 +25,7 @@ class PetNode(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
filter_fields = { filter_fields = {
"name": ["exact", "in"], "name": ["exact", "in"],
"age": ["exact", "in", "range"], "age": ["exact", "in", "range"],
@ -101,14 +102,14 @@ def test_range_filter_with_invalid_input():
# Empty list # Empty list
result = schema.execute(query, variables={"rangeValue": []}) result = schema.execute(query, variables={"rangeValue": []})
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']" assert result.errors[0].message == expected_error
# Only one item in the list # Only one item in the list
result = schema.execute(query, variables={"rangeValue": [1]}) result = schema.execute(query, variables={"rangeValue": [1]})
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']" assert result.errors[0].message == expected_error
# More than 2 items in the list # More than 2 items in the list
result = schema.execute(query, variables={"rangeValue": [1, 2, 3]}) result = schema.execute(query, variables={"rangeValue": [1, 2, 3]})
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == f"['{expected_error}']" assert result.errors[0].message == expected_error

View File

@ -0,0 +1,151 @@
import pytest
from django_filters import FilterSet
import graphene
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = []
if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import (
DjangoFilterConnectionField,
TypedFilter,
ListFilter,
)
else:
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
@pytest.fixture
def schema():
class ArticleFilterSet(FilterSet):
class Meta:
model = Article
fields = {
"lang": ["exact", "in"],
}
lang__contains = TypedFilter(
field_name="lang", lookup_expr="icontains", input_type=graphene.String
)
lang__in_str = ListFilter(
field_name="lang",
lookup_expr="in",
input_type=graphene.List(graphene.String),
)
first_n = TypedFilter(input_type=graphene.Int, method="first_n_filter")
only_first = TypedFilter(
input_type=graphene.Boolean, method="only_first_filter"
)
def first_n_filter(self, queryset, _name, value):
return queryset[:value]
def only_first_filter(self, queryset, _name, value):
if value:
return queryset[:1]
else:
return queryset
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
fields = "__all__"
filterset_class = ArticleFilterSet
class Query(graphene.ObjectType):
articles = DjangoFilterConnectionField(ArticleType)
schema = graphene.Schema(query=Query)
return schema
def test_typed_filter_schema(schema):
"""
Check that the type provided in the filter is reflected in the schema.
"""
schema_str = str(schema)
filters = {
"offset": "Int",
"before": "String",
"after": "String",
"first": "Int",
"last": "Int",
"lang": "TestsArticleLangChoices",
"lang_In": "[TestsArticleLangChoices]",
"lang_Contains": "String",
"lang_InStr": "[String]",
"firstN": "Int",
"onlyFirst": "Boolean",
}
all_articles_filters = (
schema_str.split(" articles(")[1]
.split("): ArticleTypeConnection\n")[0]
.split(", ")
)
for filter_field, gql_type in filters.items():
assert f"{filter_field}: {gql_type}" in all_articles_filters
def test_typed_filters_work(schema):
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en")
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "C"}},
]
query = "query { articles (firstN: 2) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
]

View File

@ -1,53 +1,103 @@
import graphene import graphene
from django import forms
from django_filters.utils import get_model_field from django_filters.utils import get_model_field, get_field_parts
from django_filters.filters import Filter, BaseCSVFilter from django_filters.filters import Filter, BaseCSVFilter
from .filters import ArrayFilter, ListFilter, RangeFilter, TypedFilter
from .filterset import custom_filterset_factory, setup_filterset from .filterset import custom_filterset_factory, setup_filterset
from .filters import InFilter, RangeFilter from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
def get_field_type(registry, model, field_name):
"""
Try to get a model field corresponding Graphql type from the DjangoObjectType.
"""
object_type = registry.get_type_for_model(model)
if object_type:
object_type_field = object_type._meta.fields.get(field_name)
if object_type_field:
field_type = object_type_field.type
if isinstance(field_type, graphene.NonNull):
field_type = field_type.of_type
return field_type
return None
def get_filtering_args_from_filterset(filterset_class, type): def get_filtering_args_from_filterset(filterset_class, type):
""" Inspect a FilterSet and produce the arguments to pass to """
a Graphene Field. These arguments will be available to Inspect a FilterSet and produce the arguments to pass to a Graphene Field.
filter against in the GraphQL These arguments will be available to filter against in the GraphQL API.
""" """
from ..forms.converter import convert_form_field from ..forms.converter import convert_form_field
args = {} args = {}
model = filterset_class._meta.model model = filterset_class._meta.model
registry = type._meta.registry
for name, filter_field in filterset_class.base_filters.items(): for name, filter_field in filterset_class.base_filters.items():
form_field = None
filter_type = filter_field.lookup_expr filter_type = filter_field.lookup_expr
if name in filterset_class.declared_filters:
# Get the filter field from the explicitly declared filter
form_field = filter_field.field
field = convert_form_field(form_field)
else:
# Get the filter field with no explicit type declaration
model_field = get_model_field(model, filter_field.field_name)
if filter_type != "isnull" and hasattr(model_field, "formfield"):
form_field = model_field.formfield(
required = filter_field.extra.get("required", False) required = filter_field.extra.get("required", False)
) field_type = None
form_field = None
# Fallback to field defined on filter if we can't get it from the if (
# model field isinstance(filter_field, TypedFilter)
and filter_field.input_type is not None
):
# First check if the filter input type has been explicitely given
field_type = filter_field.input_type
else:
if name not in filterset_class.declared_filters or isinstance(
filter_field, TypedFilter
):
# Get the filter field for filters that are no explicitly declared.
if filter_type == "isnull":
field = graphene.Boolean(required=required)
else:
model_field = get_model_field(model, filter_field.field_name)
# Get the form field either from:
# 1. the formfield corresponding to the model field
# 2. the field defined on filter
if hasattr(model_field, "formfield"):
form_field = model_field.formfield(required=required)
if not form_field: if not form_field:
form_field = filter_field.field form_field = filter_field.field
field = convert_form_field(form_field) # First try to get the matching field type from the GraphQL DjangoObjectType
if model_field:
if (
isinstance(form_field, forms.ModelChoiceField)
or isinstance(form_field, forms.ModelMultipleChoiceField)
or isinstance(form_field, GlobalIDMultipleChoiceField)
or isinstance(form_field, GlobalIDFormField)
):
# Foreign key have dynamic types and filtering on a foreign key actually means filtering on its ID.
field_type = get_field_type(
registry, model_field.related_model, "id"
)
else:
field_type = get_field_type(
registry, model_field.model, model_field.name
)
if filter_type in {"in", "range", "contains", "overlap"}: if not field_type:
# Replace CSV filters (`in`, `range`, `contains`, `overlap`) argument type to be a list of # Fallback on converting the form field either because:
# the same type as the field. See comments in # - it's an explicitly declared filters
# `replace_csv_filters` method for more details. # - we did not manage to get the type from the model type
field = graphene.List(field.get_type()) form_field = form_field or filter_field.field
field_type = convert_form_field(form_field).get_type()
field_type = field.Argument() if isinstance(filter_field, ListFilter) or isinstance(
field_type.description = str(filter_field.label) if filter_field.label else None filter_field, RangeFilter
args[name] = field_type ):
# Replace InFilter/RangeFilter filters (`in`, `range`) argument type to be a list of
# the same type as the field. See comments in `replace_csv_filters` method for more details.
field_type = graphene.List(field_type)
args[name] = graphene.Argument(
field_type,
description=filter_field.label,
required=required,
)
return args return args
@ -69,18 +119,26 @@ def get_filterset_class(filterset_class, **meta):
def replace_csv_filters(filterset_class): def replace_csv_filters(filterset_class):
""" """
Replace the "in", "contains", "overlap" and "range" filters (that are not explicitly declared) to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore Replace the "in" and "range" filters (that are not explicitly declared)
but regular Filter objects that simply use the input value as filter argument on the queryset. to not be BaseCSVFilter (BaseInFilter, BaseRangeFilter) objects anymore
but our custom InFilter/RangeFilter filter class that use the input
value as filter argument on the queryset.
This is because those BaseCSVFilter are expecting a string as input with comma separated value but with GraphQl we This is because those BaseCSVFilter are expecting a string as input with
can actually have a list as input and have a proper type verification of each value in the list. comma separated values.
But with GraphQl we can actually have a list as input and have a proper
type verification of each value in the list.
See issue https://github.com/graphql-python/graphene-django/issues/1068. See issue https://github.com/graphql-python/graphene-django/issues/1068.
""" """
for name, filter_field in list(filterset_class.base_filters.items()): for name, filter_field in list(filterset_class.base_filters.items()):
# Do not touch any declared filters
if name in filterset_class.declared_filters:
continue
filter_type = filter_field.lookup_expr filter_type = filter_field.lookup_expr
if filter_type in {"in", "contains", "overlap"}: if filter_type == "in":
filterset_class.base_filters[name] = InFilter( filterset_class.base_filters[name] = ListFilter(
field_name=filter_field.field_name, field_name=filter_field.field_name,
lookup_expr=filter_field.lookup_expr, lookup_expr=filter_field.lookup_expr,
label=filter_field.label, label=filter_field.label,
@ -88,7 +146,6 @@ def replace_csv_filters(filterset_class):
exclude=filter_field.exclude, exclude=filter_field.exclude,
**filter_field.extra **filter_field.extra
) )
elif filter_type == "range": elif filter_type == "range":
filterset_class.base_filters[name] = RangeFilter( filterset_class.base_filters[name] = RangeFilter(
field_name=filter_field.field_name, field_name=filter_field.field_name,

View File

@ -3,7 +3,19 @@ from functools import singledispatch
from django import forms from django import forms
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time from graphene import (
ID,
Boolean,
Decimal,
Float,
Int,
List,
String,
UUID,
Date,
DateTime,
Time,
)
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
@ -57,12 +69,18 @@ def convert_form_field_to_nullboolean(field):
return Boolean(description=get_form_field_description(field)) return Boolean(description=get_form_field_description(field))
@convert_form_field.register(forms.DecimalField)
@convert_form_field.register(forms.FloatField) @convert_form_field.register(forms.FloatField)
def convert_form_field_to_float(field): def convert_form_field_to_float(field):
return Float(description=get_form_field_description(field), required=field.required) return Float(description=get_form_field_description(field), required=field.required)
@convert_form_field.register(forms.DecimalField)
def convert_form_field_to_decimal(field):
return Decimal(
description=get_form_field_description(field), required=field.required
)
@convert_form_field.register(forms.MultipleChoiceField) @convert_form_field.register(forms.MultipleChoiceField)
def convert_form_field_to_string_list(field): def convert_form_field_to_string_list(field):
return List( return List(

View File

@ -14,10 +14,6 @@ from graphene.types.utils import yank_fields_from_attrs
from graphene_django.constants import MUTATION_ERRORS_FLAG from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.registry import get_global_registry from graphene_django.registry import get_global_registry
from django.core.exceptions import ValidationError
from django.db import connection
from ..types import ErrorType from ..types import ErrorType
from .converter import convert_form_field from .converter import convert_form_field
@ -99,12 +95,15 @@ class DjangoFormMutation(BaseDjangoFormMutation):
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field) _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField) input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoFormMutation, cls).__init_subclass_with_meta__( super().__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options _meta=_meta, input_fields=input_fields, **options
) )
@classmethod @classmethod
def perform_mutate(cls, form, info): def perform_mutate(cls, form, info):
if hasattr(form, "save"):
# `save` method won't exist on plain Django forms, but this mutation can
# in theory be used with `ModelForm`s as well and we do want to save them.
form.save() form.save()
return cls(errors=[], **form.cleaned_data) return cls(errors=[], **form.cleaned_data)
@ -118,7 +117,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
class Meta: class Meta:
abstract = True abstract = True
errors = graphene.List(ErrorType) errors = graphene.List(graphene.NonNull(ErrorType), required=True)
@classmethod @classmethod
def __init_subclass_with_meta__( def __init_subclass_with_meta__(
@ -128,7 +127,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
return_field_name=None, return_field_name=None,
only_fields=(), only_fields=(),
exclude_fields=(), exclude_fields=(),
**options **options,
): ):
if not form_class: if not form_class:
@ -148,7 +147,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
registry = get_global_registry() registry = get_global_registry()
model_type = registry.get_type_for_model(model) model_type = registry.get_type_for_model(model)
if not model_type: if not model_type:
raise Exception("No type registered for model: {}".format(model.__name__)) raise Exception(f"No type registered for model: {model.__name__}")
if not return_field_name: if not return_field_name:
model_name = model.__name__ model_name = model.__name__
@ -164,7 +163,7 @@ class DjangoModelFormMutation(BaseDjangoFormMutation):
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field) _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField) input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( super().__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options _meta=_meta, input_fields=input_fields, **options
) )

View File

@ -1,11 +1,12 @@
from django import forms from django import forms
from py.test import raises from pytest import raises
import graphene import graphene
from graphene import ( from graphene import (
String, String,
Int, Int,
Boolean, Boolean,
Decimal,
Float, Float,
ID, ID,
UUID, UUID,
@ -97,8 +98,8 @@ def test_should_float_convert_float():
assert_conversion(forms.FloatField, Float) assert_conversion(forms.FloatField, Float)
def test_should_decimal_convert_float(): def test_should_decimal_convert_decimal():
assert_conversion(forms.DecimalField, Float) assert_conversion(forms.DecimalField, Decimal)
def test_should_multiple_choice_convert_list(): def test_should_multiple_choice_convert_list():

View File

@ -1,7 +1,7 @@
import pytest import pytest
from django import forms from django import forms
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from py.test import raises from pytest import raises
from graphene import Field, ObjectType, Schema, String from graphene import Field, ObjectType, Schema, String
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType

View File

@ -1,3 +1 @@
import graphene
from ..types import ErrorType # noqa Import ErrorType for backwards compatability from ..types import ErrorType # noqa Import ErrorType for backwards compatability

View File

@ -48,14 +48,14 @@ class CommandArguments(BaseCommand):
class Command(CommandArguments): class Command(CommandArguments):
help = "Dump Graphene schema as a JSON or GraphQL file" help = "Dump Graphene schema as a JSON or GraphQL file"
can_import_settings = True can_import_settings = True
requires_system_checks = False requires_system_checks = []
def save_json_file(self, out, schema_dict, indent): def save_json_file(self, out, schema_dict, indent):
with open(out, "w") as outfile: with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent, sort_keys=True) json.dump(schema_dict, outfile, indent=indent, sort_keys=True)
def save_graphql_file(self, out, schema): def save_graphql_file(self, out, schema):
with open(out, "w") as outfile: with open(out, "w", encoding="utf-8") as outfile:
outfile.write(print_schema(schema.graphql_schema)) outfile.write(print_schema(schema.graphql_schema))
def get_schema(self, schema, out, indent): def get_schema(self, schema, out, indent):
@ -73,16 +73,12 @@ class Command(CommandArguments):
elif file_extension == ".json": elif file_extension == ".json":
self.save_json_file(out, schema_dict, indent) self.save_json_file(out, schema_dict, indent)
else: else:
raise CommandError( raise CommandError(f'Unrecognised file format "{file_extension}"')
'Unrecognised file format "{}"'.format(file_extension)
)
style = getattr(self, "style", None) style = getattr(self, "style", None)
success = getattr(style, "SUCCESS", lambda x: x) success = getattr(style, "SUCCESS", lambda x: x)
self.stdout.write( self.stdout.write(success(f"Successfully dumped GraphQL schema to {out}"))
success("Successfully dumped GraphQL schema to {}".format(out))
)
def handle(self, *args, **options): def handle(self, *args, **options):
options_schema = options.get("schema") options_schema = options.get("schema")

View File

@ -1,4 +1,4 @@
class Registry(object): class Registry:
def __init__(self): def __init__(self):
self._registry = {} self._registry = {}
self._field_registry = {} self._field_registry = {}

View File

@ -114,7 +114,7 @@ class SerializerMutation(ClientIDMutation):
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field) _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField) input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(SerializerMutation, cls).__init_subclass_with_meta__( super().__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options _meta=_meta, input_fields=input_fields, **options
) )

View File

@ -72,7 +72,7 @@ def convert_serializer_to_input_type(serializer_class):
for name, field in serializer.fields.items() for name, field in serializer.fields.items()
} }
ret_type = type( ret_type = type(
"{}Input".format(serializer.__class__.__name__), f"{serializer.__class__.__name__}Input",
(graphene.InputObjectType,), (graphene.InputObjectType,),
items, items,
) )
@ -110,11 +110,15 @@ def convert_serializer_field_to_bool(field):
@get_graphene_type_from_serializer_field.register(serializers.FloatField) @get_graphene_type_from_serializer_field.register(serializers.FloatField)
@get_graphene_type_from_serializer_field.register(serializers.DecimalField)
def convert_serializer_field_to_float(field): def convert_serializer_field_to_float(field):
return graphene.Float return graphene.Float
@get_graphene_type_from_serializer_field.register(serializers.DecimalField)
def convert_serializer_field_to_decimal(field):
return graphene.Decimal
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField) @get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
def convert_serializer_field_to_datetime_time(field): def convert_serializer_field_to_datetime_time(field):
return graphene.types.datetime.DateTime return graphene.types.datetime.DateTime

View File

@ -3,7 +3,7 @@ import copy
import graphene import graphene
from django.db import models from django.db import models
from graphene import InputObjectType from graphene import InputObjectType
from py.test import raises from pytest import raises
from rest_framework import serializers from rest_framework import serializers
from ..serializer_converter import convert_serializer_field from ..serializer_converter import convert_serializer_field
@ -133,9 +133,9 @@ def test_should_float_convert_float():
assert_conversion(serializers.FloatField, graphene.Float) assert_conversion(serializers.FloatField, graphene.Float)
def test_should_decimal_convert_float(): def test_should_decimal_convert_decimal():
assert_conversion( assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2 serializers.DecimalField, graphene.Decimal, max_digits=4, decimal_places=2
) )

View File

@ -1,6 +1,6 @@
import datetime import datetime
from py.test import raises from pytest import raises
from rest_framework import serializers from rest_framework import serializers
from graphene import Field, ResolveInfo from graphene import Field, ResolveInfo

View File

@ -11,15 +11,11 @@ This module provides the `graphene_settings` object, that is used to access
Graphene settings, checking for user settings first, then falling Graphene settings, checking for user settings first, then falling
back to the defaults. back to the defaults.
""" """
from __future__ import unicode_literals
from django.conf import settings from django.conf import settings
from django.test.signals import setting_changed from django.test.signals import setting_changed
try:
import importlib # Available in Python 3.1+ import importlib # Available in Python 3.1+
except ImportError:
from django.utils import importlib # Will be removed in Django 1.9
# Copied shamelessly from Django REST Framework # Copied shamelessly from Django REST Framework
@ -44,7 +40,9 @@ DEFAULTS = {
# This sets headerEditorEnabled GraphiQL option, for details go to # This sets headerEditorEnabled GraphiQL option, for details go to
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options # https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
"GRAPHIQL_HEADER_EDITOR_ENABLED": True, "GRAPHIQL_HEADER_EDITOR_ENABLED": True,
"GRAPHIQL_SHOULD_PERSIST_HEADERS": False,
"ATOMIC_MUTATIONS": False, "ATOMIC_MUTATIONS": False,
"TESTING_ENDPOINT": "/graphql",
} }
if settings.DEBUG: if settings.DEBUG:
@ -79,7 +77,7 @@ def import_from_string(val, setting_name):
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % ( msg = "Could not import '{}' for Graphene setting '{}'. {}: {}.".format(
val, val,
setting_name, setting_name,
e.__class__.__name__, e.__class__.__name__,
@ -88,7 +86,7 @@ def import_from_string(val, setting_name):
raise ImportError(msg) raise ImportError(msg)
class GrapheneSettings(object): class GrapheneSettings:
""" """
A settings object, that allows API settings to be accessed as properties. A settings object, that allows API settings to be accessed as properties.
For example: For example:

View File

@ -10,14 +10,6 @@
history, history,
location, location,
) { ) {
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ("; " + document.cookie).split("; csrftoken=");
if (cookies.length == 2) {
csrftoken = cookies.pop().split(";").shift();
} else {
csrftoken = document.querySelector("[name=csrfmiddlewaretoken]").value;
}
// Collect the URL parameters // Collect the URL parameters
var parameters = {}; var parameters = {};
@ -68,9 +60,19 @@
var headers = opts.headers || {}; var headers = opts.headers || {};
headers['Accept'] = headers['Accept'] || 'application/json'; headers['Accept'] = headers['Accept'] || 'application/json';
headers['Content-Type'] = headers['Content-Type'] || 'application/json'; headers['Content-Type'] = headers['Content-Type'] || 'application/json';
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ("; " + document.cookie).split("; csrftoken=");
if (cookies.length == 2) {
csrftoken = cookies.pop().split(";").shift();
} else {
csrftoken = document.querySelector("[name=csrfmiddlewaretoken]").value;
}
if (csrftoken) { if (csrftoken) {
headers['X-CSRFToken'] = csrftoken headers['X-CSRFToken'] = csrftoken
} }
return fetch(fetchURL, { return fetch(fetchURL, {
method: "post", method: "post",
headers: headers, headers: headers,
@ -123,8 +125,8 @@
if (operationType === "subscription") { if (operationType === "subscription") {
return { return {
subscribe: function (observer) { subscribe: function (observer) {
subscriptionClient.request(graphQLParams).subscribe(observer);
activeSubscription = subscriptionClient; activeSubscription = subscriptionClient;
return subscriptionClient.request(graphQLParams, opts).subscribe(observer);
}, },
}; };
} else { } else {
@ -176,6 +178,7 @@
onEditVariables: onEditVariables, onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName, onEditOperationName: onEditOperationName,
headerEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled, headerEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled,
shouldPersistHeaders: GRAPHENE_SETTINGS.graphiqlShouldPersistHeaders,
query: parameters.query, query: parameters.query,
}; };
if (parameters.variables) { if (parameters.variables) {

View File

@ -46,6 +46,7 @@ add "&raw" to the end of the URL within a browser.
subscriptionPath: "{{subscription_path}}", subscriptionPath: "{{subscription_path}}",
{% endif %} {% endif %}
graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }}, graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }},
graphiqlShouldPersistHeaders: {{ graphiql_should_persist_headers|yesno:"true,false" }},
}; };
</script> </script>
<script src="{% static 'graphene_django/graphiql.js' %}"></script> <script src="{% static 'graphene_django/graphiql.js' %}"></script>

View File

@ -8,8 +8,8 @@ import graphene
from graphene import Field, ResolveInfo from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from py.test import raises from pytest import raises
from py.test import mark from pytest import mark
from rest_framework import serializers from rest_framework import serializers
from ...types import DjangoObjectType from ...types import DjangoObjectType

View File

@ -1,5 +1,3 @@
from __future__ import absolute_import
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -13,6 +11,9 @@ class Person(models.Model):
class Pet(models.Model): class Pet(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
age = models.PositiveIntegerField() age = models.PositiveIntegerField()
owner = models.ForeignKey(
"Person", on_delete=models.CASCADE, null=True, blank=True, related_name="pets"
)
class FilmDetails(models.Model): class FilmDetails(models.Model):
@ -26,7 +27,7 @@ class Film(models.Model):
genre = models.CharField( genre = models.CharField(
max_length=2, max_length=2,
help_text="Genre", help_text="Genre",
choices=[("do", "Documentary"), ("ot", "Other")], choices=[("do", "Documentary"), ("ac", "Action"), ("ot", "Other")],
default="ot", default="ot",
) )
reporters = models.ManyToManyField("Reporter", related_name="films") reporters = models.ManyToManyField("Reporter", related_name="films")
@ -34,7 +35,7 @@ class Film(models.Model):
class DoeReporterManager(models.Manager): class DoeReporterManager(models.Manager):
def get_queryset(self): def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe") return super().get_queryset().filter(last_name="Doe")
class Reporter(models.Model): class Reporter(models.Model):
@ -54,7 +55,7 @@ class Reporter(models.Model):
) )
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name) return f"{self.first_name} {self.last_name}"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
@ -64,7 +65,7 @@ class Reporter(models.Model):
when a CNNReporter is pulled from the database, it is still when a CNNReporter is pulled from the database, it is still
of type Reporter. This was added to test proxy model support. of type Reporter. This was added to test proxy model support.
""" """
super(Reporter, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if self.reporter_type == 2: # quick and dirty way without enums if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter self.__class__ = CNNReporter
@ -74,7 +75,7 @@ class Reporter(models.Model):
class CNNReporterManager(models.Manager): class CNNReporterManager(models.Manager):
def get_queryset(self): def get_queryset(self):
return super(CNNReporterManager, self).get_queryset().filter(reporter_type=2) return super().get_queryset().filter(reporter_type=2)
class CNNReporter(Reporter): class CNNReporter(Reporter):
@ -91,8 +92,8 @@ class CNNReporter(Reporter):
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField() pub_date = models.DateField(auto_now_add=True)
pub_date_time = models.DateTimeField() pub_date_time = models.DateTimeField(auto_now_add=True)
reporter = models.ForeignKey( reporter = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="articles" Reporter, on_delete=models.CASCADE, related_name="articles"
) )

View File

@ -2,7 +2,7 @@ from textwrap import dedent
from django.core import management from django.core import management
from io import StringIO from io import StringIO
from mock import mock_open, patch from unittest.mock import mock_open, patch
from graphene import ObjectType, Schema, String from graphene import ObjectType, Schema, String
@ -53,6 +53,5 @@ def test_generate_graphql_file_on_call_graphql_schema():
"""\ """\
type Query { type Query {
hi: String hi: String
} }"""
"""
) )

View File

@ -3,13 +3,14 @@ from collections import namedtuple
import pytest import pytest
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from py.test import raises from pytest import raises
import graphene import graphene
from graphene import NonNull from graphene import NonNull
from graphene.relay import ConnectionField, Node from graphene.relay import ConnectionField, Node
from graphene.types.datetime import Date, DateTime, Time from graphene.types.datetime import Date, DateTime, Time
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.scalars import BigInt
from ..compat import ( from ..compat import (
ArrayField, ArrayField,
@ -111,6 +112,15 @@ def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID, primary_key=True) assert_conversion(models.AutoField, graphene.ID, primary_key=True)
def test_should_big_auto_convert_id():
assert_conversion(models.BigAutoField, graphene.ID, primary_key=True)
def test_should_small_auto_convert_id():
if hasattr(models, "SmallAutoField"):
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True)
def test_should_uuid_convert_id(): def test_should_uuid_convert_id():
assert_conversion(models.UUIDField, graphene.UUID) assert_conversion(models.UUIDField, graphene.UUID)
@ -131,8 +141,8 @@ def test_should_small_integer_convert_int():
assert_conversion(models.SmallIntegerField, graphene.Int) assert_conversion(models.SmallIntegerField, graphene.Int)
def test_should_big_integer_convert_int(): def test_should_big_integer_convert_big_int():
assert_conversion(models.BigIntegerField, graphene.Int) assert_conversion(models.BigIntegerField, BigInt)
def test_should_integer_convert_int(): def test_should_integer_convert_int():

View File

@ -408,3 +408,95 @@ class TestDjangoListField:
{"firstName": "Debra", "articles": []}, {"firstName": "Debra", "articles": []},
] ]
} }
def test_resolve_list_external_resolver(self):
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return [ReporterModel.objects.get(first_name="Debra")]
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Debra"}]}
def test_get_queryset_filter_external_resolver(self):
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
fields = ("first_name", "articles")
@classmethod
def get_queryset(cls, queryset, info):
# Only get reporters with at least 1 article
return queryset.annotate(article_count=Count("articles")).filter(
article_count__gt=0
)
def resolve_reporters(_, info):
return ReporterModel.objects.all()
class Query(ObjectType):
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
schema = Schema(query=Query)
query = """
query {
reporters {
firstName
}
}
"""
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
ArticleModel.objects.create(
headline="Amazing news",
reporter=r1,
pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
editor=r1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {"reporters": [{"firstName": "Tara"}]}

View File

@ -1,5 +1,5 @@
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from py.test import raises from pytest import raises
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField

View File

@ -0,0 +1,361 @@
import pytest
import graphene
from graphene.relay import Node
from graphql_relay import to_global_id
from ..fields import DjangoConnectionField
from ..types import DjangoObjectType
from .models import Article, Reporter
class TestShouldCallGetQuerySetOnForeignKey:
"""
Check that the get_queryset method is called in both forward and reversed direction
of a foreignkey on types.
(see issue #1111)
"""
@pytest.fixture(autouse=True)
def setup_schema(self):
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
@classmethod
def get_queryset(cls, queryset, info):
if info.context and info.context.get("admin"):
return queryset
raise Exception("Not authorized to access reporters.")
class ArticleType(DjangoObjectType):
class Meta:
model = Article
@classmethod
def get_queryset(cls, queryset, info):
return queryset.exclude(headline__startswith="Draft")
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType, id=graphene.ID(required=True))
article = graphene.Field(ArticleType, id=graphene.ID(required=True))
def resolve_reporter(self, info, id):
return (
ReporterType.get_queryset(Reporter.objects, info)
.filter(id=id)
.last()
)
def resolve_article(self, info, id):
return (
ArticleType.get_queryset(Article.objects, info).filter(id=id).last()
)
self.schema = graphene.Schema(query=Query)
self.reporter = Reporter.objects.create(first_name="Jane", last_name="Doe")
self.articles = [
Article.objects.create(
headline="A fantastic article",
reporter=self.reporter,
editor=self.reporter,
),
Article.objects.create(
headline="Draft: My next best seller",
reporter=self.reporter,
editor=self.reporter,
),
]
def test_get_queryset_called_on_field(self):
# If a user tries to access an article it is fine as long as it's not a draft one
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
}
}
"""
# Non-draft
result = self.schema.execute(query, variables={"id": self.articles[0].id})
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
}
# Draft
result = self.schema.execute(query, variables={"id": self.articles[1].id})
assert not result.errors
assert result.data["article"] is None
# If a non admin user tries to access a reporter they should get our authorization error
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(query, variables={"id": self.reporter.id})
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.reporter.id},
context_value={"admin": True},
)
assert not result.errors
assert result.data == {"reporter": {"firstName": "Jane"}}
def test_get_queryset_called_on_foreignkey(self):
# If a user tries to access a reporter through an article they should get our authorization error
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(query, variables={"id": self.articles[0].id})
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters through an article
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.articles[0].id},
context_value={"admin": True},
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
"reporter": {"firstName": "Jane"},
}
# An admin user should not be able to access draft article through a reporter
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
articles {
headline
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.reporter.id},
context_value={"admin": True},
)
assert not result.errors
assert result.data["reporter"] == {
"firstName": "Jane",
"articles": [{"headline": "A fantastic article"}],
}
class TestShouldCallGetQuerySetOnForeignKeyNode:
"""
Check that the get_queryset method is called in both forward and reversed direction
of a foreignkey on types using a node interface.
(see issue #1111)
"""
@pytest.fixture(autouse=True)
def setup_schema(self):
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
@classmethod
def get_queryset(cls, queryset, info):
if info.context and info.context.get("admin"):
return queryset
raise Exception("Not authorized to access reporters.")
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node,)
@classmethod
def get_queryset(cls, queryset, info):
return queryset.exclude(headline__startswith="Draft")
class Query(graphene.ObjectType):
reporter = Node.Field(ReporterType)
article = Node.Field(ArticleType)
self.schema = graphene.Schema(query=Query)
self.reporter = Reporter.objects.create(first_name="Jane", last_name="Doe")
self.articles = [
Article.objects.create(
headline="A fantastic article",
reporter=self.reporter,
editor=self.reporter,
),
Article.objects.create(
headline="Draft: My next best seller",
reporter=self.reporter,
editor=self.reporter,
),
]
def test_get_queryset_called_on_node(self):
# If a user tries to access an article it is fine as long as it's not a draft one
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
}
}
"""
# Non-draft
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[0].id)}
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
}
# Draft
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[1].id)}
)
assert not result.errors
assert result.data["article"] is None
# If a non admin user tries to access a reporter they should get our authorization error
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(
query, variables={"id": to_global_id("ReporterType", self.reporter.id)}
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ReporterType", self.reporter.id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data == {"reporter": {"firstName": "Jane"}}
def test_get_queryset_called_on_foreignkey(self):
# If a user tries to access a reporter through an article they should get our authorization error
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[0].id)}
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters through an article
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ArticleType", self.articles[0].id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
"reporter": {"firstName": "Jane"},
}
# An admin user should not be able to access draft article through a reporter
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
articles {
edges {
node {
headline
}
}
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ReporterType", self.reporter.id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data["reporter"] == {
"firstName": "Jane",
"articles": {"edges": [{"node": {"headline": "A fantastic article"}}]},
}

View File

@ -6,7 +6,7 @@ from django.db import models
from django.db.models import Q from django.db.models import Q
from django.utils.functional import SimpleLazyObject from django.utils.functional import SimpleLazyObject
from graphql_relay import to_global_id from graphql_relay import to_global_id
from py.test import raises from pytest import raises
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node
@ -15,7 +15,7 @@ from ..compat import IntegerRangeField, MissingType
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from ..utils import DJANGO_FILTER_INSTALLED from ..utils import DJANGO_FILTER_INSTALLED
from .models import Article, CNNReporter, Film, FilmDetails, Reporter from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter
def test_should_query_only_fields(): def test_should_query_only_fields():
@ -251,8 +251,8 @@ def test_should_node():
def test_should_query_onetoone_fields(): def test_should_query_onetoone_fields():
film = Film(id=1) film = Film.objects.create(id=1)
film_details = FilmDetails(id=1, film=film) film_details = FilmDetails.objects.create(id=1, film=film)
class FilmNode(DjangoObjectType): class FilmNode(DjangoObjectType):
class Meta: class Meta:
@ -421,6 +421,7 @@ def test_should_query_node_filtering():
interfaces = (Node,) interfaces = (Node,)
fields = "__all__" fields = "__all__"
filter_fields = ("lang",) filter_fields = ("lang",)
convert_choices_to_enum = False
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -546,6 +547,7 @@ def test_should_query_node_multiple_filtering():
interfaces = (Node,) interfaces = (Node,)
fields = "__all__" fields = "__all__"
filter_fields = ("lang", "headline") filter_fields = ("lang", "headline")
convert_choices_to_enum = False
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1149,9 +1151,9 @@ def test_connection_should_limit_after_to_list_length():
REPORTERS = [ REPORTERS = [
dict( dict(
first_name="First {}".format(i), first_name=f"First {i}",
last_name="Last {}".format(i), last_name=f"Last {i}",
email="johndoe+{}@example.com".format(i), email=f"johndoe+{i}@example.com",
a_choice=1, a_choice=1,
) )
for i in range(6) for i in range(6)
@ -1241,6 +1243,7 @@ def test_should_have_next_page(graphene_settings):
} }
@pytest.mark.parametrize("max_limit", [100, 4])
class TestBackwardPagination: class TestBackwardPagination:
def setup_schema(self, graphene_settings, max_limit): def setup_schema(self, graphene_settings, max_limit):
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit graphene_settings.RELAY_CONNECTION_MAX_LIMIT = max_limit
@ -1251,6 +1254,7 @@ class TestBackwardPagination:
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1258,8 +1262,8 @@ class TestBackwardPagination:
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
return schema return schema
def do_queries(self, schema): def test_query_last(self, graphene_settings, max_limit):
# Simply last 3 schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_last = """ query_last = """
query { query {
allReporters(last: 3) { allReporters(last: 3) {
@ -1279,7 +1283,8 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"] e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 3", "First 4", "First 5"] ] == ["First 3", "First 4", "First 5"]
# Use a combination of first and last def test_query_first_and_last(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_and_last = """ query_first_and_last = """
query { query {
allReporters(first: 4, last: 3) { allReporters(first: 4, last: 3) {
@ -1299,7 +1304,8 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"] e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 1", "First 2", "First 3"] ] == ["First 1", "First 2", "First 3"]
# Use a combination of first and last and after def test_query_first_last_and_after(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_last_and_after = """ query_first_last_and_after = """
query queryAfter($after: String) { query queryAfter($after: String) {
allReporters(first: 4, last: 3, after: $after) { allReporters(first: 4, last: 3, after: $after) {
@ -1314,7 +1320,8 @@ class TestBackwardPagination:
after = base64.b64encode(b"arrayconnection:0").decode() after = base64.b64encode(b"arrayconnection:0").decode()
result = schema.execute( result = schema.execute(
query_first_last_and_after, variable_values=dict(after=after) query_first_last_and_after,
variable_values=dict(after=after),
) )
assert not result.errors assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3 assert len(result.data["allReporters"]["edges"]) == 3
@ -1322,20 +1329,35 @@ class TestBackwardPagination:
e["node"]["firstName"] for e in result.data["allReporters"]["edges"] e["node"]["firstName"] for e in result.data["allReporters"]["edges"]
] == ["First 2", "First 3", "First 4"] ] == ["First 2", "First 3", "First 4"]
def test_should_query(self, graphene_settings): def test_query_last_and_before(self, graphene_settings, max_limit):
schema = self.setup_schema(graphene_settings, max_limit=max_limit)
query_first_last_and_after = """
query queryAfter($before: String) {
allReporters(last: 1, before: $before) {
edges {
node {
firstName
}
}
}
}
""" """
Backward pagination should work as expected
"""
schema = self.setup_schema(graphene_settings, max_limit=100)
self.do_queries(schema)
def test_should_query_with_low_max_limit(self, graphene_settings): result = schema.execute(
""" query_first_last_and_after,
When doing backward pagination (using last) in combination with a max limit higher than the number of objects )
we should really retrieve the last ones. assert not result.errors
""" assert len(result.data["allReporters"]["edges"]) == 1
schema = self.setup_schema(graphene_settings, max_limit=4) assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 5"
self.do_queries(schema)
before = base64.b64encode(b"arrayconnection:5").decode()
result = schema.execute(
query_first_last_and_after,
variable_values=dict(before=before),
)
assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1
assert result.data["allReporters"]["edges"][0]["node"]["firstName"] == "First 4"
def test_should_preserve_prefetch_related(django_assert_num_queries): def test_should_preserve_prefetch_related(django_assert_num_queries):
@ -1455,6 +1477,7 @@ def test_connection_should_enable_offset_filtering():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1476,7 +1499,11 @@ def test_connection_should_enable_offset_filtering():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": {"edges": [{"node": {"firstName": "Some", "lastName": "Guy"}},]} "allReporters": {
"edges": [
{"node": {"firstName": "Some", "lastName": "Guy"}},
]
}
} }
assert result.data == expected assert result.data == expected
@ -1494,6 +1521,7 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1516,7 +1544,9 @@ def test_connection_should_enable_offset_filtering_higher_than_max_limit(
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": { "allReporters": {
"edges": [{"node": {"firstName": "Some", "lastName": "Lady"}},] "edges": [
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
} }
} }
assert result.data == expected assert result.data == expected
@ -1527,6 +1557,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1561,6 +1592,7 @@ def test_connection_should_allow_offset_filtering_with_after():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node,) interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
@ -1583,6 +1615,149 @@ def test_connection_should_allow_offset_filtering_with_after():
result = schema.execute(query, variable_values=dict(after=after)) result = schema.execute(query, variable_values=dict(after=after))
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe"}},]} "allReporters": {
"edges": [
{"node": {"firstName": "Jane", "lastName": "Roe"}},
]
}
} }
assert result.data == expected assert result.data == expected
def test_connection_should_succeed_if_last_higher_than_number_of_objects():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
fields = "__all__"
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
schema = graphene.Schema(query=Query)
query = """
query ReporterPromiseConnectionQuery ($last: Int) {
allReporters(last: $last) {
edges {
node {
firstName
lastName
}
}
}
}
"""
result = schema.execute(query, variable_values=dict(last=2))
assert not result.errors
expected = {"allReporters": {"edges": []}}
assert result.data == expected
Reporter.objects.create(first_name="John", last_name="Doe")
Reporter.objects.create(first_name="Some", last_name="Guy")
Reporter.objects.create(first_name="Jane", last_name="Roe")
Reporter.objects.create(first_name="Some", last_name="Lady")
result = schema.execute(query, variable_values=dict(last=2))
assert not result.errors
expected = {
"allReporters": {
"edges": [
{"node": {"firstName": "Jane", "lastName": "Roe"}},
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
result = schema.execute(query, variable_values=dict(last=4))
assert not result.errors
expected = {
"allReporters": {
"edges": [
{"node": {"firstName": "John", "lastName": "Doe"}},
{"node": {"firstName": "Some", "lastName": "Guy"}},
{"node": {"firstName": "Jane", "lastName": "Roe"}},
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
result = schema.execute(query, variable_values=dict(last=20))
assert not result.errors
expected = {
"allReporters": {
"edges": [
{"node": {"firstName": "John", "lastName": "Doe"}},
{"node": {"firstName": "Some", "lastName": "Guy"}},
{"node": {"firstName": "Jane", "lastName": "Roe"}},
{"node": {"firstName": "Some", "lastName": "Lady"}},
]
}
}
assert result.data == expected
def test_should_query_nullable_foreign_key():
class PetType(DjangoObjectType):
class Meta:
model = Pet
class PersonType(DjangoObjectType):
class Meta:
model = Person
class Query(graphene.ObjectType):
pet = graphene.Field(PetType, name=graphene.String(required=True))
person = graphene.Field(PersonType, name=graphene.String(required=True))
def resolve_pet(self, info, name):
return Pet.objects.filter(name=name).first()
def resolve_person(self, info, name):
return Person.objects.filter(name=name).first()
schema = graphene.Schema(query=Query)
person = Person.objects.create(name="Jane")
pets = [
Pet.objects.create(name="Stray dog", age=1),
Pet.objects.create(name="Jane's dog", owner=person, age=1),
]
query_pet = """
query getPet($name: String!) {
pet(name: $name) {
owner {
name
}
}
}
"""
result = schema.execute(query_pet, variables={"name": "Stray dog"})
assert not result.errors
assert result.data["pet"] == {
"owner": None,
}
result = schema.execute(query_pet, variables={"name": "Jane's dog"})
assert not result.errors
assert result.data["pet"] == {
"owner": {"name": "Jane"},
}
query_owner = """
query getOwner($name: String!) {
person(name: $name) {
pets {
name
}
}
}
"""
result = schema.execute(query_owner, variables={"name": "Jane"})
assert not result.errors
assert result.data["person"] == {
"pets": [{"name": "Jane's dog"}],
}

View File

@ -1,4 +1,4 @@
from py.test import raises from pytest import raises
from ..registry import Registry from ..registry import Registry
from ..types import DjangoObjectType from ..types import DjangoObjectType

View File

@ -3,7 +3,7 @@ from textwrap import dedent
import pytest import pytest
from django.db import models from django.db import models
from mock import patch from unittest.mock import patch
from graphene import Connection, Field, Interface, ObjectType, Schema, String from graphene import Connection, Field, Interface, ObjectType, Schema, String
from graphene.relay import Node from graphene.relay import Node
@ -104,7 +104,7 @@ def test_django_objecttype_with_custom_meta():
@classmethod @classmethod
def __init_subclass_with_meta__(cls, **options): def __init_subclass_with_meta__(cls, **options):
options.setdefault("_meta", ArticleTypeOptions(cls)) options.setdefault("_meta", ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(**options) super().__init_subclass_with_meta__(**options)
class Article(ArticleType): class Article(ArticleType):
class Meta: class Meta:
@ -183,7 +183,7 @@ def test_schema_representation():
pets: [Reporter!]! pets: [Reporter!]!
aChoice: TestsReporterAChoiceChoices aChoice: TestsReporterAChoiceChoices
reporterType: TestsReporterReporterTypeChoices reporterType: TestsReporterReporterTypeChoices
articles(offset: Int = null, before: String = null, after: String = null, first: Int = null, last: Int = null): ArticleConnection! articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection!
} }
\"""An enumeration.\""" \"""An enumeration.\"""
@ -244,8 +244,7 @@ def test_schema_representation():
\"""The ID of the object\""" \"""The ID of the object\"""
id: ID! id: ID!
): Node ): Node
} }"""
"""
) )
assert str(schema) == expected assert str(schema) == expected
@ -485,7 +484,7 @@ def test_django_objecttype_neither_fields_nor_exclude():
def custom_enum_name(field): def custom_enum_name(field):
return "CustomEnum{}".format(field.name.title()) return f"CustomEnum{field.name.title()}"
class TestDjangoObjectType: class TestDjangoObjectType:
@ -525,8 +524,7 @@ class TestDjangoObjectType:
id: ID! id: ID!
kind: String! kind: String!
cuteness: Int! cuteness: Int!
} }"""
"""
) )
def test_django_objecttype_convert_choices_enum_list(self, PetModel): def test_django_objecttype_convert_choices_enum_list(self, PetModel):
@ -560,8 +558,7 @@ class TestDjangoObjectType:
\"""Dog\""" \"""Dog\"""
DOG DOG
} }"""
"""
) )
def test_django_objecttype_convert_choices_enum_empty_list(self, PetModel): def test_django_objecttype_convert_choices_enum_empty_list(self, PetModel):
@ -586,8 +583,7 @@ class TestDjangoObjectType:
id: ID! id: ID!
kind: String! kind: String!
cuteness: Int! cuteness: Int!
} }"""
"""
) )
def test_django_objecttype_convert_choices_enum_naming_collisions( def test_django_objecttype_convert_choices_enum_naming_collisions(
@ -621,8 +617,7 @@ class TestDjangoObjectType:
\"""Dog\""" \"""Dog\"""
DOG DOG
} }"""
"""
) )
def test_django_objecttype_choices_custom_enum_name( def test_django_objecttype_choices_custom_enum_name(
@ -660,8 +655,7 @@ class TestDjangoObjectType:
\"""Dog\""" \"""Dog\"""
DOG DOG
} }"""
"""
) )
@ -671,6 +665,7 @@ def test_django_objecttype_name_connection_propagation():
class Meta: class Meta:
model = ReporterModel model = ReporterModel
name = "CustomReporterName" name = "CustomReporterName"
fields = "__all__"
filter_fields = ["email"] filter_fields = ["email"]
interfaces = (Node,) interfaces = (Node,)

View File

@ -2,7 +2,7 @@ import json
import pytest import pytest
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy
from mock import patch from unittest.mock import patch
from ..utils import camelize, get_model_fields, GraphQLTestCase from ..utils import camelize, get_model_fields, GraphQLTestCase
from .models import Film, Reporter from .models import Film, Reporter
@ -11,11 +11,11 @@ from ..utils.testing import graphql_query
def test_get_model_fields_no_duplication(): def test_get_model_fields_no_duplication():
reporter_fields = get_model_fields(Reporter) reporter_fields = get_model_fields(Reporter)
reporter_name_set = set([field[0] for field in reporter_fields]) reporter_name_set = {field[0] for field in reporter_fields}
assert len(reporter_fields) == len(reporter_name_set) assert len(reporter_fields) == len(reporter_name_set)
film_fields = get_model_fields(Film) film_fields = get_model_fields(Film)
film_name_set = set([field[0] for field in film_fields]) film_name_set = {field[0] for field in film_fields}
assert len(film_fields) == len(film_name_set) assert len(film_fields) == len(film_name_set)
@ -54,7 +54,7 @@ def test_graphql_test_case_operation_name(post_mock):
tc._pre_setup() tc._pre_setup()
tc.setUpClass() tc.setUpClass()
tc.query("query { }", operation_name="QueryName") tc.query("query { }", operation_name="QueryName")
body = json.loads(post_mock.call_args.args[1]) body = json.loads(post_mock.call_args[0][1])
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request # `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
assert ( assert (
"operationName", "operationName",
@ -66,7 +66,7 @@ def test_graphql_test_case_operation_name(post_mock):
@patch("graphene_django.utils.testing.Client.post") @patch("graphene_django.utils.testing.Client.post")
def test_graphql_query_case_operation_name(post_mock): def test_graphql_query_case_operation_name(post_mock):
graphql_query("query { }", operation_name="QueryName") graphql_query("query { }", operation_name="QueryName")
body = json.loads(post_mock.call_args.args[1]) body = json.loads(post_mock.call_args[0][1])
# `operationName` field from https://graphql.org/learn/serving-over-http/#post-request # `operationName` field from https://graphql.org/learn/serving-over-http/#post-request
assert ( assert (
"operationName", "operationName",
@ -83,6 +83,6 @@ def client_query(client):
def test_pytest_fixture_usage(client_query): def test_pytest_fixture_usage(client_query):
response = graphql_query("query { test }") response = client_query("query { test }")
content = json.loads(response.content) content = json.loads(response.content)
assert content == {"data": {"test": "Hello World"}} assert content == {"data": {"test": "Hello World"}}

View File

@ -2,7 +2,7 @@ import json
import pytest import pytest
from mock import patch from unittest.mock import patch
from django.db import connection from django.db import connection
@ -109,12 +109,10 @@ def test_reports_validation_errors(client):
{ {
"message": "Cannot query field 'unknownOne' on type 'QueryRoot'.", "message": "Cannot query field 'unknownOne' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 9}], "locations": [{"line": 1, "column": 9}],
"path": None,
}, },
{ {
"message": "Cannot query field 'unknownTwo' on type 'QueryRoot'.", "message": "Cannot query field 'unknownTwo' on type 'QueryRoot'.",
"locations": [{"line": 1, "column": 21}], "locations": [{"line": 1, "column": 21}],
"path": None,
}, },
] ]
} }
@ -135,8 +133,6 @@ def test_errors_when_missing_operation_name(client):
"errors": [ "errors": [
{ {
"message": "Must provide operation name if query contains multiple operations.", "message": "Must provide operation name if query contains multiple operations.",
"locations": None,
"path": None,
} }
] ]
} }
@ -477,7 +473,6 @@ def test_handles_syntax_errors_caught_by_graphql(client):
{ {
"locations": [{"column": 1, "line": 1}], "locations": [{"column": 1, "line": 1}],
"message": "Syntax Error: Unexpected Name 'syntaxerror'.", "message": "Syntax Error: Unexpected Name 'syntaxerror'.",
"path": None,
} }
] ]
} }
@ -512,7 +507,7 @@ def test_handles_invalid_json_bodies(client):
def test_handles_django_request_error(client, monkeypatch): def test_handles_django_request_error(client, monkeypatch):
def mocked_read(*args): def mocked_read(*args):
raise IOError("foo-bar") raise OSError("foo-bar")
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read) monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)

View File

@ -1,8 +1,8 @@
from django.conf.urls import url from django.urls import path
from ..views import GraphQLView from ..views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r"^graphql/batch", GraphQLView.as_view(batch=True)), path("graphql/batch", GraphQLView.as_view(batch=True)),
url(r"^graphql", GraphQLView.as_view(graphiql=True)), path("graphql", GraphQLView.as_view(graphiql=True)),
] ]

View File

@ -1,4 +1,4 @@
from django.conf.urls import url from django.urls import path
from ..views import GraphQLView from ..views import GraphQLView
from .schema_view import schema from .schema_view import schema
@ -10,4 +10,4 @@ class CustomGraphQLView(GraphQLView):
pretty = True pretty = True
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())] urlpatterns = [path("graphql/inherited/", CustomGraphQLView.as_view())]

View File

@ -1,6 +1,6 @@
from django.conf.urls import url from django.urls import path
from ..views import GraphQLView from ..views import GraphQLView
from .schema_view import schema from .schema_view import schema
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))] urlpatterns = [path("graphql", GraphQLView.as_view(schema=schema, pretty=True))]

View File

@ -2,11 +2,8 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Type from typing import Type
from django.db.models import Model
from django.utils.functional import SimpleLazyObject
import graphene import graphene
from graphene import Field from django.db.models import Model
from graphene.relay import Connection, Node from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs from graphene.types.utils import yank_fields_from_attrs
@ -21,7 +18,6 @@ from .utils import (
is_valid_django_model, is_valid_django_model,
) )
ALL_FIELDS = "__all__" ALL_FIELDS = "__all__"
@ -108,12 +104,7 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
( (
'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. ' 'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. '
'Either remove the custom field or remove the field from the "exclude" list.' 'Either remove the custom field or remove the field from the "exclude" list.'
).format( ).format(field_name=name, type_=type_)
field_name=name,
app_label=model._meta.app_label,
object_name=model._meta.object_name,
type_=type_,
)
) )
else: else:
if not hasattr(model, name): if not hasattr(model, name):
@ -131,7 +122,7 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
class DjangoObjectTypeOptions(ObjectTypeOptions): class DjangoObjectTypeOptions(ObjectTypeOptions):
model = None # type: Model model = None # type: Type[Model]
registry = None # type: Registry registry = None # type: Registry
connection = None # type: Type[Connection] connection = None # type: Type[Connection]
@ -177,11 +168,9 @@ class DjangoObjectType(ObjectType):
if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class): if not DJANGO_FILTER_INSTALLED and (filter_fields or filterset_class):
raise Exception( raise Exception(
(
"Can only set filter_fields or filterset_class if " "Can only set filter_fields or filterset_class if "
"Django-Filter is installed" "Django-Filter is installed"
) )
)
assert not (fields and exclude), ( assert not (fields and exclude), (
"Cannot set both 'fields' and 'exclude' options on " "Cannot set both 'fields' and 'exclude' options on "
@ -225,19 +214,19 @@ class DjangoObjectType(ObjectType):
"Creating a DjangoObjectType without either the `fields` " "Creating a DjangoObjectType without either the `fields` "
"or the `exclude` option is deprecated. Add an explicit `fields " "or the `exclude` option is deprecated. Add an explicit `fields "
"= '__all__'` option on DjangoObjectType {class_name} to use all " "= '__all__'` option on DjangoObjectType {class_name} to use all "
"fields".format(class_name=cls.__name__,), "fields".format(class_name=cls.__name__),
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
django_fields = yank_fields_from_attrs( django_fields = yank_fields_from_attrs(
construct_fields(model, registry, fields, exclude, convert_choices_to_enum), construct_fields(model, registry, fields, exclude, convert_choices_to_enum),
_as=Field, _as=graphene.Field,
) )
if use_connection is None and interfaces: if use_connection is None and interfaces:
use_connection = any( use_connection = any(
(issubclass(interface, Node) for interface in interfaces) issubclass(interface, Node) for interface in interfaces
) )
if use_connection and not connection: if use_connection and not connection:
@ -264,7 +253,7 @@ class DjangoObjectType(ObjectType):
_meta.fields = django_fields _meta.fields = django_fields
_meta.connection = connection _meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__( super().__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options _meta=_meta, interfaces=interfaces, **options
) )

View File

@ -1,9 +1,11 @@
import json import json
import warnings import warnings
from django.test import Client, TestCase from django.test import Client, TestCase, TransactionTestCase
DEFAULT_GRAPHQL_URL = "/graphql/" from graphene_django.settings import graphene_settings
DEFAULT_GRAPHQL_URL = "/graphql"
def graphql_query( def graphql_query(
@ -19,7 +21,7 @@ def graphql_query(
Args: Args:
query (string) - GraphQL query to run query (string) - GraphQL query to run
operation_name (string) - If the query is a mutation or named query, you must operation_name (string) - If the query is a mutation or named query, you must
supply the op_name. For annon queries ("{ ... }"), supply the operation_name. For annon queries ("{ ... }"),
should be None (default). should be None (default).
input_data (dict) - If provided, the $input variable in GraphQL will be set input_data (dict) - If provided, the $input variable in GraphQL will be set
to this value. If both ``input_data`` and ``variables``, to this value. If both ``input_data`` and ``variables``,
@ -28,7 +30,9 @@ def graphql_query(
variables (dict) - If provided, the "variables" field in GraphQL will be variables (dict) - If provided, the "variables" field in GraphQL will be
set to this value. set to this value.
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
will be set to this value. will be set to this value. Keys should be prepended with
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
use "HTTP_AUTHORIZATION" as the key).
client (django.test.Client) - Test client. Defaults to django.test.Client. client (django.test.Client) - Test client. Defaults to django.test.Client.
graphql_url (string) - URL to graphql endpoint. Defaults to "/graphql". graphql_url (string) - URL to graphql endpoint. Defaults to "/graphql".
@ -38,7 +42,7 @@ def graphql_query(
if client is None: if client is None:
client = Client() client = Client()
if not graphql_url: if not graphql_url:
graphql_url = DEFAULT_GRAPHQL_URL graphql_url = graphene_settings.TESTING_ENDPOINT
body = {"query": query} body = {"query": query}
if operation_name: if operation_name:
@ -61,13 +65,13 @@ def graphql_query(
return resp return resp
class GraphQLTestCase(TestCase): class GraphQLTestMixin:
""" """
Based on: https://www.sam.today/blog/testing-graphql-with-graphene-django/ Based on: https://www.sam.today/blog/testing-graphql-with-graphene-django/
""" """
# URL to graphql endpoint # URL to graphql endpoint
GRAPHQL_URL = DEFAULT_GRAPHQL_URL GRAPHQL_URL = graphene_settings.TESTING_ENDPOINT
def query( def query(
self, query, operation_name=None, input_data=None, variables=None, headers=None self, query, operation_name=None, input_data=None, variables=None, headers=None
@ -76,7 +80,7 @@ class GraphQLTestCase(TestCase):
Args: Args:
query (string) - GraphQL query to run query (string) - GraphQL query to run
operation_name (string) - If the query is a mutation or named query, you must operation_name (string) - If the query is a mutation or named query, you must
supply the op_name. For annon queries ("{ ... }"), supply the operation_name. For annon queries ("{ ... }"),
should be None (default). should be None (default).
input_data (dict) - If provided, the $input variable in GraphQL will be set input_data (dict) - If provided, the $input variable in GraphQL will be set
to this value. If both ``input_data`` and ``variables``, to this value. If both ``input_data`` and ``variables``,
@ -85,7 +89,9 @@ class GraphQLTestCase(TestCase):
variables (dict) - If provided, the "variables" field in GraphQL will be variables (dict) - If provided, the "variables" field in GraphQL will be
set to this value. set to this value.
headers (dict) - If provided, the headers in POST request to GRAPHQL_URL headers (dict) - If provided, the headers in POST request to GRAPHQL_URL
will be set to this value. will be set to this value. Keys should be prepended with
"HTTP_" (e.g. to specify the "Authorization" HTTP header,
use "HTTP_AUTHORIZATION" as the key).
Returns: Returns:
Response object from client Response object from client
@ -139,3 +145,11 @@ class GraphQLTestCase(TestCase):
""" """
content = json.loads(resp.content) content = json.loads(resp.content)
self.assertIn("errors", list(content.keys()), msg or content) self.assertIn("errors", list(content.keys()), msg or content)
class GraphQLTestCase(GraphQLTestMixin, TestCase):
pass
class GraphQLTransactionTestCase(GraphQLTestMixin, TransactionTestCase):
pass

Some files were not shown because too many files have changed in this diff Show More