Merge branch 'master' into convert-geometry-field

This commit is contained in:
Firas K 2019-03-19 20:21:35 +02:00 committed by GitHub
commit ec73e916b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
89 changed files with 2805 additions and 1800 deletions

View File

@ -11,9 +11,6 @@ install:
pip install -e .[test] pip install -e .[test]
pip install psycopg2 # Required for Django postgres fields testing pip install psycopg2 # Required for Django postgres fields testing
pip install django==$DJANGO_VERSION pip install django==$DJANGO_VERSION
if (($(echo "$DJANGO_VERSION <= 1.9" | bc -l))); then # DRF dropped 1.8 and 1.9 support at 3.7.0
pip install djangorestframework==3.6.4
fi
python setup.py develop python setup.py develop
elif [ "$TEST_TYPE" = lint ]; then elif [ "$TEST_TYPE" = lint ]; then
pip install flake8 pip install flake8
@ -38,13 +35,19 @@ env:
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
- python: '3.4'
env: TEST_TYPE=build DJANGO_VERSION=2.0
- python: '3.5'
env: TEST_TYPE=build DJANGO_VERSION=2.0
- python: '3.6'
env: TEST_TYPE=build DJANGO_VERSION=2.0
- python: '3.5'
env: TEST_TYPE=build DJANGO_VERSION=2.1
- python: '3.6'
env: TEST_TYPE=build DJANGO_VERSION=2.1
- python: '2.7' - python: '2.7'
env: TEST_TYPE=build DJANGO_VERSION=1.8 env: TEST_TYPE=lint
- python: '2.7' - python: '3.6'
env: TEST_TYPE=build DJANGO_VERSION=1.9
- python: '2.7'
env: TEST_TYPE=build DJANGO_VERSION=1.10
- python: '2.7'
env: TEST_TYPE=lint env: TEST_TYPE=lint
deploy: deploy:
provider: pypi provider: pypi

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016-Present Syrus Akbary
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,2 +1,2 @@
include README.md include README.md LICENSE
recursive-include graphene_django/templates * recursive-include graphene_django/templates *

View File

@ -20,6 +20,7 @@ pip install "graphene-django>=2.0"
```python ```python
INSTALLED_APPS = ( INSTALLED_APPS = (
# ... # ...
'django.contrib.staticfiles', # Required for GraphiQL
'graphene_django', 'graphene_django',
) )

View File

@ -20,7 +20,7 @@ Let's use a simple example model.
Limiting Field Access Limiting Field Access
--------------------- ---------------------
This is easy, simply use the ``only_fields`` meta attribute. To limit fields in a GraphQL query simply use the ``only_fields`` meta attribute.
.. code:: python .. code:: python
@ -61,10 +61,11 @@ define a resolve method for that field and return the desired queryset.
from .models import Post from .models import Post
class Query(ObjectType): class Query(ObjectType):
all_posts = DjangoFilterConnectionField(CategoryNode) all_posts = DjangoFilterConnectionField(PostNode)
def resolve_all_posts(self, info):
return Post.objects.filter(published=True)
def resolve_all_posts(self, args, info):
return Post.objects.filter(published=True)
User-based Queryset Filtering User-based Queryset Filtering
----------------------------- -----------------------------
@ -79,7 +80,7 @@ with the context argument.
from .models import Post from .models import Post
class Query(ObjectType): class Query(ObjectType):
my_posts = DjangoFilterConnectionField(CategoryNode) my_posts = DjangoFilterConnectionField(PostNode)
def resolve_my_posts(self, info): def resolve_my_posts(self, info):
# context will reference to the Django request # context will reference to the Django request
@ -95,7 +96,7 @@ schema is simple.
result = schema.execute(query, context_value=request) result = schema.execute(query, context_value=request)
Filtering ID-based node access Filtering ID-based Node Access
------------------------------ ------------------------------
In order to add authorization to id-based node access, we need to add a In order to add authorization to id-based node access, we need to add a
@ -113,23 +114,25 @@ method to your ``DjangoObjectType``.
interfaces = (relay.Node, ) interfaces = (relay.Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
try: try:
post = cls._meta.model.objects.get(id=id) post = cls._meta.model.objects.get(id=id)
except cls._meta.model.DoesNotExist: except cls._meta.model.DoesNotExist:
return None return None
if post.published or context.user == post.owner: if post.published or info.context.user == post.owner:
return post return post
return None return None
Adding login required
Adding Login Required
--------------------- ---------------------
If you want to use the standard Django LoginRequiredMixin_ you can create your own view, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``: To restrict users from accessing the GraphQL API page the standard Django LoginRequiredMixin_ can be used to create your own standard Django Class Based View, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``.:
.. code:: python .. code:: python
#views.py
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from graphene_django.views import GraphQLView from graphene_django.views import GraphQLView
@ -137,7 +140,9 @@ If you want to use the standard Django LoginRequiredMixin_ you can create your o
class PrivateGraphQLView(LoginRequiredMixin, GraphQLView): class PrivateGraphQLView(LoginRequiredMixin, GraphQLView):
pass pass
After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``: After this, you can use the new ``PrivateGraphQLView`` in the project's URL Configuration file ``url.py``:
For Django 1.9 and below:
.. code:: python .. code:: python
@ -145,5 +150,14 @@ After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``:
# some other urls # some other urls
url(r'^graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)), url(r'^graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
] ]
For Django 2.0 and above:
.. code:: python
urlpatterns = [
# some other urls
path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
]
.. _LoginRequiredMixin: https://docs.djangoproject.com/en/1.10/topics/auth/default/#the-loginrequired-mixin .. _LoginRequiredMixin: https://docs.djangoproject.com/en/1.10/topics/auth/default/#the-loginrequired-mixin

View File

@ -2,9 +2,9 @@ Filtering
========= =========
Graphene integrates with Graphene integrates with
`django-filter <https://django-filter.readthedocs.org>`__ to provide `django-filter <https://django-filter.readthedocs.io/en/master/>`__ (2.x for
filtering of results. See the `usage Python 3 or 1.x for Python 2) to provide filtering of results. See the `usage
documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__ documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__
for details on the format for ``filter_fields``. for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``. This filtering is automatically available when implementing a ``relay.Node``.
@ -15,7 +15,7 @@ You will need to install it manually, which can be done as follows:
.. code:: bash .. code:: bash
# You'll need to django-filter # You'll need to django-filter
pip install django-filter pip install django-filter>=2
Note: The techniques below are demoed in the `cookbook example Note: The techniques below are demoed in the `cookbook example
app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__. app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__.
@ -26,7 +26,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering ``django-filter``, so see the `filtering
documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__ documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__
for full details on the range of options available. for full details on the range of options available.
For example: For example:
@ -127,7 +127,7 @@ create your own ``Filterset`` as follows:
all_animals = DjangoFilterConnectionField(AnimalNode, all_animals = DjangoFilterConnectionField(AnimalNode,
filterset_class=AnimalFilter) filterset_class=AnimalFilter)
The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/latest/guide/usage.html#request-based-filtering>`__ The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/master/guide/usage.html#request-based-filtering>`__
in a ``django_filters.FilterSet`` instance. You can use this to customize your in a ``django_filters.FilterSet`` instance. You can use this to customize your
filters to be context-dependent. We could modify the ``AnimalFilter`` above to filters to be context-dependent. We could modify the ``AnimalFilter`` above to
pre-filter animals owned by the authenticated user (set in ``context.user``). pre-filter animals owned by the authenticated user (set in ``context.user``).
@ -145,4 +145,4 @@ pre-filter animals owned by the authenticated user (set in ``context.user``).
@property @property
def qs(self): def qs(self):
# The query context can be found in self.request. # The query context can be found in self.request.
return super(AnimalFilter, self).filter(owner=self.request.user) return super(AnimalFilter, self).qs.filter(owner=self.request.user)

68
docs/form-mutations.rst Normal file
View File

@ -0,0 +1,68 @@
Integration with Django forms
=============================
Graphene-Django comes with mutation classes that will convert the fields on Django forms into inputs on a mutation.
*Note: the API is experimental and will likely change in the future.*
FormMutation
------------
.. code:: python
class MyForm(forms.Form):
name = forms.CharField()
class MyMutation(FormMutation):
class Meta:
form_class = MyForm
``MyMutation`` will automatically receive an ``input`` argument. This argument should be a ``dict`` where the key is ``name`` and the value is a string.
ModelFormMutation
-----------------
``ModelFormMutation`` will pull the fields from a ``ModelForm``.
.. code:: python
class Pet(models.Model):
name = models.CharField()
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = ('name',)
# This will get returned when the mutation completes successfully
class PetType(DjangoObjectType):
class Meta:
model = Pet
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
``PetMutation`` will grab the fields from ``PetForm`` and turn them into inputs. If the form is valid then the mutation
will lookup the ``DjangoObjectType`` for the ``Pet`` model and return that under the key ``pet``. Otherwise it will
return a list of errors.
You can change the input name (default is ``input``) and the return field name (default is the model name lowercase).
.. code:: python
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
input_field_name = 'data'
return_field_name = 'my_pet'
Form validation
---------------
Form mutations will call ``is_valid()`` on your forms.
If the form is valid then ``form_valid(form, info)`` is called on the mutation. Override this method to change how
the form is saved or to return a different Graphene object type.
If the form is *not* valid then a list of errors will be returned. These errors have two fields: ``field``, a string
containing the name of the invalid form field, and ``messages``, a list of strings with the validation messages.

View File

@ -12,4 +12,5 @@ Contents:
authorization authorization
debug debug
rest-framework rest-framework
form-mutations
introspection introspection

View File

@ -1,3 +1,3 @@
sphinx sphinx
# Docs template # Docs template
https://github.com/graphql-python/graphene-python.org/archive/docs.zip http://graphene-python.org/sphinx_graphene_theme.zip

View File

@ -19,3 +19,46 @@ You can create a Mutation based on a serializer by using the
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
Create/Update Operations
---------------------
By default ModelSerializers accept create and update operations. To
customize this use the `model_operations` attribute. The update
operation looks up models by the primary key by default. You can
customize the look up with the lookup attribute.
.. code:: python
from graphene_django.rest_framework.mutation import SerializerMutation
class AwesomeModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['create', 'update']
lookup_field = 'id'
Overriding Update Queries
-------------------------
Use the method `get_serializer_kwargs` to override how
updates are applied.
.. code:: python
from graphene_django.rest_framework.mutation import SerializerMutation
class AwesomeModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
if 'id' in input:
instance = Post.objects.filter(id=input['id'], owner=info.context.user).first()
if instance:
return {'instance': instance, 'data': input, 'partial': True}
else:
raise http.Http404
return {'data': input, 'partial': True}

View File

@ -68,8 +68,8 @@ Let's get started with these models:
class Ingredient(models.Model): class Ingredient(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
notes = models.TextField() notes = models.TextField()
category = models.ForeignKey(Category, related_name='ingredients', category = models.ForeignKey(
on_delete=models.CASCADE) Category, related_name='ingredients', on_delete=models.CASCADE)
def __str__(self): def __str__(self):
return self.name return self.name
@ -84,6 +84,7 @@ Add ingredients as INSTALLED_APPS:
'cookbook.ingredients', 'cookbook.ingredients',
] ]
Don't forget to create & run migrations: Don't forget to create & run migrations:
.. code:: bash .. code:: bash
@ -112,6 +113,18 @@ Alternatively you can use the Django admin interface to create some data
yourself. You'll need to run the development server (see below), and yourself. You'll need to run the development server (see below), and
create a login for yourself too (``./manage.py createsuperuser``). create a login for yourself too (``./manage.py createsuperuser``).
Register models with admin panel:
.. code:: python
# cookbook/ingredients/admin.py
from django.contrib import admin
from cookbook.ingredients.models import Category, Ingredient
admin.site.register(Category)
admin.site.register(Ingredient)
Hello GraphQL - Schema and Object Types Hello GraphQL - Schema and Object Types
--------------------------------------- ---------------------------------------
@ -166,9 +179,9 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
return Ingredient.objects.select_related('category').all() return Ingredient.objects.select_related('category').all()
Note that the above ``Query`` class is marked as 'abstract'. This is Note that the above ``Query`` class is a mixin, inheriting from
because we will now create a project-level query which will combine all ``object``. This is because we will now create a project-level query
our app-level queries. class which will combine all our app-level mixins.
Create the parent project-level ``cookbook/schema.py``: Create the parent project-level ``cookbook/schema.py``:

View File

@ -10,7 +10,7 @@ app <https://github.com/graphql-python/graphene-django/tree/master/examples/cook
A good idea is to check the following things first: A good idea is to check the following things first:
* `Graphene Relay documentation <http://docs.graphene-python.org/en/latest/relay/>`__ * `Graphene Relay documentation <http://docs.graphene-python.org/en/latest/relay/>`__
* `GraphQL Relay Specification <https://facebook.github.io/relay/docs/graphql-relay-specification.html>`__ * `GraphQL Relay Specification <https://facebook.github.io/relay/docs/en/graphql-server-specification.html>`__
Setup the Django project Setup the Django project
------------------------ ------------------------
@ -118,7 +118,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
.. code:: python .. code:: python
# cookbook/ingredients/schema.py # cookbook/ingredients/schema.py
from graphene import relay, ObjectType, AbstractType from graphene import relay, ObjectType
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
@ -147,7 +147,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
interfaces = (relay.Node, ) interfaces = (relay.Node, )
class Query(AbstractType): class Query(object):
category = relay.Node.Field(CategoryNode) category = relay.Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -3,7 +3,7 @@ Cookbook Example Django Project
This example project demos integration between Graphene and Django. This example project demos integration between Graphene and Django.
The project contains two apps, one named `ingredients` and another The project contains two apps, one named `ingredients` and another
named `recepies`. named `recipes`.
Getting started Getting started
--------------- ---------------

View File

@ -0,0 +1,17 @@
# Generated by Django 2.0 on 2018-10-18 17:46
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('ingredients', '0002_auto_20161104_0050'),
]
operations = [
migrations.AlterModelOptions(
name='category',
options={'verbose_name_plural': 'Categories'},
),
]

View File

@ -2,6 +2,8 @@ from django.db import models
class Category(models.Model): class Category(models.Model):
class Meta:
verbose_name_plural = 'Categories'
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
def __str__(self): def __str__(self):
@ -11,7 +13,7 @@ class Category(models.Model):
class Ingredient(models.Model): class Ingredient(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
notes = models.TextField(null=True, blank=True) notes = models.TextField(null=True, blank=True)
category = models.ForeignKey(Category, related_name='ingredients') category = models.ForeignKey(Category, related_name='ingredients', on_delete=models.CASCADE)
def __str__(self): def __str__(self):
return self.name return self.name

View File

@ -1,7 +1,7 @@
import graphene import graphene
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
from cookbook.ingredients.models import Category, Ingredient from .models import Category, Ingredient
class CategoryType(DjangoObjectType): class CategoryType(DjangoObjectType):
@ -14,7 +14,7 @@ class IngredientType(DjangoObjectType):
model = Ingredient model = Ingredient
class Query(graphene.AbstractType): class Query(object):
category = graphene.Field(CategoryType, category = graphene.Field(CategoryType,
id=graphene.Int(), id=graphene.Int(),
name=graphene.String()) name=graphene.String())
@ -25,17 +25,14 @@ class Query(graphene.AbstractType):
name=graphene.String()) name=graphene.String())
all_ingredients = graphene.List(IngredientType) all_ingredients = graphene.List(IngredientType)
def resolve_all_categories(self, args, context, info): def resolve_all_categories(self, context):
return Category.objects.all() return Category.objects.all()
def resolve_all_ingredients(self, args, context, info): def resolve_all_ingredients(self, context):
# We can easily optimize query count in the resolve method # We can easily optimize query count in the resolve method
return Ingredient.objects.select_related('category').all() return Ingredient.objects.select_related('category').all()
def resolve_category(self, args, context, info): def resolve_category(self, context, id=None, name=None):
id = args.get('id')
name = args.get('name')
if id is not None: if id is not None:
return Category.objects.get(pk=id) return Category.objects.get(pk=id)
@ -44,10 +41,7 @@ class Query(graphene.AbstractType):
return None return None
def resolve_ingredient(self, args, context, info): def resolve_ingredient(self, context, id=None, name=None):
id = args.get('id')
name = args.get('name')
if id is not None: if id is not None:
return Ingredient.objects.get(pk=id) return Ingredient.objects.get(pk=id)

View File

@ -0,0 +1,2 @@
# Create your tests here.

View File

@ -0,0 +1,2 @@
# Create your views here.

View File

@ -0,0 +1,18 @@
# Generated by Django 2.0 on 2018-10-18 17:28
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('recipes', '0002_auto_20161104_0106'),
]
operations = [
migrations.AlterField(
model_name='recipeingredient',
name='unit',
field=models.CharField(choices=[('unit', 'Units'), ('kg', 'Kilograms'), ('l', 'Litres'), ('st', 'Shots')], max_length=20),
),
]

View File

@ -1,16 +1,18 @@
from django.db import models from django.db import models
from cookbook.ingredients.models import Ingredient from ..ingredients.models import Ingredient
class Recipe(models.Model): class Recipe(models.Model):
title = models.CharField(max_length=100) title = models.CharField(max_length=100)
instructions = models.TextField() instructions = models.TextField()
def __str__(self):
return self.title
class RecipeIngredient(models.Model): class RecipeIngredient(models.Model):
recipe = models.ForeignKey(Recipe, related_name='amounts') recipe = models.ForeignKey(Recipe, related_name='amounts', on_delete=models.CASCADE)
ingredient = models.ForeignKey(Ingredient, related_name='used_by') ingredient = models.ForeignKey(Ingredient, related_name='used_by', on_delete=models.CASCADE)
amount = models.FloatField() amount = models.FloatField()
unit = models.CharField(max_length=20, choices=( unit = models.CharField(max_length=20, choices=(
('unit', 'Units'), ('unit', 'Units'),

View File

@ -1,7 +1,7 @@
import graphene import graphene
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
from cookbook.recipes.models import Recipe, RecipeIngredient from .models import Recipe, RecipeIngredient
class RecipeType(DjangoObjectType): class RecipeType(DjangoObjectType):
@ -14,7 +14,7 @@ class RecipeIngredientType(DjangoObjectType):
model = RecipeIngredient model = RecipeIngredient
class Query(graphene.AbstractType): class Query(object):
recipe = graphene.Field(RecipeType, recipe = graphene.Field(RecipeType,
id=graphene.Int(), id=graphene.Int(),
title=graphene.String()) title=graphene.String())
@ -24,10 +24,7 @@ class Query(graphene.AbstractType):
id=graphene.Int()) id=graphene.Int())
all_recipeingredients = graphene.List(RecipeIngredientType) all_recipeingredients = graphene.List(RecipeIngredientType)
def resolve_recipe(self, args, context, info): def resolve_recipe(self, context, id=None, title=None):
id = args.get('id')
title = args.get('title')
if id is not None: if id is not None:
return Recipe.objects.get(pk=id) return Recipe.objects.get(pk=id)
@ -36,17 +33,15 @@ class Query(graphene.AbstractType):
return None return None
def resolve_recipeingredient(self, args, context, info): def resolve_recipeingredient(self, context, id=None):
id = args.get('id')
if id is not None: if id is not None:
return RecipeIngredient.objects.get(pk=id) return RecipeIngredient.objects.get(pk=id)
return None return None
def resolve_all_recipes(self, args, context, info): def resolve_all_recipes(self, context):
return Recipe.objects.all() return Recipe.objects.all()
def resolve_all_recipeingredients(self, args, context, info): def resolve_all_recipeingredients(self, context):
related = ['recipe', 'ingredient'] related = ['recipe', 'ingredient']
return RecipeIngredient.objects.select_related(*related).all() return RecipeIngredient.objects.select_related(*related).all()

View File

@ -0,0 +1,2 @@
# Create your tests here.

View File

@ -0,0 +1,2 @@
# Create your views here.

View File

@ -44,13 +44,12 @@ INSTALLED_APPS = [
'cookbook.recipes.apps.RecipesConfig', 'cookbook.recipes.apps.RecipesConfig',
] ]
MIDDLEWARE_CLASSES = [ MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
] ]

View File

@ -1,10 +1,10 @@
from django.conf.urls import url from django.urls import path
from django.contrib import admin from django.contrib import admin
from graphene_django.views import GraphQLView from graphene_django.views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r'^admin/', admin.site.urls), path('admin/', admin.site.urls),
url(r'^graphql', GraphQLView.as_view(graphiql=True)), path('graphql/', GraphQLView.as_view(graphiql=True)),
] ]

View File

@ -1,4 +1,4 @@
graphene graphene
graphene-django graphene-django
graphql-core graphql-core>=2.1rc1
django==1.9 django==2.1.2

View File

@ -2,9 +2,11 @@ from django.contrib import admin
from cookbook.ingredients.models import Category, Ingredient from cookbook.ingredients.models import Category, Ingredient
@admin.register(Ingredient) @admin.register(Ingredient)
class IngredientAdmin(admin.ModelAdmin): class IngredientAdmin(admin.ModelAdmin):
list_display = ("id","name","category") list_display = ('id', 'name', 'category')
list_editable = ("name","category") list_editable = ('name', 'category')
admin.site.register(Category) admin.site.register(Category)

View File

@ -10,7 +10,7 @@ class Category(models.Model):
class Ingredient(models.Model): class Ingredient(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
notes = models.TextField(null=True,blank=True) notes = models.TextField(null=True, blank=True)
category = models.ForeignKey(Category, related_name='ingredients') category = models.ForeignKey(Category, related_name='ingredients')
def __str__(self): def __str__(self):

View File

@ -1,5 +1,5 @@
from cookbook.ingredients.models import Category, Ingredient from cookbook.ingredients.models import Category, Ingredient
from graphene import AbstractType, Node from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType):
} }
class Query(AbstractType): class Query(object):
category = Node.Field(CategoryNode) category = Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -2,9 +2,11 @@ from django.contrib import admin
from cookbook.recipes.models import Recipe, RecipeIngredient from cookbook.recipes.models import Recipe, RecipeIngredient
class RecipeIngredientInline(admin.TabularInline): class RecipeIngredientInline(admin.TabularInline):
model = RecipeIngredient model = RecipeIngredient
@admin.register(Recipe) @admin.register(Recipe)
class RecipeAdmin(admin.ModelAdmin): class RecipeAdmin(admin.ModelAdmin):
inlines = [RecipeIngredientInline] inlines = [RecipeIngredientInline]

View File

@ -8,6 +8,7 @@ class Recipe(models.Model):
instructions = models.TextField() instructions = models.TextField()
__unicode__ = lambda self: self.title __unicode__ = lambda self: self.title
class RecipeIngredient(models.Model): class RecipeIngredient(models.Model):
recipe = models.ForeignKey(Recipe, related_name='amounts') recipe = models.ForeignKey(Recipe, related_name='amounts')
ingredient = models.ForeignKey(Ingredient, related_name='used_by') ingredient = models.ForeignKey(Ingredient, related_name='used_by')

View File

@ -1,5 +1,5 @@
from cookbook.recipes.models import Recipe, RecipeIngredient from cookbook.recipes.models import Recipe, RecipeIngredient
from graphene import AbstractType, Node from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
@ -24,7 +24,7 @@ class RecipeIngredientNode(DjangoObjectType):
} }
class Query(AbstractType): class Query(object):
recipe = Node.Field(RecipeNode) recipe = Node.Field(RecipeNode)
all_recipes = DjangoFilterConnectionField(RecipeNode) all_recipes = DjangoFilterConnectionField(RecipeNode)

View File

@ -5,7 +5,9 @@ import graphene
from graphene_django.debug import DjangoDebug from graphene_django.debug import DjangoDebug
class Query(cookbook.recipes.schema.Query, cookbook.ingredients.schema.Query, graphene.ObjectType): class Query(cookbook.ingredients.schema.Query,
cookbook.recipes.schema.Query,
graphene.ObjectType):
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')

View File

@ -1,3 +1,4 @@
# flake8: noqa
""" """
Django settings for cookbook project. Django settings for cookbook project.

View File

@ -3,6 +3,7 @@ from django.contrib import admin
from graphene_django.views import GraphQLView from graphene_django.views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r'^admin/', admin.site.urls), url(r'^admin/', admin.site.urls),
url(r'^graphql', GraphQLView.as_view(graphiql=True)), url(r'^graphql', GraphQLView.as_view(graphiql=True)),

View File

@ -1,5 +1,5 @@
graphene graphene
graphene-django graphene-django
graphql-core graphql-core>=2.1rc1
django==1.9 django==1.9
django-filter==0.11.0 django-filter>=2

View File

@ -0,0 +1,2 @@
[flake8]
exclude=migrations,.git,__pycache__

View File

@ -5,7 +5,7 @@ from django.db import models
class Character(models.Model): class Character(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
ship = models.ForeignKey('Ship', blank=True, null=True, related_name='characters') ship = models.ForeignKey('Ship', on_delete=models.CASCADE, blank=True, null=True, related_name='characters')
def __str__(self): def __str__(self):
return self.name return self.name
@ -13,7 +13,7 @@ class Character(models.Model):
class Faction(models.Model): class Faction(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
hero = models.ForeignKey(Character) hero = models.ForeignKey(Character, on_delete=models.CASCADE)
def __str__(self): def __str__(self):
return self.name return self.name
@ -21,7 +21,7 @@ class Faction(models.Model):
class Ship(models.Model): class Ship(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
faction = models.ForeignKey(Faction, related_name='ships') faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name='ships')
def __str__(self): def __str__(self):
return self.name return self.name

View File

@ -1,14 +1,6 @@
from .types import ( from .types import DjangoObjectType
DjangoObjectType, from .fields import DjangoConnectionField
)
from .fields import (
DjangoConnectionField,
)
__version__ = '2.0.0' __version__ = "2.2.0"
__all__ = [ __all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]

View File

@ -5,13 +5,7 @@ class MissingType(object):
try: try:
# Postgres fields are only available in Django with psycopg2 installed # Postgres fields are only available in Django with psycopg2 installed
# and we cannot have psycopg2 on PyPy # and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField from django.contrib.postgres.fields import (ArrayField, HStoreField,
JSONField, RangeField)
except ImportError: except ImportError:
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4 ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
try:
# Postgres fields are only available in Django 1.9+
from django.contrib.postgres.fields import JSONField
except ImportError:
JSONField = MissingType

View File

@ -2,9 +2,22 @@ from django.contrib.gis.db.models import GeometryField
from django.db import models from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, from graphene import (
NonNull, String, UUID) ID,
from graphene.types.datetime import DateTime, Time Boolean,
Dynamic,
Enum,
Field,
Float,
Int,
List,
NonNull,
String,
UUID,
DateTime,
Date,
Time,
)
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name from graphql import assert_valid_name
@ -34,7 +47,7 @@ def get_choices(choices):
else: else:
name = convert_choice_name(value) name = convert_choice_name(value)
while name in converted_names: while name in converted_names:
name += '_' + str(len(converted_names)) name += "_" + str(len(converted_names))
converted_names.append(name) converted_names.append(name)
description = help_text description = help_text
yield name, value, description yield name, value, description
@ -45,16 +58,15 @@ def convert_django_field_with_choices(field, registry=None):
converted = registry.get_converted_field(field) converted = registry.get_converted_field(field)
if converted: if converted:
return converted return converted
choices = getattr(field, 'choices', None) choices = getattr(field, "choices", None)
if choices: if choices:
meta = field.model._meta meta = field.model._meta
name = to_camel_case('{}_{}'.format(meta.object_name, field.name)) name = to_camel_case("{}_{}".format(meta.object_name, field.name))
choices = list(get_choices(choices)) choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices] named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices} named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object): class EnumWithDescriptionsType(object):
@property @property
def description(self): def description(self):
return named_choices_descriptions[self.name] return named_choices_descriptions[self.name]
@ -71,8 +83,8 @@ def convert_django_field_with_choices(field, registry=None):
@singledispatch @singledispatch
def convert_django_field(field, registry=None): def convert_django_field(field, registry=None):
raise Exception( raise Exception(
"Don't know how to convert the Django field %s (%s)" % "Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
(field, field.__class__)) )
@convert_django_field.register(models.CharField) @convert_django_field.register(models.CharField)
@ -82,6 +94,7 @@ def convert_django_field(field, registry=None):
@convert_django_field.register(models.URLField) @convert_django_field.register(models.URLField)
@convert_django_field.register(models.GenericIPAddressField) @convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(models.FileField) @convert_django_field.register(models.FileField)
@convert_django_field.register(models.FilePathField)
@convert_django_field.register(GeometryField) @convert_django_field.register(GeometryField)
def convert_field_to_string(field, registry=None): def convert_field_to_string(field, registry=None):
return String(description=field.help_text, required=not field.null) return String(description=field.help_text, required=not field.null)
@ -123,9 +136,14 @@ def convert_field_to_float(field, registry=None):
return Float(description=field.help_text, required=not field.null) return Float(description=field.help_text, required=not field.null)
@convert_django_field.register(models.DateTimeField)
def convert_datetime_to_string(field, registry=None):
return DateTime(description=field.help_text, required=not field.null)
@convert_django_field.register(models.DateField) @convert_django_field.register(models.DateField)
def convert_date_to_string(field, registry=None): def convert_date_to_string(field, registry=None):
return DateTime(description=field.help_text, required=not field.null) return Date(description=field.help_text, required=not field.null)
@convert_django_field.register(models.TimeField) @convert_django_field.register(models.TimeField)
@ -144,7 +162,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
# We do this for a bug in Django 1.8, where null attr # We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance # is not available in the OneToOneRel instance
null = getattr(field, 'null', True) null = getattr(field, "null", True)
return Field(_type, required=not null) return Field(_type, required=not null)
return Dynamic(dynamic_type) return Dynamic(dynamic_type)
@ -168,6 +186,7 @@ def convert_field_to_list_or_connection(field, registry=None):
# defined filter_fields in the DjangoObjectType Meta # defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields: if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type) return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type) return DjangoConnectionField(_type)

View File

@ -1,4 +1,4 @@
from .middleware import DjangoDebugMiddleware from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug from .types import DjangoDebug
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug'] __all__ = ["DjangoDebugMiddleware", "DjangoDebug"]

View File

@ -7,7 +7,6 @@ from .types import DjangoDebug
class DjangoDebugContext(object): class DjangoDebugContext(object):
def __init__(self): def __init__(self):
self.debug_promise = None self.debug_promise = None
self.promises = [] self.promises = []
@ -38,20 +37,21 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object): class DjangoDebugMiddleware(object):
def resolve(self, next, root, info, **args): def resolve(self, next, root, info, **args):
context = info.context context = info.context
django_debug = getattr(context, 'django_debug', None) django_debug = getattr(context, "django_debug", None)
if not django_debug: if not django_debug:
if context is None: if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts') raise Exception("DjangoDebug cannot be executed in None contexts")
try: try:
context.django_debug = DjangoDebugContext() context.django_debug = DjangoDebugContext()
except Exception: except Exception:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format( raise Exception(
context.__class__.__name__ "DjangoDebug need the context to be writable, context received: {}.".format(
)) context.__class__.__name__
if info.schema.get_type('DjangoDebug') == info.return_type: )
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise() return context.django_debug.get_debug_promise()
promise = next(root, info, **args) promise = next(root, info, **args)
context.django_debug.add_promise(promise) context.django_debug.add_promise(promise)

View File

@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
class ThreadLocalState(local): class ThreadLocalState(local):
def __init__(self): def __init__(self):
self.enabled = True self.enabled = True
@ -35,7 +34,7 @@ recording = state.recording # export function
def wrap_cursor(connection, panel): def wrap_cursor(connection, panel):
if not hasattr(connection, '_graphene_cursor'): if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor connection._graphene_cursor = connection.cursor
def cursor(): def cursor():
@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
def unwrap_cursor(connection): def unwrap_cursor(connection):
if hasattr(connection, '_graphene_cursor'): if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor connection.cursor = previous_cursor
del connection._graphene_cursor del connection._graphene_cursor
@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
if not params: if not params:
return params return params
if isinstance(params, dict): if isinstance(params, dict):
return dict((key, self._quote_expr(value)) return dict((key, self._quote_expr(value)) for key, value in params.items())
for key, value in params.items())
return list(map(self._quote_expr, params)) return list(map(self._quote_expr, params))
def _decode(self, param): def _decode(self, param):
try: try:
return force_text(param, strings_only=True) return force_text(param, strings_only=True)
except UnicodeDecodeError: except UnicodeDecodeError:
return '(encoded string)' return "(encoded string)"
def _record(self, method, sql, params): def _record(self, method, sql, params):
start_time = time() start_time = time()
@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
return method(sql, params) return method(sql, params)
finally: finally:
stop_time = time() stop_time = time()
duration = (stop_time - start_time) duration = stop_time - start_time
_params = '' _params = ""
try: try:
_params = json.dumps(list(map(self._decode, params))) _params = json.dumps(list(map(self._decode, params)))
except Exception: except Exception:
pass # object not JSON serializable pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default') alias = getattr(self.db, "alias", "default")
conn = self.db.connection conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown') vendor = getattr(conn, "vendor", "unknown")
params = { params = {
'vendor': vendor, "vendor": vendor,
'alias': alias, "alias": alias,
'sql': self.db.ops.last_executed_query( "sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)), self.cursor, sql, self._quote_params(params)
'duration': duration, ),
'raw_sql': sql, "duration": duration,
'params': _params, "raw_sql": sql,
'start_time': start_time, "params": _params,
'stop_time': stop_time, "start_time": start_time,
'is_slow': duration > 10, "stop_time": stop_time,
'is_select': sql.lower().strip().startswith('select'), "is_slow": duration > 10,
"is_select": sql.lower().strip().startswith("select"),
} }
if vendor == 'postgresql': if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might # If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an # be in a state where checking isolation_level raises an
# exception. # exception.
try: try:
iso_level = conn.isolation_level iso_level = conn.isolation_level
except conn.InternalError: except conn.InternalError:
iso_level = 'unknown' iso_level = "unknown"
params.update({ params.update(
'trans_id': self.logger.get_transaction_id(alias), {
'trans_status': conn.get_transaction_status(), "trans_id": self.logger.get_transaction_id(alias),
'iso_level': iso_level, "trans_status": conn.get_transaction_status(),
'encoding': conn.encoding, "iso_level": iso_level,
}) "encoding": conn.encoding,
}
)
_sql = DjangoDebugSQL(**params) _sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility # We keep `sql` to maintain backwards compatibility

View File

@ -2,19 +2,53 @@ from graphene import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType): class DjangoDebugSQL(ObjectType):
vendor = String() class Meta:
alias = String() description = (
sql = String() "Represents a single database query made to a Django managed DB."
duration = Float() )
raw_sql = String()
params = String() vendor = String(
start_time = Float() required=True,
stop_time = Float() description=(
is_slow = Boolean() "The type of database being used (e.g. postrgesql, mysql, sqlite)."
is_select = Boolean() ),
)
alias = String(
required=True,
description="The Django database alias (e.g. 'default').",
)
sql = String(description="The actual SQL sent to this database.")
duration = Float(
required=True,
description="Duration of this database query in seconds.",
)
raw_sql = String(
required=True,
description="The raw SQL of this query, without params.",
)
params = String(
required=True,
description="JSON encoded database query parameters.",
)
start_time = Float(
required=True,
description="Start time of this database query.",
)
stop_time = Float(
required=True,
description="Stop time of this database query.",
)
is_slow = Boolean(
required=True,
description="Whether this database query took more than 10 seconds.",
)
is_select = Boolean(
required=True,
description="Whether this database query was a SELECT.",
)
# Postgres # Postgres
trans_id = String() trans_id = String(description="Postgres transaction ID if available.")
trans_status = String() trans_status = String(description="Postgres transaction status if available.")
iso_level = String() iso_level = String(description="Postgres isolation level if available.")
encoding = String() encoding = String(description="Postgres connection encoding if available.")

View File

@ -12,31 +12,31 @@ from ..types import DjangoDebug
class context(object): class context(object):
pass pass
# from examples.starwars_django.models import Character # from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
def test_should_query_field(): def test_should_query_field():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_reporter(self, info, **args): def resolve_reporter(self, info, **args):
return Reporter.objects.first() return Reporter.objects.first()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
reporter { reporter {
lastName lastName
@ -47,43 +47,40 @@ def test_should_query_field():
} }
} }
} }
''' """
expected = { expected = {
'reporter': { "reporter": {"lastName": "ABA"},
'lastName': 'ABA', "__debug": {
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
}, },
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
} }
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_query_list(): def test_should_query_list():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType) all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
allReporters { allReporters {
lastName lastName
@ -94,45 +91,38 @@ def test_should_query_list():
} }
} }
} }
''' """
expected = { expected = {
'allReporters': [{ "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
'lastName': 'ABA', "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.all().query)
}]
}
} }
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_query_connection(): def test_should_query_connection():
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
allReporters(first:1) { allReporters(first:1) {
edges { edges {
@ -147,48 +137,41 @@ def test_should_query_connection():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data["allReporters"] == expected["allReporters"]
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query assert result.data["__debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter(): def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name='ABA') r1 = Reporter(last_name="ABA")
r1.save() r1.save()
r2 = Reporter(last_name='Griffin') r2 = Reporter(last_name="Griffin")
r2.save() r2.save()
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name']) all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S") s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = """
query ReporterQuery { query ReporterQuery {
allReporters(first:1) { allReporters(first:1) {
edges { edges {
@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data["allReporters"] == expected["allReporters"]
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query) query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query assert result.data["__debug"]["sql"][1]["rawSql"] == query

View File

@ -4,4 +4,10 @@ from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType): class DjangoDebug(ObjectType):
sql = List(DjangoDebugSQL) class Meta:
description = "Debugging information for the current query."
sql = List(
DjangoDebugSQL,
description="Executed SQL queries for this API query.",
)

View File

@ -13,7 +13,6 @@ from .utils import maybe_queryset
class DjangoListField(Field): class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs): def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs) super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@ -30,25 +29,28 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.on = kwargs.pop('on', False) self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop( self.max_limit = kwargs.pop(
'max_limit', "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
) )
self.enforce_first_or_last = kwargs.pop( self.enforce_first_or_last = kwargs.pop(
'enforce_first_or_last', "enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
) )
super(DjangoConnectionField, self).__init__(*args, **kwargs) super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property @property
def type(self): def type(self):
from .types import DjangoObjectType from .types import DjangoObjectType
_type = super(ConnectionField, self).type _type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types" assert issubclass(
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) _type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection return _type._meta.connection
@property @property
@ -67,6 +69,10 @@ class DjangoConnectionField(ConnectionField):
@classmethod @classmethod
def merge_querysets(cls, default_queryset, queryset): def merge_querysets(cls, default_queryset, queryset):
if default_queryset.query.distinct and not queryset.query.distinct:
queryset = queryset.distinct()
elif queryset.query.distinct and not default_queryset.query.distinct:
default_queryset = default_queryset.distinct()
return queryset & default_queryset return queryset & default_queryset
@classmethod @classmethod
@ -96,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
return connection return connection
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(
enforce_first_or_last, root, info, **args): cls,
first = args.get('first') resolver,
last = args.get('last') connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last: if enforce_first_or_last:
assert first or last, ( assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.' "You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name) ).format(info.field_name)
if max_limit: if max_limit:
if first: if first:
assert first <= max_limit, ( assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit) ).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit) args["first"] = min(first, max_limit)
if last: if last:
assert last <= max_limit, ( assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit) ).format(last, info.field_name, max_limit)
args['last'] = min(last, max_limit) args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args) iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
@ -134,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
self.type, self.type,
self.get_manager(), self.get_manager(),
self.max_limit, self.max_limit,
self.enforce_first_or_last self.enforce_first_or_last,
) )

View File

@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED: if not DJANGO_FILTER_INSTALLED:
warnings.warn( warnings.warn(
"Use of django filtering requires the django-filter package " "Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`", ImportWarning "be installed. You can do so using `pip install django-filter`",
ImportWarning,
) )
else: else:
from .fields import DjangoFilterConnectionField from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = ['DjangoFilterConnectionField', __all__ = [
'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter'] "DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
]

View File

@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField): class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(
def __init__(self, type, fields=None, order_by=None, self,
extra_filter_meta=None, filterset_class=None, type,
*args, **kwargs): fields=None,
order_by=None,
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs
):
self._fields = fields self._fields = fields
self._provided_filterset_class = filterset_class self._provided_filterset_class = filterset_class
self._filterset_class = None self._filterset_class = None
@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self): def filterset_class(self):
if not self._filterset_class: if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model, meta = dict(model=self.model, fields=fields)
fields=fields)
if self._extra_filter_meta: if self._extra_filter_meta:
meta.update(self._extra_filter_meta) meta.update(self._extra_filter_meta)
self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta) self._filterset_class = get_filterset_class(
self._provided_filterset_class, **meta
)
return self._filterset_class return self._filterset_class
@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
# See related PR: https://github.com/graphql-python/graphene-django/pull/126 # See related PR: https://github.com/graphql-python/graphene-django/pull/126
assert not (default_queryset.query.low_mark and queryset.query.low_mark), ( assert not (
'Received two sliced querysets (low mark) in the connection, please slice only in one.' default_queryset.query.low_mark and queryset.query.low_mark
) ), "Received two sliced querysets (low mark) in the connection, please slice only in one."
assert not (default_queryset.query.high_mark and queryset.query.high_mark), ( assert not (
'Received two sliced querysets (high mark) in the connection, please slice only in one.' default_queryset.query.high_mark and queryset.query.high_mark
) ), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits() default_queryset.query.clear_limits()
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset) queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
default_queryset, queryset
)
queryset.query.set_limits(low, high) queryset.query.set_limits(low, high)
return queryset return queryset
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(
enforce_first_or_last, filterset_class, filtering_args, cls,
root, info, **args): resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class( qs = filterset_class(
data=filter_kwargs, data=filter_kwargs,
queryset=default_manager.get_queryset(), queryset=default_manager.get_queryset(),
request=info.context request=info.context,
).qs ).qs
return super(DjangoFilterConnectionField, cls).connection_resolver( return super(DjangoFilterConnectionField, cls).connection_resolver(
@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.max_limit, self.max_limit,
self.enforce_first_or_last, self.enforce_first_or_last,
self.filterset_class, self.filterset_class,
self.filtering_args self.filtering_args,
) )

View File

@ -1,8 +1,7 @@
import itertools import itertools
from django.db import models from django.db import models
from django.utils.text import capfirst from django_filters import Filter, MultipleChoiceFilter, VERSION
from django_filters import Filter, MultipleChoiceFilter
from django_filters.filterset import BaseFilterSet, FilterSet from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
@ -15,7 +14,10 @@ class GlobalIDFilter(Filter):
field_class = GlobalIDFormField field_class = GlobalIDFormField
def filter(self, qs, value): def filter(self, qs, value):
_type, _id = from_global_id(value) """ Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id) return super(GlobalIDFilter, self).filter(qs, _id)
@ -28,71 +30,76 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
GRAPHENE_FILTER_SET_OVERRIDES = { GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: { models.AutoField: {"filter_class": GlobalIDFilter},
'filter_class': GlobalIDFilter, models.OneToOneField: {"filter_class": GlobalIDFilter},
}, models.ForeignKey: {"filter_class": GlobalIDFilter},
models.OneToOneField: { models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
'filter_class': GlobalIDFilter, models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter},
}, models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter},
models.ForeignKey: {
'filter_class': GlobalIDFilter,
},
models.ManyToManyField: {
'filter_class': GlobalIDMultipleChoiceFilter,
}
} }
class GrapheneFilterSetMixin(BaseFilterSet): class GrapheneFilterSetMixin(BaseFilterSet):
FILTER_DEFAULTS = dict(itertools.chain( """ A django_filters.filterset.BaseFilterSet with default filter overrides
FILTER_FOR_DBFIELD_DEFAULTS.items(), to handle global IDs """
GRAPHENE_FILTER_SET_OVERRIDES.items()
))
@classmethod FILTER_DEFAULTS = dict(
def filter_for_reverse_field(cls, f, name): itertools.chain(
"""Handles retrieving filters for reverse relationships FILTER_FOR_DBFIELD_DEFAULTS.items(),
GRAPHENE_FILTER_SET_OVERRIDES.items()
)
)
We override the default implementation so that we can handle
Global IDs (the default implementation expects database # To support a Django 1.11 + Python 2.7 combination django-filter must be
primary keys) # < 2.x.x. To support the earlier version of django-filter, the
""" # filter_for_reverse_field method must be present on GrapheneFilterSetMixin and
rel = f.field.rel # must not be present for later versions of django-filter.
default = { if VERSION[0] < 2:
'name': name, from django.utils.text import capfirst
'label': capfirst(rel.related_name)
} class GrapheneFilterSetMixinPython2(GrapheneFilterSetMixin):
if rel.multiple:
# For to-many relationships @classmethod
return GlobalIDMultipleChoiceFilter(**default) def filter_for_reverse_field(cls, f, name):
else: """Handles retrieving filters for reverse relationships
# For to-one relationships We override the default implementation so that we can handle
return GlobalIDFilter(**default) Global IDs (the default implementation expects database
primary keys)
"""
try:
rel = f.field.remote_field
except AttributeError:
rel = f.field.rel
default = {"name": name, "label": capfirst(rel.related_name)}
if rel.multiple:
# For to-many relationships
return GlobalIDMultipleChoiceFilter(**default)
else:
# For to-one relationships
return GlobalIDFilter(**default)
GrapheneFilterSetMixin = GrapheneFilterSetMixinPython2
def setup_filterset(filterset_class): def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality """ Wrap a provided filterset in Graphene-specific functionality
""" """
return type( return type(
'Graphene{}'.format(filterset_class.__name__), "Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin), (filterset_class, GrapheneFilterSetMixin),
{}, {},
) )
def custom_filterset_factory(model, filterset_base_class=FilterSet, def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
**meta):
""" Create a filterset for the given model using the provided meta data """ Create a filterset for the given model using the provided meta data
""" """
meta.update({ meta.update({"model": model})
'model': model, meta_class = type(str("Meta"), (object,), meta)
})
meta_class = type(str('Meta'), (object,), meta)
filterset = type( filterset = type(
str('%sFilterSet' % model._meta.object_name), str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin), (filterset_base_class, GrapheneFilterSetMixin),
{ {"Meta": meta_class},
'Meta': meta_class
}
) )
return filterset return filterset

View File

@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet): class ArticleFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = { fields = {
'headline': ['exact', 'icontains'], "headline": ["exact", "icontains"],
'pub_date': ['gt', 'lt', 'exact'], "pub_date": ["gt", "lt", "exact"],
'reporter': ['exact'], "reporter": ["exact"],
} }
order_by = OrderingFilter(fields=('pub_date',)) order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet): class ReporterFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['first_name', 'last_name', 'email', 'pets'] fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=('pub_date',)) order_by = OrderingFilter(fields=("pub_date",))
class PetFilter(django_filters.FilterSet): class PetFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Pet model = Pet
fields = ['name'] fields = ["name"]

View File

@ -5,8 +5,7 @@ import pytest
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.forms import (GlobalIDFormField, from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
GlobalIDMultipleChoiceField)
from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED:
import django_filters import django_filters
from django_filters import FilterSet, NumberFilter from django_filters import FilterSet, NumberFilter
from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField, from graphene_django.filter import (
GlobalIDMultipleChoiceFilter) GlobalIDFilter,
from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter,
)
from graphene_django.filter.tests.filters import (
ArticleFilter,
PetFilter,
ReporterFilter,
)
else: else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible')) pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
pytestmark.append(pytest.mark.django_db) pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
class ArticleNode(DjangoObjectType):
class ArticleNode(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ('headline', ) filter_fields = ("headline",)
class ReporterNode(DjangoObjectType): class ReporterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class PetNode(DjangoObjectType): class PetNode(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node, ) interfaces = (Node,)
# schema = Schema() # schema = Schema()
@ -59,58 +65,47 @@ def get_args(field):
def assert_arguments(field, *arguments): def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'order_by') ignore = ("after", "before", "first", "last", "order_by")
args = get_args(field) args = get_args(field)
actual = [ actual = [name for name in args if name not in ignore and not name.startswith("_")]
name assert set(arguments) == set(
for name in args actual
if name not in ignore and not name.startswith('_') ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
]
assert set(arguments) == set(actual), \
'Expected arguments ({}) did not match actual ({})'.format(
arguments,
actual
)
def assert_orderable(field): def assert_orderable(field):
args = get_args(field) args = get_args(field)
assert 'order_by' in args, \ assert "order_by" in args, "Field cannot be ordered"
'Field cannot be ordered'
def assert_not_orderable(field): def assert_not_orderable(field):
args = get_args(field) args = get_args(field)
assert 'order_by' not in args, \ assert "order_by" not in args, "Field can be ordered"
'Field can be ordered'
def test_filter_explicit_filterset_arguments(): def test_filter_explicit_filterset_arguments():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
assert_arguments(field, assert_arguments(
'headline', 'headline__icontains', field,
'pub_date', 'pub_date__gt', 'pub_date__lt', "headline",
'reporter', "headline__icontains",
) "pub_date",
"pub_date__gt",
"pub_date__lt",
"reporter",
)
def test_filter_shortcut_filterset_arguments_list(): def test_filter_shortcut_filterset_arguments_list():
field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter']) field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
assert_arguments(field, assert_arguments(field, "pub_date", "reporter")
'pub_date',
'reporter',
)
def test_filter_shortcut_filterset_arguments_dict(): def test_filter_shortcut_filterset_arguments_dict():
field = DjangoFilterConnectionField(ArticleNode, fields={ field = DjangoFilterConnectionField(
'headline': ['exact', 'icontains'], ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
'reporter': ['exact'], )
}) assert_arguments(field, "headline", "headline__icontains", "reporter")
assert_arguments(field,
'headline', 'headline__icontains',
'reporter',
)
def test_filter_explicit_filterset_orderable(): def test_filter_explicit_filterset_orderable():
@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta(): def test_filter_shortcut_filterset_extra_meta():
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ field = DjangoFilterConnectionField(
'exclude': ('headline', ) ArticleNode, extra_filter_meta={"exclude": ("headline",)}
}) )
assert 'headline' not in field.filterset_class.get_fields() assert "headline" not in field.filterset_class.get_fields()
def test_filter_shortcut_filterset_context(): def test_filter_shortcut_filterset_context():
class ArticleContextFilter(django_filters.FilterSet): class ArticleContextFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
exclude = set() exclude = set()
@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context():
return qs.filter(reporter=self.request.reporter) return qs.filter(reporter=self.request.reporter)
class Query(ObjectType): class Query(ObjectType):
context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter) context_articles = DjangoFilterConnectionField(
ArticleNode, filterset_class=ArticleContextFilter
)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1, editor=r1) Article.objects.create(
Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2, editor=r2) headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
class context(object): class context(object):
reporter = r2 reporter = r2
query = ''' query = """
query { query {
contextArticles { contextArticles {
edges { edges {
@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context():
} }
} }
} }
''' """
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(query, context_value=context()) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert len(result.data['contextArticles']['edges']) == 1 assert len(result.data["contextArticles"]["edges"]) == 1
assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2' assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
def test_filter_filterset_information_on_meta(): def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['first_name', 'articles'] filter_fields = ["first_name", "articles"]
field = DjangoFilterConnectionField(ReporterFilterNode) field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'first_name', 'articles') assert_arguments(field, "first_name", "articles")
assert_not_orderable(field) assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related(): def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['first_name', 'articles'] filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType): class ArticleFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['headline', 'reporter'] filter_fields = ["headline", "reporter"]
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
article = Field(ArticleFilterNode) article = Field(ArticleFilterNode)
schema = Schema(query=Query) schema = Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields['articles'].get_type() articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, 'headline', 'reporter') assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field) assert_not_orderable(articles_field)
def test_filter_filterset_related_results(): def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ['first_name', 'articles'] filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType): class ArticleFilterNode(DjangoObjectType):
class Meta: class Meta:
interfaces = (Node, ) interfaces = (Node,)
model = Article model = Article
filter_fields = ['headline', 'reporter'] filter_fields = ["headline", "reporter"]
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -243,12 +246,22 @@ def test_filter_filterset_related_results():
reporter = Field(ReporterFilterNode) reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode) article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1) Article.objects.create(
Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2) headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
)
query = ''' query = """
query { query {
allReporters { allReporters {
edges { edges {
@ -264,123 +277,134 @@ def test_filter_filterset_related_results():
} }
} }
} }
''' """
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
# We should only get back a single article for each reporter # We should only get back a single article for each reporter
assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1 assert (
assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1 len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
)
assert (
len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
)
def test_global_id_field_implicit(): def test_global_id_field_implicit():
field = DjangoFilterConnectionField(ArticleNode, fields=['id']) field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id'] id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter) assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_explicit(): def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet): class ArticleIdFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = ['id'] fields = ["id"]
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
filterset_class = field.filterset_class filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id'] id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter) assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_filterset_descriptions(): def test_filterset_descriptions():
class ArticleIdFilter(django_filters.FilterSet): class ArticleIdFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Article model = Article
fields = ['id'] fields = ["id"]
max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time") max_time = django_filters.NumberFilter(
method="filter_max_time", label="The maximum time"
)
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
max_time = field.args['max_time'] max_time = field.args["max_time"]
assert isinstance(max_time, Argument) assert isinstance(max_time, Argument)
assert max_time.type == Float assert max_time.type == Float
assert max_time.description == 'The maximum time' assert max_time.description == "The maximum time"
def test_global_id_field_relation(): def test_global_id_field_relation():
field = DjangoFilterConnectionField(ArticleNode, fields=['reporter']) field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['reporter'] id_filter = filterset_class.base_filters["reporter"]
assert isinstance(id_filter, GlobalIDFilter) assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField assert id_filter.field_class == GlobalIDFormField
def test_global_id_multiple_field_implicit(): def test_global_id_multiple_field_implicit():
field = DjangoFilterConnectionField(ReporterNode, fields=['pets']) field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets'] multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit(): def test_global_id_multiple_field_explicit():
class ReporterPetsFilter(django_filters.FilterSet): class ReporterPetsFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['pets'] fields = ["pets"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets'] multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_implicit_reverse(): def test_global_id_multiple_field_implicit_reverse():
field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles'] multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit_reverse(): def test_global_id_multiple_field_explicit_reverse():
class ReporterPetsFilter(django_filters.FilterSet): class ReporterPetsFilter(django_filters.FilterSet):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['articles'] fields = ["articles"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles'] multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_filter_filterset_related_results(): def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = { filter_fields = {"first_name": ["icontains"]}
'first_name': ['icontains']
}
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com') r1 = Reporter.objects.create(
r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com') first_name="A test user", last_name="Last Name", email="test1@test.com"
r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com') )
r2 = Reporter.objects.create(
first_name="Other test user",
last_name="Other Last Name",
email="test2@test.com",
)
r3 = Reporter.objects.create(
first_name="Random", last_name="RandomLast", email="random@test.com"
)
query = ''' query = """
query { query {
allReporters(firstName_Icontains: "test") { allReporters(firstName_Icontains: "test") {
edges { edges {
@ -390,12 +414,12 @@ def test_filter_filterset_related_results():
} }
} }
} }
''' """
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
# We should only get two reporters # We should only get two reporters
assert len(result.data['allReporters']['edges']) == 2 assert len(result.data["allReporters"]["edges"]) == 2
def test_recursive_filter_connection(): def test_recursive_filter_connection():
@ -407,77 +431,73 @@ def test_recursive_filter_connection():
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode) all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode assert (
ReporterFilterNode._meta.fields["child_reporters"].node_type
== ReporterFilterNode
)
def test_should_query_filter_node_limit(): def test_should_query_filter_node_limit():
class ReporterFilter(FilterSet): class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit') limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value): def filter_limit(self, queryset, name, value):
return queryset[:value] return queryset[:value]
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['first_name', ] fields = ["first_name"]
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node,)
filter_fields = ('lang', ) filter_fields = ("lang",)
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField( all_reporters = DjangoFilterConnectionField(
ReporterType, ReporterType, filterset_class=ReporterFilter
filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice') return Reporter.objects.order_by("a_choice")
Reporter.objects.create( Reporter.objects.create(
first_name='Bob', first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
) )
r = Reporter.objects.create( r = Reporter.objects.create(
first_name='John', first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
last_name='Doe',
email='johndoe@example.com',
a_choice=1
) )
Article.objects.create( Article.objects.create(
headline='Article Node 1', headline="Article Node 1",
pub_date=datetime.now(), pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang="es",
) )
Article.objects.create( Article.objects.create(
headline='Article Node 2', headline="Article Node 2",
pub_date=datetime.now(), pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='en' lang="en",
) )
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(limit: 1) { allReporters(limit: 1) {
edges { edges {
@ -496,24 +516,23 @@ def test_should_query_filter_node_limit():
} }
} }
} }
''' """
expected = { expected = {
'allReporters': { "allReporters": {
'edges': [{ "edges": [
'node': { {
'id': 'UmVwb3J0ZXJUeXBlOjI=', "node": {
'firstName': 'John', "id": "UmVwb3J0ZXJUeXBlOjI=",
'articles': { "firstName": "John",
'edges': [{ "articles": {
'node': { "edges": [
'id': 'QXJ0aWNsZVR5cGU6MQ==', {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}}
'lang': 'ES' ]
} },
}]
} }
} }
}] ]
} }
} }
@ -524,45 +543,37 @@ def test_should_query_filter_node_limit():
def test_should_query_filter_node_double_limit_raises(): def test_should_query_filter_node_double_limit_raises():
class ReporterFilter(FilterSet): class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit') limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value): def filter_limit(self, queryset, name, value):
return queryset[:value] return queryset[:value]
class Meta: class Meta:
model = Reporter model = Reporter
fields = ['first_name', ] fields = ["first_name"]
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField( all_reporters = DjangoFilterConnectionField(
ReporterType, ReporterType, filterset_class=ReporterFilter
filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2] return Reporter.objects.order_by("a_choice")[:2]
Reporter.objects.create( Reporter.objects.create(
first_name='Bob', first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
) )
r = Reporter.objects.create( r = Reporter.objects.create(
first_name='John', first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
last_name='Doe',
email='johndoe@example.com',
a_choice=1
) )
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(limit: 1) { allReporters(limit: 1) {
edges { edges {
@ -573,41 +584,40 @@ def test_should_query_filter_node_double_limit_raises():
} }
} }
} }
''' """
result = schema.execute(query) result = schema.execute(query)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert str(result.errors[0]) == ( assert str(result.errors[0]) == (
'Received two sliced querysets (high mark) in the connection, please slice only in one.' "Received two sliced querysets (high mark) in the connection, please slice only in one."
) )
def test_order_by_is_perserved(): def test_order_by_is_perserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = () filter_fields = ()
class Query(ObjectType): class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean()) all_reporters = DjangoFilterConnectionField(
ReporterType, reverse_order=Boolean()
)
def resolve_all_reporters(self, info, reverse_order=False, **args): def resolve_all_reporters(self, info, reverse_order=False, **args):
reporters = Reporter.objects.order_by('first_name') reporters = Reporter.objects.order_by("first_name")
if reverse_order: if reverse_order:
return reporters.reverse() return reporters.reverse()
return reporters return reporters
Reporter.objects.create( Reporter.objects.create(first_name="b")
first_name='b', r = Reporter.objects.create(first_name="a")
)
r = Reporter.objects.create(
first_name='a',
)
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(first: 1) { allReporters(first: 1) {
edges { edges {
@ -617,23 +627,14 @@ def test_order_by_is_perserved():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}}
'allReporters': {
'edges': [{
'node': {
'firstName': 'a',
}
}]
}
}
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
reverse_query = """
reverse_query = '''
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(first: 1, reverseOrder: true) { allReporters(first: 1, reverseOrder: true) {
edges { edges {
@ -643,33 +644,26 @@ def test_order_by_is_perserved():
} }
} }
} }
''' """
reverse_expected = { reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}}
'allReporters': {
'edges': [{
'node': {
'firstName': 'b',
}
}]
}
}
reverse_result = schema.execute(reverse_query) reverse_result = schema.execute(reverse_query)
assert not reverse_result.errors assert not reverse_result.errors
assert reverse_result.data == reverse_expected assert reverse_result.data == reverse_expected
def test_annotation_is_perserved(): def test_annotation_is_perserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
full_name = String() full_name = String()
def resolve_full_name(instance, info, **args): def resolve_full_name(instance, info, **args):
return instance.full_name return instance.full_name
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
filter_fields = () filter_fields = ()
class Query(ObjectType): class Query(ObjectType):
@ -677,17 +671,16 @@ def test_annotation_is_perserved():
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate( return Reporter.objects.annotate(
full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField()) full_name=Concat(
"first_name", Value(" "), "last_name", output_field=TextField()
)
) )
Reporter.objects.create( Reporter.objects.create(first_name="John", last_name="Doe")
first_name='John',
last_name='Doe',
)
schema = Schema(query=Query) schema = Schema(query=Query)
query = ''' query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(first: 1) { allReporters(first: 1) {
edges { edges {
@ -697,16 +690,8 @@ def test_annotation_is_perserved():
} }
} }
} }
''' """
expected = { expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
'allReporters': {
'edges': [{
'node': {
'fullName': 'John Doe',
}
}]
}
}
result = schema.execute(query) result = schema.execute(query)

View File

@ -8,7 +8,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
a Graphene Field. These arguments will be available to a Graphene Field. These arguments will be available to
filter against in the GraphQL filter against in the GraphQL
""" """
from ..form_converter import convert_form_field from ..forms.converter import convert_form_field
args = {} args = {}
for name, filter_field in six.iteritems(filterset_class.base_filters): for name, filter_field in six.iteritems(filterset_class.base_filters):

View File

@ -0,0 +1 @@
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa

View File

@ -1,24 +1,24 @@
from django import forms from django import forms
from django.forms.fields import BaseTemporalField from django.core.exceptions import ImproperlyConfigured
from graphene import ID, Boolean, Float, Int, List, String, UUID from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from .utils import import_single_dispatch from ..utils import import_single_dispatch
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
@singledispatch @singledispatch
def convert_form_field(field): def convert_form_field(field):
raise Exception( raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) " "Don't know how to convert the Django form field %s (%s) "
"to Graphene type" % "to Graphene type" % (field, field.__class__)
(field, field.__class__)
) )
@convert_form_field.register(BaseTemporalField) @convert_form_field.register(forms.fields.BaseTemporalField)
@convert_form_field.register(forms.CharField) @convert_form_field.register(forms.CharField)
@convert_form_field.register(forms.EmailField) @convert_form_field.register(forms.EmailField)
@convert_form_field.register(forms.SlugField) @convert_form_field.register(forms.SlugField)
@ -63,6 +63,21 @@ def convert_form_field_to_list(field):
return List(ID, required=field.required) return List(ID, required=field.required)
@convert_form_field.register(forms.DateField)
def convert_form_field_to_date(field):
return Date(description=field.help_text, required=field.required)
@convert_form_field.register(forms.DateTimeField)
def convert_form_field_to_datetime(field):
return DateTime(description=field.help_text, required=field.required)
@convert_form_field.register(forms.TimeField)
def convert_form_field_to_time(field):
return Time(description=field.help_text, required=field.required)
@convert_form_field.register(forms.ModelChoiceField) @convert_form_field.register(forms.ModelChoiceField)
@convert_form_field.register(GlobalIDFormField) @convert_form_field.register(GlobalIDFormField)
def convert_form_field_to_id(field): def convert_form_field_to_id(field):

View File

@ -8,9 +8,7 @@ from graphql_relay import from_global_id
class GlobalIDFormField(Field): class GlobalIDFormField(Field):
default_error_messages = { default_error_messages = {"invalid": _("Invalid ID specified.")}
'invalid': _('Invalid ID specified.'),
}
def clean(self, value): def clean(self, value):
if not value and not self.required: if not value and not self.required:
@ -19,21 +17,21 @@ class GlobalIDFormField(Field):
try: try:
_type, _id = from_global_id(value) _type, _id = from_global_id(value)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error): except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
raise ValidationError(self.error_messages['invalid']) raise ValidationError(self.error_messages["invalid"])
try: try:
CharField().clean(_id) CharField().clean(_id)
CharField().clean(_type) CharField().clean(_type)
except ValidationError: except ValidationError:
raise ValidationError(self.error_messages['invalid']) raise ValidationError(self.error_messages["invalid"])
return value return value
class GlobalIDMultipleChoiceField(MultipleChoiceField): class GlobalIDMultipleChoiceField(MultipleChoiceField):
default_error_messages = { default_error_messages = {
'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'), "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
'invalid_list': _('Enter a list of values.'), "invalid_list": _("Enter a list of values."),
} }
def valid_value(self, value): def valid_value(self, value):

View File

@ -0,0 +1,192 @@
# from django import forms
from collections import OrderedDict
import graphene
from graphene import Field, InputField
from graphene.relay.mutation import ClientIDMutation
from graphene.types.mutation import MutationOptions
# from graphene.types.inputobjecttype import (
# InputObjectTypeOptions,
# InputObjectType,
# )
from graphene.types.utils import yank_fields_from_attrs
from graphene_django.registry import get_global_registry
from .converter import convert_form_field
from .types import ErrorType
def fields_for_form(form, only_fields, exclude_fields):
fields = OrderedDict()
for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name
in exclude_fields # or
# name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_form_field(field)
return fields
class BaseDjangoFormMutation(ClientIDMutation):
class Meta:
abstract = True
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
form = cls.get_form(root, info, **input)
if form.is_valid():
return cls.perform_mutate(form, info)
else:
errors = [
ErrorType(field=key, messages=value)
for key, value in form.errors.items()
]
return cls(errors=errors)
@classmethod
def get_form(cls, root, info, **input):
form_kwargs = cls.get_form_kwargs(root, info, **input)
return cls._meta.form_class(**form_kwargs)
@classmethod
def get_form_kwargs(cls, root, info, **input):
kwargs = {"data": input}
pk = input.pop("id", None)
if pk:
instance = cls._meta.model._default_manager.get(pk=pk)
kwargs["instance"] = instance
return kwargs
# class DjangoFormInputObjectTypeOptions(InputObjectTypeOptions):
# form_class = None
# class DjangoFormInputObjectType(InputObjectType):
# class Meta:
# abstract = True
# @classmethod
# def __init_subclass_with_meta__(cls, form_class=None,
# only_fields=(), exclude_fields=(), _meta=None, **options):
# if not _meta:
# _meta = DjangoFormInputObjectTypeOptions(cls)
# assert isinstance(form_class, forms.Form), (
# 'form_class must be an instance of django.forms.Form'
# )
# _meta.form_class = form_class
# form = form_class()
# fields = fields_for_form(form, only_fields, exclude_fields)
# super(DjangoFormInputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, fields=fields, **options)
class DjangoFormMutationOptions(MutationOptions):
form_class = None
class DjangoFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(
cls, form_class=None, only_fields=(), exclude_fields=(), **options
):
if not form_class:
raise Exception("form_class is required for DjangoFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
output_fields = fields_for_form(form, only_fields, exclude_fields)
_meta = DjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def perform_mutate(cls, form, info):
form.save()
return cls(errors=[])
class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions):
model = None
return_field_name = None
class DjangoModelFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(
cls,
form_class=None,
model=None,
return_field_name=None,
only_fields=(),
exclude_fields=(),
**options
):
if not form_class:
raise Exception("form_class is required for DjangoModelFormMutation")
if not model:
model = form_class._meta.model
if not model:
raise Exception("model is required for DjangoModelFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
if "id" not in exclude_fields:
input_fields["id"] = graphene.ID()
registry = get_global_registry()
model_type = registry.get_type_for_model(model)
return_field_name = return_field_name
if not return_field_name:
model_name = model.__name__
return_field_name = model_name[:1].lower() + model_name[1:]
output_fields = OrderedDict()
output_fields[return_field_name] = graphene.Field(model_type)
_meta = DjangoModelDjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.model = model
_meta.return_field_name = return_field_name
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def perform_mutate(cls, form, info):
obj = form.save()
kwargs = {cls._meta.return_field_name: obj}
return cls(errors=[], **kwargs)

View File

View File

@ -0,0 +1,114 @@
from django import forms
from py.test import raises
import graphene
from graphene import (
String,
Int,
Boolean,
Float,
ID,
UUID,
List,
NonNull,
DateTime,
Date,
Time,
)
from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text="Custom Help Text")
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == "Custom Help Text"
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_form_field(None)
assert "Don't know how to convert the Django form field" in str(excinfo.value)
def test_should_date_convert_date():
assert_conversion(forms.DateField, Date)
def test_should_time_convert_time():
assert_conversion(forms.TimeField, Time)
def test_should_date_time_convert_date_time():
assert_conversion(forms.DateTimeField, DateTime)
def test_should_char_convert_string():
assert_conversion(forms.CharField, String)
def test_should_email_convert_string():
assert_conversion(forms.EmailField, String)
def test_should_slug_convert_string():
assert_conversion(forms.SlugField, String)
def test_should_url_convert_string():
assert_conversion(forms.URLField, String)
def test_should_choice_convert_string():
assert_conversion(forms.ChoiceField, String)
def test_should_base_field_convert_string():
assert_conversion(forms.Field, String)
def test_should_regex_convert_string():
assert_conversion(forms.RegexField, String, "[0-9]+")
def test_should_uuid_convert_string():
if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID)
def test_should_integer_convert_int():
assert_conversion(forms.IntegerField, Int)
def test_should_boolean_convert_boolean():
field = assert_conversion(forms.BooleanField, Boolean)
assert isinstance(field.type, NonNull)
def test_should_nullboolean_convert_boolean():
field = assert_conversion(forms.NullBooleanField, Boolean)
assert not isinstance(field.type, NonNull)
def test_should_float_convert_float():
assert_conversion(forms.FloatField, Float)
def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, Float)
def test_should_multiple_choice_convert_connectionorlist():
field = forms.ModelMultipleChoiceField(queryset=None)
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == ID
def test_should_manytoone_convert_connectionorlist():
field = forms.ModelChoiceField(queryset=None)
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, ID)

View File

@ -0,0 +1,141 @@
from django import forms
from django.test import TestCase
from py.test import raises
from graphene_django.tests.models import Pet, Film, FilmDetails
from ..mutation import DjangoFormMutation, DjangoModelFormMutation
class MyForm(forms.Form):
text = forms.CharField()
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = '__all__'
def test_needs_form_class():
with raises(Exception) as exc:
class MyMutation(DjangoFormMutation):
pass
assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
def test_has_output_fields():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
assert "text" in MyMutation.Input._meta.fields
class ModelFormMutationTests(TestCase):
def test_default_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("pet", PetMutation._meta.fields)
def test_default_input_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("name", PetMutation.Input._meta.fields)
self.assertIn("client_mutation_id", PetMutation.Input._meta.fields)
self.assertIn("id", PetMutation.Input._meta.fields)
def test_exclude_fields_input_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
exclude_fields = ['id']
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("name", PetMutation.Input._meta.fields)
self.assertIn("age", PetMutation.Input._meta.fields)
self.assertIn("client_mutation_id", PetMutation.Input._meta.fields)
self.assertNotIn("id", PetMutation.Input._meta.fields)
def test_return_field_name_is_camelcased(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
model = FilmDetails
self.assertEqual(PetMutation._meta.model, FilmDetails)
self.assertEqual(PetMutation._meta.return_field_name, "filmDetails")
def test_custom_return_field_name(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
model = Film
return_field_name = "animal"
self.assertEqual(PetMutation._meta.model, Film)
self.assertEqual(PetMutation._meta.return_field_name, "animal")
self.assertIn("animal", PetMutation._meta.fields)
def test_model_form_mutation_mutate(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
pet = Pet.objects.create(name="Axel", age=10)
result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia", age=10)
self.assertEqual(Pet.objects.count(), 1)
pet.refresh_from_db()
self.assertEqual(pet.name, "Mia")
self.assertEqual(result.errors, [])
def test_model_form_mutation_updates_existing_(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None, name="Mia", age=10)
self.assertEqual(Pet.objects.count(), 1)
pet = Pet.objects.get()
self.assertEqual(pet.name, "Mia")
self.assertEqual(pet.age, 10)
self.assertEqual(result.errors, [])
def test_model_form_mutation_mutate_invalid_form(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None)
# A pet was not created
self.assertEqual(Pet.objects.count(), 0)
fields_w_error = [e.field for e in result.errors]
self.assertEqual(len(result.errors), 2)
self.assertIn("name", fields_w_error)
self.assertEqual(result.errors[0].messages, ["This field is required."])
self.assertIn("age", fields_w_error)
self.assertEqual(result.errors[1].messages, ["This field is required."])

View File

@ -0,0 +1,6 @@
import graphene
class ErrorType(graphene.ObjectType):
field = graphene.String()
messages = graphene.List(graphene.String)

View File

@ -7,43 +7,45 @@ from graphene_django.settings import graphene_settings
class CommandArguments(BaseCommand): class CommandArguments(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--schema', "--schema",
type=str, type=str,
dest='schema', dest="schema",
default=graphene_settings.SCHEMA, default=graphene_settings.SCHEMA,
help='Django app containing schema to dump, e.g. myproject.core.schema.schema') help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
)
parser.add_argument( parser.add_argument(
'--out', "--out",
type=str, type=str,
dest='out', dest="out",
default=graphene_settings.SCHEMA_OUTPUT, default=graphene_settings.SCHEMA_OUTPUT,
help='Output file (default: schema.json)') help="Output file, --out=- prints to stdout (default: schema.json)",
)
parser.add_argument( parser.add_argument(
'--indent', "--indent",
type=int, type=int,
dest='indent', dest="indent",
default=graphene_settings.SCHEMA_INDENT, default=graphene_settings.SCHEMA_INDENT,
help='Output file indent (default: None)') help="Output file indent (default: None)",
)
class Command(CommandArguments): class Command(CommandArguments):
help = 'Dump Graphene schema JSON to file' help = "Dump Graphene schema JSON to file"
can_import_settings = True can_import_settings = True
def save_file(self, out, schema_dict, indent): def save_file(self, out, schema_dict, indent):
with open(out, 'w') as outfile: with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent) json.dump(schema_dict, outfile, indent=indent)
def handle(self, *args, **options): def handle(self, *args, **options):
options_schema = options.get('schema') options_schema = options.get("schema")
if options_schema and type(options_schema) is str: if options_schema and type(options_schema) is str:
module_str, schema_name = options_schema.rsplit('.', 1) module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str) mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name) schema = getattr(mod, schema_name)
@ -53,16 +55,21 @@ class Command(CommandArguments):
else: else:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
out = options.get('out') or graphene_settings.SCHEMA_OUTPUT out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
if not schema: if not schema:
raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema') raise CommandError(
"Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
)
indent = options.get('indent') indent = options.get("indent")
schema_dict = {'data': schema.introspect()} schema_dict = {"data": schema.introspect()}
self.save_file(out, schema_dict, indent) if out == '-':
self.stdout.write(json.dumps(schema_dict, indent=indent))
else:
self.save_file(out, schema_dict, indent)
style = getattr(self, 'style', None) style = getattr(self, "style", None)
success = getattr(style, 'SUCCESS', lambda x: x) success = getattr(style, "SUCCESS", lambda x: x)
self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out)) self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out))

View File

@ -1,20 +1,21 @@
class Registry(object): class Registry(object):
def __init__(self): def __init__(self):
self._registry = {} self._registry = {}
self._field_registry = {} self._field_registry = {}
def register(self, cls): def register(self, cls):
from .types import DjangoObjectType from .types import DjangoObjectType
assert issubclass( assert issubclass(
cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format( cls, DjangoObjectType
cls.__name__) ), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
assert cls._meta.registry == self, 'Registry for a Model have to match.' cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, ( # assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# ) # )
if not getattr(cls._meta, 'skip_registry', False): if not getattr(cls._meta, "skip_registry", False):
self._registry[cls._meta.model] = cls self._registry[cls._meta.model] = cls
def get_type_for_model(self, model): def get_type_for_model(self, model):

View File

@ -1,20 +1,21 @@
from collections import OrderedDict from collections import OrderedDict
from django.shortcuts import get_object_or_404
import graphene import graphene
from graphene.types import Field, InputField from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation from graphene.relay.mutation import ClientIDMutation
from graphene.types.objecttype import ( from graphene.types.objecttype import yank_fields_from_attrs
yank_fields_from_attrs
)
from .serializer_converter import ( from .serializer_converter import convert_serializer_field
convert_serializer_field
)
from .types import ErrorType from .types import ErrorType
class SerializerMutationOptions(MutationOptions): class SerializerMutationOptions(MutationOptions):
lookup_field = None
model_class = None
model_operations = ["create", "update"]
serializer_class = None serializer_class = None
@ -23,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals
for name, field in serializer.fields.items(): for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_excluded = ( is_excluded = (
name in exclude_fields # or name
in exclude_fields # or
# name in already_created_fields # name in already_created_fields
) )
@ -39,37 +41,86 @@ class SerializerMutation(ClientIDMutation):
abstract = True abstract = True
errors = graphene.List( errors = graphene.List(
ErrorType, ErrorType, description="May contain more than one error for same field."
description='May contain more than one error for same field.'
) )
@classmethod @classmethod
def __init_subclass_with_meta__(cls, serializer_class=None, def __init_subclass_with_meta__(
only_fields=(), exclude_fields=(), **options): cls,
lookup_field=None,
serializer_class=None,
model_class=None,
model_operations=["create", "update"],
only_fields=(),
exclude_fields=(),
**options
):
if not serializer_class: if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation') raise Exception("serializer_class is required for the SerializerMutation")
if "update" not in model_operations and "create" not in model_operations:
raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class() serializer = serializer_class()
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) if model_class is None:
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) serializer_meta = getattr(serializer_class, "Meta", None)
if serializer_meta:
model_class = getattr(serializer_meta, "model", None)
if lookup_field is None and model_class:
lookup_field = model_class._meta.pk.name
input_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=True
)
output_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=False
)
_meta = SerializerMutationOptions(cls) _meta = SerializerMutationOptions(cls)
_meta.lookup_field = lookup_field
_meta.model_operations = model_operations
_meta.serializer_class = serializer_class _meta.serializer_class = serializer_class
_meta.fields = yank_fields_from_attrs( _meta.model_class = model_class
output_fields, _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
_as=Field,
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(SerializerMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
) )
input_fields = yank_fields_from_attrs( @classmethod
input_fields, def get_serializer_kwargs(cls, root, info, **input):
_as=InputField, lookup_field = cls._meta.lookup_field
) model_class = cls._meta.model_class
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
if model_class:
if "update" in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(
model_class, **{lookup_field: input[lookup_field]}
)
elif "create" in cls._meta.model_operations:
instance = None
else:
raise Exception(
'Invalid update operation. Input parameter "{}" required.'.format(
lookup_field
)
)
return {
"instance": instance,
"data": input,
"context": {"request": info.context},
}
return {"data": input, "context": {"request": info.context}}
@classmethod @classmethod
def mutate_and_get_payload(cls, root, info, **input): def mutate_and_get_payload(cls, root, info, **input):
serializer = cls._meta.serializer_class(data=input) kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs)
if serializer.is_valid(): if serializer.is_valid():
return cls.perform_mutate(serializer, info) return cls.perform_mutate(serializer, info)

View File

@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True):
graphql_type = get_graphene_type_from_serializer_field(field) graphql_type = get_graphene_type_from_serializer_field(field)
args = [] args = []
kwargs = { kwargs = {"description": field.help_text, "required": is_input and field.required}
'description': field.help_text,
'required': is_input and field.required,
}
# if it is a tuple or a list it means that we are returning # if it is a tuple or a list it means that we are returning
# the graphql type and the child type # the graphql type and the child type
if isinstance(graphql_type, (list, tuple)): if isinstance(graphql_type, (list, tuple)):
kwargs['of_type'] = graphql_type[1] kwargs["of_type"] = graphql_type[1]
graphql_type = graphql_type[0] graphql_type = graphql_type[0]
if isinstance(field, serializers.ModelSerializer): if isinstance(field, serializers.ModelSerializer):
@ -46,6 +43,15 @@ def convert_serializer_field(field, is_input=True):
global_registry = get_global_registry() global_registry = get_global_registry()
field_model = field.Meta.model field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)] args = [global_registry.get_type_for_model(field_model)]
elif isinstance(field, serializers.ListSerializer):
field = field.child
if is_input:
kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
else:
del kwargs["of_type"]
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
return graphql_type(*args, **kwargs) return graphql_type(*args, **kwargs)
@ -59,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class):
} }
return type( return type(
'{}Input'.format(serializer.__class__.__name__), "{}Input".format(serializer.__class__.__name__),
(graphene.InputObjectType,), (graphene.InputObjectType,),
items items,
) )
@ -75,6 +81,12 @@ def convert_serializer_to_field(field):
return graphene.Field return graphene.Field
@get_graphene_type_from_serializer_field.register(serializers.ListSerializer)
def convert_list_serializer_to_field(field):
child_type = get_graphene_type_from_serializer_field(field.child)
return (graphene.List, child_type)
@get_graphene_type_from_serializer_field.register(serializers.IntegerField) @get_graphene_type_from_serializer_field.register(serializers.IntegerField)
def convert_serializer_field_to_int(field): def convert_serializer_field_to_int(field):
return graphene.Int return graphene.Int
@ -92,9 +104,13 @@ def convert_serializer_field_to_float(field):
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField) @get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
def convert_serializer_field_to_datetime_time(field):
return graphene.types.datetime.DateTime
@get_graphene_type_from_serializer_field.register(serializers.DateField) @get_graphene_type_from_serializer_field.register(serializers.DateField)
def convert_serializer_field_to_date_time(field): def convert_serializer_field_to_date_time(field):
return graphene.types.datetime.DateTime return graphene.types.datetime.Date
@get_graphene_type_from_serializer_field.register(serializers.TimeField) @get_graphene_type_from_serializer_field.register(serializers.TimeField)

View File

@ -1,8 +1,10 @@
import copy import copy
from rest_framework import serializers
from py.test import raises
import graphene import graphene
from django.db import models
from graphene import InputObjectType
from py.test import raises
from rest_framework import serializers
from ..serializer_converter import convert_serializer_field from ..serializer_converter import convert_serializer_field
from ..types import DictType from ..types import DictType
@ -14,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
# Remove `source=` from the field declaration. # Remove `source=` from the field declaration.
# since we are reusing the same child in when testing the required attribute # since we are reusing the same child in when testing the required attribute
if 'child' in kwargs: if "child" in kwargs:
kwargs['child'] = copy.deepcopy(kwargs['child']) kwargs["child"] = copy.deepcopy(kwargs["child"])
field = rest_framework_field(**kwargs) field = rest_framework_field(**kwargs)
@ -23,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
def assert_conversion(rest_framework_field, graphene_field, **kwargs): def assert_conversion(rest_framework_field, graphene_field, **kwargs):
graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs) graphene_type = _get_type(
rest_framework_field, help_text="Custom Help Text", **kwargs
)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
graphene_type_required = _get_type( graphene_type_required = _get_type(
rest_framework_field, help_text='Custom Help Text', required=True, **kwargs rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
) )
assert isinstance(graphene_type_required, graphene_field) assert isinstance(graphene_type_required, graphene_field)
@ -37,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs):
def test_should_unknown_rest_framework_field_raise_exception(): def test_should_unknown_rest_framework_field_raise_exception():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
convert_serializer_field(None) convert_serializer_field(None)
assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) assert "Don't know how to convert the serializer field" in str(excinfo.value)
def test_should_char_convert_string(): def test_should_char_convert_string():
@ -65,20 +69,19 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string(): def test_should_regex_convert_string():
assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+') assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(serializers, 'UUIDField'): if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String) assert_conversion(serializers.UUIDField, graphene.String)
def test_should_model_convert_field(): def test_should_model_convert_field():
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = None model = None
fields = '__all__' fields = "__all__"
assert_conversion(MyModelSerializer, graphene.Field, is_input=False) assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
@ -87,8 +90,8 @@ def test_should_date_time_convert_datetime():
assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime) assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime)
def test_should_date_convert_datetime(): def test_should_date_convert_date():
assert_conversion(serializers.DateField, graphene.types.datetime.DateTime) assert_conversion(serializers.DateField, graphene.types.datetime.Date)
def test_should_time_convert_time(): def test_should_time_convert_time():
@ -108,7 +111,9 @@ def test_should_float_convert_float():
def test_should_decimal_convert_float(): def test_should_decimal_convert_float():
assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2) assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
)
def test_should_list_convert_to_list(): def test_should_list_convert_to_list():
@ -118,7 +123,7 @@ def test_should_list_convert_to_list():
field_a = assert_conversion( field_a = assert_conversion(
serializers.ListField, serializers.ListField,
graphene.List, graphene.List,
child=serializers.IntegerField(min_value=0, max_value=100) child=serializers.IntegerField(min_value=0, max_value=100),
) )
assert field_a.of_type == graphene.Int assert field_a.of_type == graphene.Int
@ -128,6 +133,34 @@ def test_should_list_convert_to_list():
assert field_b.of_type == graphene.String assert field_b.of_type == graphene.String
def test_should_list_serializer_convert_to_list():
class FooModel(models.Model):
pass
class ChildSerializer(serializers.ModelSerializer):
class Meta:
model = FooModel
fields = "__all__"
class ParentSerializer(serializers.ModelSerializer):
child = ChildSerializer(many=True)
class Meta:
model = FooModel
fields = "__all__"
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=True
)
assert isinstance(converted_type, graphene.List)
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=False
)
assert isinstance(converted_type, graphene.List)
assert converted_type.of_type is None
def test_should_dict_convert_dict(): def test_should_dict_convert_dict():
assert_conversion(serializers.DictField, DictType) assert_conversion(serializers.DictField, DictType)
@ -141,7 +174,7 @@ def test_should_file_convert_string():
def test_should_filepath_convert_string(): def test_should_filepath_convert_string():
assert_conversion(serializers.FilePathField, graphene.String, path='/') assert_conversion(serializers.FilePathField, graphene.String, path="/")
def test_should_ip_convert_string(): def test_should_ip_convert_string():
@ -157,6 +190,8 @@ def test_should_json_convert_jsonstring():
def test_should_multiplechoicefield_convert_to_list_of_string(): def test_should_multiplechoicefield_convert_to_list_of_string():
field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3]) field = assert_conversion(
serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
)
assert field.of_type == graphene.String assert field.of_type == graphene.String

View File

@ -1,6 +1,6 @@
import datetime import datetime
from graphene import Field from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from py.test import raises from py.test import raises
from py.test import mark from py.test import mark
@ -11,10 +11,30 @@ from ..models import MyFakeModel
from ..mutation import SerializerMutation from ..mutation import SerializerMutation
def mock_info():
return ResolveInfo(
None,
None,
None,
None,
schema=None,
fragments=None,
root_value=None,
operation=None,
variable_values=None,
context=None,
)
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = MyFakeModel model = MyFakeModel
fields = '__all__' fields = "__all__"
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer): class MySerializer(serializers.Serializer):
@ -27,10 +47,11 @@ class MySerializer(serializers.Serializer):
def test_needs_serializer_class(): def test_needs_serializer_class():
with raises(Exception) as exc: with raises(Exception) as exc:
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
pass pass
assert str(exc.value) == 'serializer_class is required for the SerializerMutation' assert str(exc.value) == "serializer_class is required for the SerializerMutation"
def test_has_fields(): def test_has_fields():
@ -38,9 +59,9 @@ def test_has_fields():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
assert 'text' in MyMutation._meta.fields assert "text" in MyMutation._meta.fields
assert 'model' in MyMutation._meta.fields assert "model" in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields assert "errors" in MyMutation._meta.fields
def test_has_input_fields(): def test_has_input_fields():
@ -48,25 +69,24 @@ def test_has_input_fields():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
assert 'text' in MyMutation.Input._meta.fields assert "text" in MyMutation.Input._meta.fields
assert 'model' in MyMutation.Input._meta.fields assert "model" in MyMutation.Input._meta.fields
def test_exclude_fields(): def test_exclude_fields():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MyModelSerializer serializer_class = MyModelSerializer
exclude_fields = ['created'] exclude_fields = ["created"]
assert 'cool_name' in MyMutation._meta.fields assert "cool_name" in MyMutation._meta.fields
assert 'created' not in MyMutation._meta.fields assert "created" not in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields assert "errors" in MyMutation._meta.fields
assert 'cool_name' in MyMutation.Input._meta.fields assert "cool_name" in MyMutation.Input._meta.fields
assert 'created' not in MyMutation.Input._meta.fields assert "created" not in MyMutation.Input._meta.fields
def test_nested_model(): def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType): class MyFakeModelGrapheneType(DjangoObjectType):
class Meta: class Meta:
model = MyFakeModel model = MyFakeModel
@ -75,61 +95,85 @@ def test_nested_model():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
model_field = MyMutation._meta.fields['model'] model_field = MyMutation._meta.fields["model"]
assert isinstance(model_field, Field) assert isinstance(model_field, Field)
assert model_field.type == MyFakeModelGrapheneType assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields['model'] model_input = MyMutation.Input._meta.fields["model"]
model_input_type = model_input._type.of_type model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType) assert issubclass(model_input_type, InputObjectType)
assert 'cool_name' in model_input_type._meta.fields assert "cool_name" in model_input_type._meta.fields
assert 'created' in model_input_type._meta.fields assert "created" in model_input_type._meta.fields
def test_mutate_and_get_payload_success(): def test_mutate_and_get_payload_success():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
result = MyMutation.mutate_and_get_payload(None, None, **{ result = MyMutation.mutate_and_get_payload(
'text': 'value', None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
'model': { )
'cool_name': 'other_value'
}
})
assert result.errors is None assert result.errors is None
@mark.django_db @mark.django_db
def test_model_mutate_and_get_payload_success(): def test_model_add_mutate_and_get_payload_success():
class MyMutation(SerializerMutation): result = MyModelMutation.mutate_and_get_payload(
class Meta: None, mock_info(), **{"cool_name": "Narf"}
serializer_class = MyModelSerializer )
result = MyMutation.mutate_and_get_payload(None, None, **{
'cool_name': 'Narf',
})
assert result.errors is None assert result.errors is None
assert result.cool_name == 'Narf' assert result.cool_name == "Narf"
assert isinstance(result.created, datetime.datetime) assert isinstance(result.created, datetime.datetime)
def test_mutate_and_get_payload_error():
@mark.django_db
def test_model_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
)
assert result.errors is None
assert result.cool_name == "New Narf"
@mark.django_db
def test_model_invalid_update_mutate_and_get_payload_success():
class InvalidModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ["update"]
with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert '"id" required' in str(exc.value)
def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
# missing required fields # missing required fields
result = MyMutation.mutate_and_get_payload(None, None, **{}) result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0 assert len(result.errors) > 0
def test_model_mutate_and_get_payload_error(): def test_model_mutate_and_get_payload_error():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
# missing required fields # missing required fields
result = MyMutation.mutate_and_get_payload(None, None, **{}) result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0 assert len(result.errors) > 0
def test_invalid_serializer_operations():
with raises(Exception) as exc:
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ["Add"]
assert "model_operations" in str(exc.value)

View File

@ -26,27 +26,22 @@ except ImportError:
# Copied shamelessly from Django REST Framework # Copied shamelessly from Django REST Framework
DEFAULTS = { DEFAULTS = {
'SCHEMA': None, "SCHEMA": None,
'SCHEMA_OUTPUT': 'schema.json', "SCHEMA_OUTPUT": "schema.json",
'SCHEMA_INDENT': None, "SCHEMA_INDENT": None,
'MIDDLEWARE': (), "MIDDLEWARE": (),
# Set to True if the connection fields must have # Set to True if the connection fields must have
# either the first or last argument # either the first or last argument
'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False, "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
# Max items returned in ConnectionFields / FilterConnectionFields # Max items returned in ConnectionFields / FilterConnectionFields
'RELAY_CONNECTION_MAX_LIMIT': 100, "RELAY_CONNECTION_MAX_LIMIT": 100,
} }
if settings.DEBUG: if settings.DEBUG:
DEFAULTS['MIDDLEWARE'] += ( DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
'graphene_django.debug.DjangoDebugMiddleware',
)
# List of settings that may be in string import notation. # List of settings that may be in string import notation.
IMPORT_STRINGS = ( IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
'MIDDLEWARE',
'SCHEMA',
)
def perform_import(val, setting_name): def perform_import(val, setting_name):
@ -69,12 +64,17 @@ def import_from_string(val, setting_name):
""" """
try: try:
# Nod to tastypie's use of importlib. # Nod to tastypie's use of importlib.
parts = val.split('.') parts = val.split(".")
module_path, class_name = '.'.join(parts[:-1]), parts[-1] module_path, class_name = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg) raise ImportError(msg)
@ -96,8 +96,8 @@ class GrapheneSettings(object):
@property @property
def user_settings(self): def user_settings(self):
if not hasattr(self, '_user_settings'): if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, 'GRAPHENE', {}) self._user_settings = getattr(settings, "GRAPHENE", {})
return self._user_settings return self._user_settings
def __getattr__(self, attr): def __getattr__(self, attr):
@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_graphene_settings(*args, **kwargs): def reload_graphene_settings(*args, **kwargs):
global graphene_settings global graphene_settings
setting, value = kwargs['setting'], kwargs['value'] setting, value = kwargs["setting"], kwargs["value"]
if setting == 'GRAPHENE': if setting == "GRAPHENE":
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS) graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)

View File

@ -0,0 +1,99 @@
(function() {
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ('; ' + document.cookie).split('; csrftoken=');
if (cookies.length == 2)
csrftoken = cookies.pop().split(';').shift();
// Collect the URL parameters
var parameters = {};
window.location.hash.substr(1).split('&').forEach(function (entry) {
var eq = entry.indexOf('=');
if (eq >= 0) {
parameters[decodeURIComponent(entry.slice(0, eq))] =
decodeURIComponent(entry.slice(eq + 1));
}
});
// Produce a Location fragment string from a parameter object.
function locationQuery(params) {
return '#' + Object.keys(params).map(function (key) {
return encodeURIComponent(key) + '=' +
encodeURIComponent(params[key]);
}).join('&');
}
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
var graphqlParamNames = {
query: true,
variables: true,
operationName: true
};
var otherParams = {};
for (var k in parameters) {
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
otherParams[k] = parameters[k];
}
}
var fetchURL = locationQuery(otherParams);
// Defines a GraphQL fetcher using the fetch API.
function graphQLFetcher(graphQLParams) {
var headers = {
'Accept': 'application/json',
'Content-Type': 'application/json'
};
if (csrftoken) {
headers['X-CSRFToken'] = csrftoken;
}
return fetch(fetchURL, {
method: 'post',
headers: headers,
body: JSON.stringify(graphQLParams),
credentials: 'include',
}).then(function (response) {
return response.text();
}).then(function (responseBody) {
try {
return JSON.parse(responseBody);
} catch (error) {
return responseBody;
}
});
}
// When the query and variables string is edited, update the URL bar so
// that it can be easily shared.
function onEditQuery(newQuery) {
parameters.query = newQuery;
updateURL();
}
function onEditVariables(newVariables) {
parameters.variables = newVariables;
updateURL();
}
function onEditOperationName(newOperationName) {
parameters.operationName = newOperationName;
updateURL();
}
function updateURL() {
history.replaceState(null, null, locationQuery(parameters));
}
var options = {
fetcher: graphQLFetcher,
onEditQuery: onEditQuery,
onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName,
query: parameters.query,
}
if (parameters.variables) {
options.variables = parameters.variables;
}
if (parameters.operation_name) {
options.operationName = parameters.operation_name;
}
// Render <GraphiQL /> into the body.
ReactDOM.render(
React.createElement(GraphiQL, options),
document.body
);
})();

View File

@ -5,6 +5,7 @@ exploring GraphQL.
If you wish to receive JSON, provide the header "Accept: application/json" or If you wish to receive JSON, provide the header "Accept: application/json" or
add "&raw" to the end of the URL within a browser. add "&raw" to the end of the URL within a browser.
--> -->
{% load static %}
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@ -16,108 +17,13 @@ add "&raw" to the end of the URL within a browser.
width: 100%; width: 100%;
} }
</style> </style>
<link href="//cdn.jsdelivr.net/graphiql/{{graphiql_version}}/graphiql.css" rel="stylesheet" /> <link href="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script> <script src="//cdn.jsdelivr.net/npm/whatwg-fetch@2.0.3/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.1/react.min.js"></script> <script src="//cdn.jsdelivr.net/npm/react@16.2.0/umd/react.production.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.1/react-dom.min.js"></script> <script src="//cdn.jsdelivr.net/npm/react-dom@16.2.0/umd/react-dom.production.min.js"></script>
<script src="//cdn.jsdelivr.net/graphiql/{{graphiql_version}}/graphiql.min.js"></script> <script src="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.js"></script>
</head> </head>
<body> <body>
<script> <script src="{% static 'graphene_django/graphiql.js' %}"></script>
// Parse the cookie value for a CSRF token
var csrftoken;
var cookies = ('; ' + document.cookie).split('; csrftoken=');
if (cookies.length == 2)
csrftoken = cookies.pop().split(';').shift();
// Collect the URL parameters
var parameters = {};
window.location.search.substr(1).split('&').forEach(function (entry) {
var eq = entry.indexOf('=');
if (eq >= 0) {
parameters[decodeURIComponent(entry.slice(0, eq))] =
decodeURIComponent(entry.slice(eq + 1));
}
});
// Produce a Location query string from a parameter object.
function locationQuery(params) {
return '?' + Object.keys(params).map(function (key) {
return encodeURIComponent(key) + '=' +
encodeURIComponent(params[key]);
}).join('&');
}
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
var graphqlParamNames = {
query: true,
variables: true,
operationName: true
};
var otherParams = {};
for (var k in parameters) {
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
otherParams[k] = parameters[k];
}
}
var fetchURL = locationQuery(otherParams);
// Defines a GraphQL fetcher using the fetch API.
function graphQLFetcher(graphQLParams) {
var headers = {
'Accept': 'application/json',
'Content-Type': 'application/json'
};
if (csrftoken) {
headers['X-CSRFToken'] = csrftoken;
}
return fetch(fetchURL, {
method: 'post',
headers: headers,
body: JSON.stringify(graphQLParams),
credentials: 'include',
}).then(function (response) {
return response.text();
}).then(function (responseBody) {
try {
return JSON.parse(responseBody);
} catch (error) {
return responseBody;
}
});
}
// When the query and variables string is edited, update the URL bar so
// that it can be easily shared.
function onEditQuery(newQuery) {
parameters.query = newQuery;
updateURL();
}
function onEditVariables(newVariables) {
parameters.variables = newVariables;
updateURL();
}
function onEditOperationName(newOperationName) {
parameters.operationName = newOperationName;
updateURL();
}
function updateURL() {
history.replaceState(null, null, locationQuery(parameters));
}
// Render <GraphiQL /> into the body.
ReactDOM.render(
React.createElement(GraphiQL, {
fetcher: graphQLFetcher,
onEditQuery: onEditQuery,
onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName,
query: '{{ query|escapejs }}',
response: '{{ result|escapejs }}',
{% if variables %}
variables: '{{ variables|escapejs }}',
{% endif %}
{% if operation_name %}
operationName: '{{ operation_name|escapejs }}',
{% endif %}
}),
document.body
);
</script>
</body> </body>
</html> </html>

View File

@ -3,56 +3,103 @@ from __future__ import absolute_import
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
CHOICES = ( CHOICES = ((1, "this"), (2, _("that")))
(1, 'this'),
(2, _('that'))
)
class Pet(models.Model): class Pet(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
age = models.PositiveIntegerField()
class FilmDetails(models.Model): class FilmDetails(models.Model):
location = models.CharField(max_length=30) location = models.CharField(max_length=30)
film = models.OneToOneField('Film', related_name='details') film = models.OneToOneField(
"Film", on_delete=models.CASCADE, related_name="details"
)
class Film(models.Model): class Film(models.Model):
reporters = models.ManyToManyField('Reporter', genre = models.CharField(
related_name='films') max_length=2,
help_text="Genre",
choices=[("do", "Documentary"), ("ot", "Other")],
default="ot",
)
reporters = models.ManyToManyField("Reporter", related_name="films")
class DoeReporterManager(models.Manager): class DoeReporterManager(models.Manager):
def get_queryset(self): def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe") return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
class Reporter(models.Model): class Reporter(models.Model):
first_name = models.CharField(max_length=30) first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30)
email = models.EmailField() email = models.EmailField()
pets = models.ManyToManyField('self') pets = models.ManyToManyField("self")
a_choice = models.CharField(max_length=30, choices=CHOICES) a_choice = models.CharField(max_length=30, choices=CHOICES)
objects = models.Manager() objects = models.Manager()
doe_objects = DoeReporterManager() doe_objects = DoeReporterManager()
def __str__(self): # __unicode__ on Python 2 reporter_type = models.IntegerField(
"Reporter Type",
null=True,
blank=True,
choices=[(1, u"Regular"), (2, u"CNN Reporter")],
)
def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name) return "%s %s" % (self.first_name, self.last_name)
def __init__(self, *args, **kwargs):
"""
Override the init method so that during runtime, Django
can know that this object can be a CNNReporter by casting
it to the proxy model. Otherwise, as far as Django knows,
when a CNNReporter is pulled from the database, it is still
of type Reporter. This was added to test proxy model support.
"""
super(Reporter, self).__init__(*args, **kwargs)
if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter
class CNNReporter(Reporter):
"""
This class is a proxy model for Reporter, used for testing
proxy model support
"""
class Meta:
proxy = True
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField() pub_date = models.DateField()
reporter = models.ForeignKey(Reporter, related_name='articles') pub_date_time = models.DateTimeField()
editor = models.ForeignKey(Reporter, related_name='edited_articles_+') reporter = models.ForeignKey(
lang = models.CharField(max_length=2, help_text='Language', choices=[ Reporter, on_delete=models.CASCADE, related_name="articles"
('es', 'Spanish'), )
('en', 'English') editor = models.ForeignKey(
], default='es') Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
importance = models.IntegerField('Importance', null=True, blank=True, )
choices=[(1, u'Very important'), (2, u'Not as important')]) lang = models.CharField(
max_length=2,
help_text="Language",
choices=[("es", "Spanish"), ("en", "English")],
default="es",
)
importance = models.IntegerField(
"Importance",
null=True,
blank=True,
choices=[(1, u"Very important"), (2, u"Not as important")],
)
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return self.headline return self.headline
class Meta: class Meta:
ordering = ('headline',) ordering = ("headline",)

View File

@ -6,10 +6,9 @@ from .models import Article, Reporter
class Character(DjangoObjectType): class Character(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (relay.Node, ) interfaces = (relay.Node,)
def get_node(self, info, id): def get_node(self, info, id):
pass pass
@ -20,7 +19,7 @@ class Human(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (relay.Node, ) interfaces = (relay.Node,)
def resolve_raises(self, info): def resolve_raises(self, info):
raise Exception("This field should raise exception") raise Exception("This field should raise exception")

View File

@ -12,10 +12,10 @@ class QueryRoot(ObjectType):
raise Exception("Throws!") raise Exception("Throws!")
def resolve_request(self, info): def resolve_request(self, info):
return info.context.GET.get('q') return info.context.GET.get("q")
def resolve_test(self, info, who=None): def resolve_test(self, info, who=None):
return 'Hello %s' % (who or 'World') return "Hello %s" % (who or "World")
class MutationRoot(ObjectType): class MutationRoot(ObjectType):

View File

@ -3,8 +3,8 @@ from mock import patch
from six import StringIO from six import StringIO
@patch('graphene_django.management.commands.graphql_schema.Command.save_file') @patch("graphene_django.management.commands.graphql_schema.Command.save_file")
def test_generate_file_on_call_graphql_schema(savefile_mock, settings): def test_generate_file_on_call_graphql_schema(savefile_mock, settings):
out = StringIO() out = StringIO()
management.call_command('graphql_schema', schema='', stdout=out) management.call_command("graphql_schema", schema="", stdout=out)
assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue() assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()

View File

@ -6,7 +6,7 @@ from py.test import raises
import graphene import graphene
from graphene.relay import ConnectionField, Node from graphene.relay import ConnectionField, Node
from graphene.types.datetime import DateTime, Time from graphene.types.datetime import DateTime, Date, Time
from graphene.types.json import JSONString from graphene.types.json import JSONString
from ..compat import JSONField, ArrayField, HStoreField, RangeField, MissingType from ..compat import JSONField, ArrayField, HStoreField, RangeField, MissingType
@ -20,11 +20,11 @@ from .models import Article, Film, FilmDetails, Reporter
def assert_conversion(django_field, graphene_field, *args, **kwargs): def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs) field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs)
graphene_type = convert_django_field(field) graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field() field = graphene_type.Field()
assert field.description == 'Custom Help Text' assert field.description == "Custom Help Text"
nonnull_field = django_field(null=False, *args, **kwargs) nonnull_field = django_field(null=False, *args, **kwargs)
if not nonnull_field.null: if not nonnull_field.null:
nonnull_graphene_type = convert_django_field(nonnull_field) nonnull_graphene_type = convert_django_field(nonnull_field)
@ -37,11 +37,15 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
def test_should_unknown_django_field_raise_exception(): def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
convert_django_field(None) convert_django_field(None)
assert 'Don\'t know how to convert the Django field' in str(excinfo.value) assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string():
assert_conversion(models.DateTimeField, DateTime)
def test_should_date_convert_string(): def test_should_date_convert_string():
assert_conversion(models.DateField, DateTime) assert_conversion(models.DateField, Date)
def test_should_time_convert_string(): def test_should_time_convert_string():
@ -84,6 +88,10 @@ def test_should_image_convert_string():
assert_conversion(models.ImageField, graphene.String) assert_conversion(models.ImageField, graphene.String)
def test_should_url_convert_string():
assert_conversion(models.FilePathField, graphene.String)
def test_should_auto_convert_id(): def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID, primary_key=True) assert_conversion(models.AutoField, graphene.ID, primary_key=True)
@ -126,70 +134,69 @@ def test_should_nullboolean_convert_boolean():
def test_field_with_choices_convert_enum(): def test_field_with_choices_convert_enum():
field = models.CharField(help_text='Language', choices=( field = models.CharField(
('es', 'Spanish'), help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
('en', 'English') )
))
class TranslatedModel(models.Model): class TranslatedModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
graphene_type = convert_django_field_with_choices(field) graphene_type = convert_django_field_with_choices(field)
assert isinstance(graphene_type, graphene.Enum) assert isinstance(graphene_type, graphene.Enum)
assert graphene_type._meta.name == 'TranslatedModelLanguage' assert graphene_type._meta.name == "TranslatedModelLanguage"
assert graphene_type._meta.enum.__members__['ES'].value == 'es' assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish' assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__['EN'].value == 'en' assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__['EN'].description == 'English' assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_grouped_choices(): def test_field_with_grouped_choices():
field = models.CharField(help_text='Language', choices=( field = models.CharField(
('Europe', ( help_text="Language",
('es', 'Spanish'), choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
('en', 'English'), )
)),
))
class GroupedChoicesModel(models.Model): class GroupedChoicesModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
convert_django_field_with_choices(field) convert_django_field_with_choices(field)
def test_field_with_choices_gettext(): def test_field_with_choices_gettext():
field = models.CharField(help_text='Language', choices=( field = models.CharField(
('es', _('Spanish')), help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
('en', _('English')) )
))
class TranslatedChoicesModel(models.Model): class TranslatedChoicesModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
convert_django_field_with_choices(field) convert_django_field_with_choices(field)
def test_field_with_choices_collision(): def test_field_with_choices_collision():
field = models.CharField(help_text='Timezone', choices=( field = models.CharField(
('Etc/GMT+1+2', 'Fake choice to produce double collision'), help_text="Timezone",
('Etc/GMT+1', 'Greenwich Mean Time +1'), choices=(
('Etc/GMT-1', 'Greenwich Mean Time -1'), ("Etc/GMT+1+2", "Fake choice to produce double collision"),
)) ("Etc/GMT+1", "Greenwich Mean Time +1"),
("Etc/GMT-1", "Greenwich Mean Time -1"),
),
)
class CollisionChoicesModel(models.Model): class CollisionChoicesModel(models.Model):
timezone = field timezone = field
class Meta: class Meta:
app_label = 'test' app_label = "test"
convert_django_field_with_choices(field) convert_django_field_with_choices(field)
@ -206,11 +213,12 @@ def test_should_manytomany_convert_connectionorlist():
def test_should_manytomany_convert_connectionorlist_list(): def test_should_manytomany_convert_connectionorlist_list():
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field) assert isinstance(dynamic_field, graphene.Field)
@ -220,12 +228,13 @@ def test_should_manytomany_convert_connectionorlist_list():
def test_should_manytomany_convert_connectionorlist_connection(): def test_should_manytomany_convert_connectionorlist_connection():
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node,)
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField) assert isinstance(dynamic_field, ConnectionField)
@ -233,16 +242,12 @@ def test_should_manytomany_convert_connectionorlist_connection():
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
# Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Reporter.articles, 'rel', None) or \
getattr(Reporter.articles, 'related')
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
graphene_field = convert_django_field(related, A._meta.registry) graphene_field = convert_django_field(Reporter.articles.rel,
A._meta.registry)
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field) assert isinstance(dynamic_field, graphene.Field)
@ -251,57 +256,53 @@ def test_should_manytoone_convert_connectionorlist():
def test_should_onetoone_reverse_convert_model(): def test_should_onetoone_reverse_convert_model():
# Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Film.details, 'rel', None) or \
getattr(Film.details, 'related')
class A(DjangoObjectType): class A(DjangoObjectType):
class Meta: class Meta:
model = FilmDetails model = FilmDetails
graphene_field = convert_django_field(related, A._meta.registry) graphene_field = convert_django_field(Film.details.related,
A._meta.registry)
assert isinstance(graphene_field, graphene.Dynamic) assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type() dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field) assert isinstance(dynamic_field, graphene.Field)
assert dynamic_field.type == A assert dynamic_field.type == A
@pytest.mark.skipif(ArrayField is MissingType, @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
reason="ArrayField should exist")
def test_should_postgres_array_convert_list(): def test_should_postgres_array_convert_list():
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100)) field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100)
)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.String assert field.type.of_type.of_type == graphene.String
@pytest.mark.skipif(ArrayField is MissingType, @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
reason="ArrayField should exist")
def test_should_postgres_array_multiple_convert_list(): def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))) field = assert_conversion(
ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List) assert isinstance(field.type.of_type.of_type, graphene.List)
assert field.type.of_type.of_type.of_type == graphene.String assert field.type.of_type.of_type.of_type == graphene.String
@pytest.mark.skipif(HStoreField is MissingType, @pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
reason="HStoreField should exist")
def test_should_postgres_hstore_convert_string(): def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString) assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(JSONField is MissingType, @pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
reason="JSONField should exist")
def test_should_postgres_json_convert_string(): def test_should_postgres_json_convert_string():
assert_conversion(JSONField, JSONString) assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType, @pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
reason="RangeField should exist")
def test_should_postgres_range_convert_list(): def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List) field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type, graphene.List)

View File

@ -1,103 +0,0 @@
from django import forms
from py.test import raises
import graphene
from graphene import ID, List, NonNull
from ..form_converter import convert_form_field
from .models import Reporter
def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text='Custom Help Text')
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == 'Custom Help Text'
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_form_field(None)
assert 'Don\'t know how to convert the Django form field' in str(excinfo.value)
def test_should_date_convert_string():
assert_conversion(forms.DateField, graphene.String)
def test_should_time_convert_string():
assert_conversion(forms.TimeField, graphene.String)
def test_should_date_time_convert_string():
assert_conversion(forms.DateTimeField, graphene.String)
def test_should_char_convert_string():
assert_conversion(forms.CharField, graphene.String)
def test_should_email_convert_string():
assert_conversion(forms.EmailField, graphene.String)
def test_should_slug_convert_string():
assert_conversion(forms.SlugField, graphene.String)
def test_should_url_convert_string():
assert_conversion(forms.URLField, graphene.String)
def test_should_choice_convert_string():
assert_conversion(forms.ChoiceField, graphene.String)
def test_should_base_field_convert_string():
assert_conversion(forms.Field, graphene.String)
def test_should_regex_convert_string():
assert_conversion(forms.RegexField, graphene.String, '[0-9]+')
def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'):
assert_conversion(forms.UUIDField, graphene.UUID)
def test_should_integer_convert_int():
assert_conversion(forms.IntegerField, graphene.Int)
def test_should_boolean_convert_boolean():
field = assert_conversion(forms.BooleanField, graphene.Boolean)
assert isinstance(field.type, NonNull)
def test_should_nullboolean_convert_boolean():
field = assert_conversion(forms.NullBooleanField, graphene.Boolean)
assert not isinstance(field.type, NonNull)
def test_should_float_convert_float():
assert_conversion(forms.FloatField, graphene.Float)
def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, graphene.Float)
def test_should_multiple_choice_convert_connectionorlist():
field = forms.ModelMultipleChoiceField(Reporter.objects.all())
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == ID
def test_should_manytoone_convert_connectionorlist():
field = forms.ModelChoiceField(Reporter.objects.all())
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene.ID)

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from py.test import raises from py.test import raises
from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc' # 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
@ -9,24 +9,24 @@ from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
def test_global_id_valid(): def test_global_id_valid():
field = GlobalIDFormField() field = GlobalIDFormField()
field.clean('TXlUeXBlOmFiYw==') field.clean("TXlUeXBlOmFiYw==")
def test_global_id_invalid(): def test_global_id_invalid():
field = GlobalIDFormField() field = GlobalIDFormField()
with raises(ValidationError): with raises(ValidationError):
field.clean('badvalue') field.clean("badvalue")
def test_global_id_multiple_valid(): def test_global_id_multiple_valid():
field = GlobalIDMultipleChoiceField() field = GlobalIDMultipleChoiceField()
field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw==']) field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
def test_global_id_multiple_invalid(): def test_global_id_multiple_invalid():
field = GlobalIDMultipleChoiceField() field = GlobalIDMultipleChoiceField()
with raises(ValidationError): with raises(ValidationError):
field.clean(['badvalue', 'another bad avue']) field.clean(["badvalue", "another bad avue"])
def test_global_id_none(): def test_global_id_none():

File diff suppressed because it is too large Load Diff

View File

@ -7,47 +7,47 @@ from .models import Reporter
def test_should_raise_if_no_model(): def test_should_raise_if_no_model():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character1(DjangoObjectType): class Character1(DjangoObjectType):
pass pass
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_raise_if_model_is_invalid(): def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character2(DjangoObjectType):
class Character2(DjangoObjectType):
class Meta: class Meta:
model = 1 model = 1
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_map_fields_correctly(): def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType): class ReporterType2(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
registry = Registry() registry = Registry()
fields = list(ReporterType2._meta.fields.keys()) fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [ assert fields[:-2] == [
'id', "id",
'first_name', "first_name",
'last_name', "last_name",
'email', "email",
'pets', "pets",
'a_choice', "a_choice",
"reporter_type",
] ]
assert sorted(fields[-2:]) == [ assert sorted(fields[-2:]) == ["articles", "films"]
'articles',
'films',
]
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
only_fields = ('id', 'email') only_fields = ("id", "email")
assert list(Reporter2._meta.fields.keys()) == ['id', 'email'] assert list(Reporter2._meta.fields.keys()) == ["id", "email"]

View File

@ -12,27 +12,30 @@ registry.reset_global_registry()
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
'''Reporter description''' """Reporter description"""
class Meta: class Meta:
model = ReporterModel model = ReporterModel
class ArticleConnection(Connection): class ArticleConnection(Connection):
'''Article Connection''' """Article Connection"""
test = String() test = String()
def resolve_test(): def resolve_test():
return 'test' return "test"
class Meta: class Meta:
abstract = True abstract = True
class Article(DjangoObjectType): class Article(DjangoObjectType):
'''Article description''' """Article description"""
class Meta: class Meta:
model = ArticleModel model = ArticleModel
interfaces = (Node, ) interfaces = (Node,)
connection_class = ArticleConnection connection_class = ArticleConnection
@ -48,7 +51,7 @@ def test_django_interface():
assert issubclass(Node, Node) assert issubclass(Node, Node)
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) @patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
def test_django_get_node(get): def test_django_get_node(get):
article = Article.get_node(None, 1) article = Article.get_node(None, 1)
get.assert_called_with(pk=1) get.assert_called_with(pk=1)
@ -58,27 +61,44 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields(): def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields fields = Reporter._meta.fields
fields = list(fields.keys()) fields = list(fields.keys())
assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice'] assert fields[:-2] == [
assert sorted(fields[-2:]) == ['articles', 'films'] "id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_django_objecttype_with_node_have_correct_fields(): def test_django_objecttype_with_node_have_correct_fields():
fields = Article._meta.fields fields = Article._meta.fields
assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'reporter', 'editor', 'lang', 'importance'] assert list(fields.keys()) == [
"id",
"headline",
"pub_date",
"pub_date_time",
"reporter",
"editor",
"lang",
"importance",
]
def test_django_objecttype_with_custom_meta(): def test_django_objecttype_with_custom_meta():
class ArticleTypeOptions(DjangoObjectTypeOptions): class ArticleTypeOptions(DjangoObjectTypeOptions):
'''Article Type Options''' """Article Type Options"""
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
abstract = True abstract = True
@classmethod @classmethod
def __init_subclass_with_meta__(cls, _meta=None, **options): def __init_subclass_with_meta__(cls, **options):
_meta = ArticleTypeOptions(cls) options.setdefault("_meta", ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(ArticleType, cls).__init_subclass_with_meta__(**options)
class Article(ArticleType): class Article(ArticleType):
class Meta: class Meta:
@ -96,7 +116,8 @@ schema {
type Article implements Node { type Article implements Node {
id: ID! id: ID!
headline: String! headline: String!
pubDate: DateTime! pubDate: Date!
pubDateTime: DateTime!
reporter: Reporter! reporter: Reporter!
editor: Reporter! editor: Reporter!
lang: ArticleLang! lang: ArticleLang!
@ -124,6 +145,8 @@ enum ArticleLang {
EN EN
} }
scalar Date
scalar DateTime scalar DateTime
interface Node { interface Node {
@ -144,6 +167,7 @@ type Reporter {
email: String! email: String!
pets: [Reporter] pets: [Reporter]
aChoice: ReporterAChoice! aChoice: ReporterAChoice!
reporterType: ReporterReporterType
articles(before: String, after: String, first: Int, last: Int): ArticleConnection articles(before: String, after: String, first: Int, last: Int): ArticleConnection
} }
@ -152,6 +176,11 @@ enum ReporterAChoice {
A_2 A_2
} }
enum ReporterReporterType {
A_1
A_2
}
type RootQuery { type RootQuery {
node(id: ID!): Node node(id: ID!): Node
} }
@ -171,6 +200,7 @@ def with_local_registry(func):
else: else:
registry.registry = old registry.registry = old
return retval return retval
return inner return inner
@ -179,11 +209,10 @@ def test_django_objecttype_only_fields():
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
only_fields = ('id', 'email', 'films') only_fields = ("id", "email", "films")
fields = list(Reporter._meta.fields.keys()) fields = list(Reporter._meta.fields.keys())
assert fields == ['id', 'email', 'films'] assert fields == ["id", "email", "films"]
@with_local_registry @with_local_registry
@ -191,8 +220,7 @@ def test_django_objecttype_exclude_fields():
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude_fields = ('email') exclude_fields = "email"
fields = list(Reporter._meta.fields.keys()) fields = list(Reporter._meta.fields.keys())
assert 'email' not in fields assert "email" not in fields

View File

@ -8,15 +8,15 @@ except ImportError:
from urllib.parse import urlencode from urllib.parse import urlencode
def url_string(string='/graphql', **url_params): def url_string(string="/graphql", **url_params):
if url_params: if url_params:
string += '?' + urlencode(url_params) string += "?" + urlencode(url_params)
return string return string
def batch_url_string(**url_params): def batch_url_string(**url_params):
return url_string('/graphql/batch', **url_params) return url_string("/graphql/batch", **url_params)
def response_json(response): def response_json(response):
@ -28,405 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs])
def test_graphiql_is_enabled(client): def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT='text/html') response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == 200 assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_graphiql(client):
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
)
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_json(client):
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
)
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_query_param(client): def test_allows_get_with_query_param(client):
response = client.get(url_string(query='{test}')) response = client.get(url_string(query="{test}"))
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_allows_get_with_variable_values(client): def test_allows_get_with_variable_values(client):
response = client.get(url_string( response = client.get(
query='query helloWho($who: String){ test(who: $who) }', url_string(
variables=json.dumps({'who': "Dolly"}) query="query helloWho($who: String){ test(who: $who) }",
)) variables=json.dumps({"who": "Dolly"}),
)
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_allows_get_with_operation_name(client): def test_allows_get_with_operation_name(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query helloYou { test(who: "You"), ...shared } query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared } query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared } query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot { fragment shared on QueryRoot {
shared: test(who: "Everyone") shared: test(who: "Everyone")
} }
''', """,
operationName='helloWorld' operationName="helloWorld",
)) )
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': { "data": {"test": "Hello World", "shared": "Hello Everyone"}
'test': 'Hello World',
'shared': 'Hello Everyone'
}
} }
def test_reports_validation_errors(client): def test_reports_validation_errors(client):
response = client.get(url_string( response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
query='{ test, unknownOne, unknownTwo }'
))
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {
'message': 'Cannot query field "unknownOne" on type "QueryRoot".', "message": 'Cannot query field "unknownOne" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 9}] "locations": [{"line": 1, "column": 9}],
}, },
{ {
'message': 'Cannot query field "unknownTwo" on type "QueryRoot".', "message": 'Cannot query field "unknownTwo" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 21}] "locations": [{"line": 1, "column": 21}],
} },
] ]
} }
def test_errors_when_missing_operation_name(client): def test_errors_when_missing_operation_name(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query TestQuery { test } query TestQuery { test }
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''' """
)) )
)
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {
'message': 'Must provide operation name if query contains multiple operations.' "message": "Must provide operation name if query contains multiple operations."
} }
] ]
} }
def test_errors_when_sending_a_mutation_via_get(client): def test_errors_when_sending_a_mutation_via_get(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''' """
)) )
)
assert response.status_code == 405 assert response.status_code == 405
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {"message": "Can only perform a mutation operation from a POST request."}
'message': 'Can only perform a mutation operation from a POST request.'
}
] ]
} }
def test_errors_when_selecting_a_mutation_within_a_get(client): def test_errors_when_selecting_a_mutation_within_a_get(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query TestQuery { test } query TestQuery { test }
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''', """,
operationName='TestMutation' operationName="TestMutation",
)) )
)
assert response.status_code == 405 assert response.status_code == 405
assert response_json(response) == { assert response_json(response) == {
'errors': [ "errors": [
{ {"message": "Can only perform a mutation operation from a POST request."}
'message': 'Can only perform a mutation operation from a POST request.'
}
] ]
} }
def test_allows_mutation_to_exist_within_a_get(client): def test_allows_mutation_to_exist_within_a_get(client):
response = client.get(url_string( response = client.get(
query=''' url_string(
query="""
query TestQuery { test } query TestQuery { test }
mutation TestMutation { writeTest { test } } mutation TestMutation { writeTest { test } }
''', """,
operationName='TestQuery' operationName="TestQuery",
)) )
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_allows_post_with_json_encoding(client): def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query='{test}'), 'application/json') response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_batch_allows_post_with_json_encoding(client): def test_batch_allows_post_with_json_encoding(client):
response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json') response = client.post(
batch_url_string(), jl(id=1, query="{test}"), "application/json"
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {"id": 1, "data": {"test": "Hello World"}, "status": 200}
'data': {'test': "Hello World"}, ]
'status': 200,
}]
def test_batch_fails_if_is_empty(client): def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), '[]', 'application/json') response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Received an empty list in the batch request.'}] "errors": [{"message": "Received an empty list in the batch request."}]
} }
def test_allows_sending_a_mutation_via_post(client): def test_allows_sending_a_mutation_via_post(client):
response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json') response = client.post(
url_string(),
j(query="mutation TestMutation { writeTest { test } }"),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
'data': {'writeTest': {'test': 'Hello World'}}
}
def test_allows_post_with_url_encoding(client): def test_allows_post_with_url_encoding(client):
response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded') response = client.post(
url_string(),
urlencode(dict(query="{test}")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello World"}}
'data': {'test': "Hello World"}
}
def test_supports_post_json_query_with_string_variables(client): def test_supports_post_json_query_with_string_variables(client):
response = client.post(url_string(), j( response = client.post(
query='query helloWho($who: String){ test(who: $who) }', url_string(),
variables=json.dumps({'who': "Dolly"}) j(
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_batch_supports_post_json_query_with_string_variables(client): def test_batch_supports_post_json_query_with_string_variables(client):
response = client.post(batch_url_string(), jl( response = client.post(
id=1, batch_url_string(),
query='query helloWho($who: String){ test(who: $who) }', jl(
variables=json.dumps({'who': "Dolly"}) id=1,
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
'data': {'test': "Hello Dolly"}, ]
'status': 200,
}]
def test_supports_post_json_query_with_json_variables(client): def test_supports_post_json_query_with_json_variables(client):
response = client.post(url_string(), j( response = client.post(
query='query helloWho($who: String){ test(who: $who) }', url_string(),
variables={'who': "Dolly"} j(
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_batch_supports_post_json_query_with_json_variables(client): def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(batch_url_string(), jl( response = client.post(
id=1, batch_url_string(),
query='query helloWho($who: String){ test(who: $who) }', jl(
variables={'who': "Dolly"} id=1,
), 'application/json') query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
'data': {'test': "Hello Dolly"}, ]
'status': 200,
}]
def test_supports_post_url_encoded_query_with_string_variables(client): def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post(url_string(), urlencode(dict( response = client.post(
query='query helloWho($who: String){ test(who: $who) }', url_string(),
variables=json.dumps({'who': "Dolly"}) urlencode(
)), 'application/x-www-form-urlencoded') dict(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_supports_post_json_quey_with_get_variable_values(client): def test_supports_post_json_quey_with_get_variable_values(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), j( j(query="query helloWho($who: String){ test(who: $who) }"),
query='query helloWho($who: String){ test(who: $who) }', "application/json",
), 'application/json') )
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_post_url_encoded_query_with_get_variable_values(client): def test_post_url_encoded_query_with_get_variable_values(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), urlencode(dict( urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
query='query helloWho($who: String){ test(who: $who) }', "application/x-www-form-urlencoded",
)), 'application/x-www-form-urlencoded') )
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"test": "Hello Dolly"}}
'data': {'test': "Hello Dolly"}
}
def test_supports_post_raw_text_query_with_get_variable_values(client): def test_supports_post_raw_text_query_with_get_variable_values(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), "query helloWho($who: String){ test(who: $who) }",
'query helloWho($who: String){ test(who: $who) }', "application/graphql",
'application/graphql' )
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_post_with_operation_name(client):
response = client.post(
url_string(),
j(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
),
"application/json",
) )
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': {'test': "Hello Dolly"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_allows_post_with_operation_name(client):
response = client.post(url_string(), j(
query='''
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
} }
def test_batch_allows_post_with_operation_name(client): def test_batch_allows_post_with_operation_name(client):
response = client.post(batch_url_string(), jl( response = client.post(
id=1, batch_url_string(),
query=''' jl(
id=1,
query="""
query helloYou { test(who: "You"), ...shared } query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared } query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared } query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot { fragment shared on QueryRoot {
shared: test(who: "Everyone") shared: test(who: "Everyone")
} }
''', """,
operationName='helloWorld' operationName="helloWorld",
), 'application/json') ),
"application/json",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == [{ assert response_json(response) == [
'id': 1, {
'data': { "id": 1,
'test': 'Hello World', "data": {"test": "Hello World", "shared": "Hello Everyone"},
'shared': 'Hello Everyone' "status": 200,
}, }
'status': 200, ]
}]
def test_allows_post_with_get_operation_name(client): def test_allows_post_with_get_operation_name(client):
response = client.post(url_string( response = client.post(
operationName='helloWorld' url_string(operationName="helloWorld"),
), ''' """
query helloYou { test(who: "You"), ...shared } query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared } query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared } query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot { fragment shared on QueryRoot {
shared: test(who: "Everyone") shared: test(who: "Everyone")
} }
''', """,
'application/graphql') "application/graphql",
)
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': { "data": {"test": "Hello World", "shared": "Hello Everyone"}
'test': 'Hello World',
'shared': 'Hello Everyone'
}
} }
@pytest.mark.urls('graphene_django.tests.urls_pretty') @pytest.mark.urls("graphene_django.tests.urls_inherited")
def test_inherited_class_with_attributes_works(client):
inherited_url = "/graphql/inherited/"
# Check schema and pretty attributes work
response = client.post(url_string(inherited_url, query="{test}"))
assert response.content.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
# Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == 200
@pytest.mark.urls("graphene_django.tests.urls_pretty")
def test_supports_pretty_printing(client): def test_supports_pretty_printing(client):
response = client.get(url_string(query='{test}')) response = client.get(url_string(query="{test}"))
assert response.content.decode() == ( assert response.content.decode() == (
'{\n' "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
) )
def test_supports_pretty_printing_by_request(client): def test_supports_pretty_printing_by_request(client):
response = client.get(url_string(query='{test}', pretty='1')) response = client.get(url_string(query="{test}", pretty="1"))
assert response.content.decode() == ( assert response.content.decode() == (
'{\n' "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
) )
def test_handles_field_errors_caught_by_graphql(client): def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query='{thrower}')) response = client.get(url_string(query="{thrower}"))
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': None, "data": None,
'errors': [{'locations': [{'column': 2, 'line': 1}], 'message': 'Throws!'}] "errors": [
{
"locations": [{"column": 2, "line": 1}],
"path": ["thrower"],
"message": "Throws!",
}
],
} }
def test_handles_syntax_errors_caught_by_graphql(client): def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query='syntaxerror')) response = client.get(url_string(query="syntaxerror"))
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'locations': [{'column': 1, 'line': 1}], "errors": [
'message': 'Syntax Error GraphQL request (1:1) ' {
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}] "locations": [{"column": 1, "line": 1}],
"message": "Syntax Error GraphQL (1:1) "
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n',
}
]
} }
@ -435,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}] "errors": [{"message": "Must provide query string."}]
} }
def test_handles_not_expected_json_bodies(client): def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), '[]', 'application/json') response = client.post(url_string(), "[]", "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'The received data is not a valid JSON query.'}] "errors": [{"message": "The received data is not a valid JSON query."}]
} }
def test_handles_invalid_json_bodies(client): def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), '[oh}', 'application/json') response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
@ -463,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch):
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read) monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
valid_json = json.dumps(dict(foo='bar')) valid_json = json.dumps(dict(foo="bar"))
response = client.post(url_string(), valid_json, 'application/json') response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
'errors': [{'message': 'foo-bar'}]
}
def test_handles_incomplete_json_bodies(client): def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', 'application/json') response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
def test_handles_plain_post_text(client): def test_handles_plain_post_text(client):
response = client.post(url_string( response = client.post(
variables=json.dumps({'who': "Dolly"}) url_string(variables=json.dumps({"who": "Dolly"})),
), "query helloWho($who: String){ test(who: $who) }",
'query helloWho($who: String){ test(who: $who) }', "text/plain",
'text/plain'
) )
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}] "errors": [{"message": "Must provide query string."}]
} }
def test_handles_poorly_formed_variables(client): def test_handles_poorly_formed_variables(client):
response = client.get(url_string( response = client.get(
query='query helloWho($who: String){ test(who: $who) }', url_string(
variables='who:You' query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
)) )
)
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'Variables are invalid JSON.'}] "errors": [{"message": "Variables are invalid JSON."}]
} }
def test_handles_unsupported_http_methods(client): def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query='{test}')) response = client.put(url_string(query="{test}"))
assert response.status_code == 405 assert response.status_code == 405
assert response['Allow'] == 'GET, POST' assert response["Allow"] == "GET, POST"
assert response_json(response) == { assert response_json(response) == {
'errors': [{'message': 'GraphQL only supports GET and POST requests.'}] "errors": [{"message": "GraphQL only supports GET and POST requests."}]
} }
def test_passes_request_into_context_request(client): def test_passes_request_into_context_request(client):
response = client.get(url_string(query='{request}', q='testing')) response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {"data": {"request": "testing"}}
'data': {
'request': 'testing'
}
}

View File

@ -3,6 +3,6 @@ from django.conf.urls import url
from ..views import GraphQLView from ..views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r'^graphql/batch', GraphQLView.as_view(batch=True)), url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
url(r'^graphql', GraphQLView.as_view(graphiql=True)), url(r"^graphql", GraphQLView.as_view(graphiql=True)),
] ]

View File

@ -0,0 +1,13 @@
from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
class CustomGraphQLView(GraphQLView):
schema = schema
graphiql = True
pretty = True
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]

View File

@ -3,6 +3,4 @@ from django.conf.urls import url
from ..views import GraphQLView from ..views import GraphQLView
from .schema_view import schema from .schema_view import schema
urlpatterns = [ urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]
url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)),
]

View File

@ -1,5 +1,7 @@
import six
from collections import OrderedDict from collections import OrderedDict
from django.db.models import Model
from django.utils.functional import SimpleLazyObject from django.utils.functional import SimpleLazyObject
from graphene import Field from graphene import Field
from graphene.relay import Connection, Node from graphene.relay import Connection, Node
@ -8,8 +10,11 @@ from graphene.types.utils import yank_fields_from_attrs
from .converter import convert_django_field_with_choices from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry from .registry import Registry, get_global_registry
from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields, from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model
is_valid_django_model)
if six.PY3:
from typing import Type
def construct_fields(model, registry, only_fields, exclude_fields): def construct_fields(model, registry, only_fields, exclude_fields):
@ -21,7 +26,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
# is_already_created = name in options.fields # is_already_created = name in options.fields
is_excluded = name in exclude_fields # or is_already_created is_excluded = name in exclude_fields # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name # https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith('+') is_no_backref = str(name).endswith("+")
if is_not_in_only or is_excluded or is_no_backref: if is_not_in_only or is_excluded or is_no_backref:
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields. # in there. Or when we exclude this field in exclude_fields.
@ -43,9 +48,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions):
class DjangoObjectType(ObjectType): class DjangoObjectType(ObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, def __init_subclass_with_meta__(
only_fields=(), exclude_fields=(), filter_fields=None, connection=None, cls,
connection_class=None, use_connection=None, interfaces=(), _meta=None, **options): model=None,
registry=None,
skip_registry=False,
only_fields=(),
exclude_fields=(),
filter_fields=None,
connection=None,
connection_class=None,
use_connection=None,
interfaces=(),
_meta=None,
**options
):
assert is_valid_django_model(model), ( assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".' 'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(cls.__name__, model) ).format(cls.__name__, model)
@ -54,7 +71,7 @@ class DjangoObjectType(ObjectType):
registry = get_global_registry() registry = get_global_registry()
assert isinstance(registry, Registry), ( assert isinstance(registry, Registry), (
'The attribute registry in {} needs to be an instance of ' "The attribute registry in {} needs to be an instance of "
'Registry, received "{}".' 'Registry, received "{}".'
).format(cls.__name__, registry) ).format(cls.__name__, registry)
@ -62,12 +79,13 @@ class DjangoObjectType(ObjectType):
raise Exception("Can only set filter_fields if Django-Filter is installed") raise Exception("Can only set filter_fields if Django-Filter is installed")
django_fields = yank_fields_from_attrs( django_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields), construct_fields(model, registry, only_fields, exclude_fields), _as=Field
_as=Field,
) )
if use_connection is None and interfaces: if use_connection is None and interfaces:
use_connection = any((issubclass(interface, Node) for interface in interfaces)) use_connection = any(
(issubclass(interface, Node) for interface in interfaces)
)
if use_connection and not connection: if use_connection and not connection:
# We create the connection automatically # We create the connection automatically
@ -75,7 +93,8 @@ class DjangoObjectType(ObjectType):
connection_class = Connection connection_class = Connection
connection = connection_class.create_type( connection = connection_class.create_type(
'{}Connection'.format(cls.__name__), node=cls) "{}Connection".format(cls.__name__), node=cls
)
if connection is not None: if connection is not None:
assert issubclass(connection, Connection), ( assert issubclass(connection, Connection), (
@ -91,7 +110,9 @@ class DjangoObjectType(ObjectType):
_meta.fields = django_fields _meta.fields = django_fields
_meta.connection = connection _meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options) super(DjangoObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options
)
if not skip_registry: if not skip_registry:
registry.register(cls) registry.register(cls)
@ -107,10 +128,9 @@ class DjangoObjectType(ObjectType):
if isinstance(root, cls): if isinstance(root, cls):
return True return True
if not is_valid_django_model(type(root)): if not is_valid_django_model(type(root)):
raise Exception(( raise Exception(('Received incompatible instance "{}".').format(root))
'Received incompatible instance "{}".'
).format(root)) model = root._meta.model._meta.concrete_model
model = root._meta.model
return model == cls._meta.model return model == cls._meta.model
@classmethod @classmethod

View File

@ -13,6 +13,7 @@ class LazyList(object):
try: try:
import django_filters # noqa import django_filters # noqa
DJANGO_FILTER_INSTALLED = True DJANGO_FILTER_INSTALLED = True
except ImportError: except ImportError:
DJANGO_FILTER_INSTALLED = False DJANGO_FILTER_INSTALLED = False
@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names):
continue continue
# Django =>1.9 uses 'rel', django <1.9 uses 'related' # Django =>1.9 uses 'rel', django <1.9 uses 'related'
related = getattr(attr, 'rel', None) or \ related = getattr(attr, "rel", None) or getattr(attr, "related", None)
getattr(attr, 'related', None)
if isinstance(related, models.ManyToOneRel): if isinstance(related, models.ManyToOneRel):
yield (name, related) yield (name, related)
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
@ -42,9 +42,9 @@ def maybe_queryset(value):
def get_model_fields(model): def get_model_fields(model):
local_fields = [ local_fields = [
(field.name, field) (field.name, field)
for field for field in sorted(
in sorted(list(model._meta.fields) + list(model._meta.fields) + list(model._meta.local_many_to_many)
list(model._meta.local_many_to_many)) )
] ]
# Make sure we don't duplicate local fields with "reverse" version # Make sure we don't duplicate local fields with "reverse" version

View File

@ -10,18 +10,16 @@ from django.utils.decorators import method_decorator
from django.views.generic import View from django.views.generic import View
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from graphql import Source, execute, parse, validate from graphql import get_default_backend
from graphql.error import format_error as format_graphql_error from graphql.error import format_error as format_graphql_error
from graphql.error import GraphQLError from graphql.error import GraphQLError
from graphql.execution import ExecutionResult from graphql.execution import ExecutionResult
from graphql.type.schema import GraphQLSchema from graphql.type.schema import GraphQLSchema
from graphql.utils.get_operation_ast import get_operation_ast
from .settings import graphene_settings from .settings import graphene_settings
class HttpError(Exception): class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs): def __init__(self, response, message=None, *args, **kwargs):
self.response = response self.response = response
self.message = message = message or response.content.decode() self.message = message = message or response.content.decode()
@ -30,18 +28,18 @@ class HttpError(Exception):
def get_accepted_content_types(request): def get_accepted_content_types(request):
def qualify(x): def qualify(x):
parts = x.split(';', 1) parts = x.split(";", 1)
if len(parts) == 2: if len(parts) == 2:
match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)', match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
parts[1])
if match: if match:
return parts[0], float(match.group(2)) return parts[0].strip(), float(match.group(2))
return parts[0], 1 return parts[0].strip(), 1
raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',') raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types) qualified_content_types = map(qualify, raw_content_types)
return list(x[0] for x in sorted(qualified_content_types, return list(
key=lambda x: x[1], reverse=True)) x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
)
def instantiate_middleware(middlewares): def instantiate_middleware(middlewares):
@ -53,38 +51,52 @@ def instantiate_middleware(middlewares):
class GraphQLView(View): class GraphQLView(View):
graphiql_version = '0.10.2' graphiql_version = "0.11.10"
graphiql_template = 'graphene/graphiql.html' graphiql_template = "graphene/graphiql.html"
schema = None schema = None
graphiql = False graphiql = False
executor = None executor = None
backend = None
middleware = None middleware = None
root_value = None root_value = None
pretty = False pretty = False
batch = False batch = False
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False, def __init__(
batch=False): self,
schema=None,
executor=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
backend=None,
):
if not schema: if not schema:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
if backend is None:
backend = get_default_backend()
if middleware is None: if middleware is None:
middleware = graphene_settings.MIDDLEWARE middleware = graphene_settings.MIDDLEWARE
self.schema = schema self.schema = self.schema or schema
if middleware is not None: if middleware is not None:
self.middleware = list(instantiate_middleware(middleware)) self.middleware = list(instantiate_middleware(middleware))
self.executor = executor self.executor = executor
self.root_value = root_value self.root_value = root_value
self.pretty = pretty self.pretty = self.pretty or pretty
self.graphiql = graphiql self.graphiql = self.graphiql or graphiql
self.batch = batch self.batch = self.batch or batch
self.backend = backend
assert isinstance( assert isinstance(
self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' self.schema, GraphQLSchema
assert not all((graphiql, batch) ), "A Schema is required to be provided to GraphQLView."
), 'Use either graphiql or batch processing' assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def get_root_value(self, request): def get_root_value(self, request):
@ -96,62 +108,58 @@ class GraphQLView(View):
def get_context(self, request): def get_context(self, request):
return request return request
def get_backend(self, request):
return self.backend
@method_decorator(ensure_csrf_cookie) @method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
try: try:
if request.method.lower() not in ('get', 'post'): if request.method.lower() not in ("get", "post"):
raise HttpError(HttpResponseNotAllowed( raise HttpError(
['GET', 'POST'], 'GraphQL only supports GET and POST requests.')) HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request) data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql( show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
request, data)
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
result = '[{}]'.format(','.join([response[0] for response in responses]))
status_code = responses and max(responses, key=lambda response: response[1])[1] or 200
else:
result, status_code = self.get_response(
request, data, show_graphiql)
if show_graphiql: if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params(
request, data)
return self.render_graphiql( return self.render_graphiql(
request, request,
graphiql_version=self.graphiql_version, graphiql_version=self.graphiql_version,
query=query or '',
variables=json.dumps(variables) or '',
operation_name=operation_name or '',
result=result or ''
) )
if self.batch:
responses = [self.get_response(request, entry) for entry in data]
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
result, status_code = self.get_response(request, data, show_graphiql)
return HttpResponse( return HttpResponse(
status=status_code, status=status_code, content=result, content_type="application/json"
content=result,
content_type='application/json'
) )
except HttpError as e: except HttpError as e:
response = e.response response = e.response
response['Content-Type'] = 'application/json' response["Content-Type"] = "application/json"
response.content = self.json_encode(request, { response.content = self.json_encode(
'errors': [self.format_error(e)] request, {"errors": [self.format_error(e)]}
}) )
return response return response
def get_response(self, request, data, show_graphiql=False): def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params( query, variables, operation_name, id = self.get_graphql_params(request, data)
request, data)
execution_result = self.execute_graphql_request( execution_result = self.execute_graphql_request(
request, request, data, query, variables, operation_name, show_graphiql
data,
query,
variables,
operation_name,
show_graphiql
) )
status_code = 200 status_code = 200
@ -159,17 +167,18 @@ class GraphQLView(View):
response = {} response = {}
if execution_result.errors: if execution_result.errors:
response['errors'] = [self.format_error( response["errors"] = [
e) for e in execution_result.errors] self.format_error(e) for e in execution_result.errors
]
if execution_result.invalid: if execution_result.invalid:
status_code = 400 status_code = 400
else: else:
response['data'] = execution_result.data response["data"] = execution_result.data
if self.batch: if self.batch:
response['id'] = id response["id"] = id
response['status'] = status_code response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql) result = self.json_encode(request, response, pretty=show_graphiql)
else: else:
@ -181,22 +190,21 @@ class GraphQLView(View):
return render(request, self.graphiql_template, data) return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False): def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get('pretty'): if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(',', ':')) return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True, return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
indent=2, separators=(',', ': '))
def parse_body(self, request): def parse_body(self, request):
content_type = self.get_content_type(request) content_type = self.get_content_type(request)
if content_type == 'application/graphql': if content_type == "application/graphql":
return {'query': request.body.decode()} return {"query": request.body.decode()}
elif content_type == 'application/json': elif content_type == "application/json":
# noinspection PyBroadException # noinspection PyBroadException
try: try:
body = request.body.decode('utf-8') body = request.body.decode("utf-8")
except Exception as e: except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e))) raise HttpError(HttpResponseBadRequest(str(e)))
@ -204,102 +212,113 @@ class GraphQLView(View):
request_json = json.loads(body) request_json = json.loads(body)
if self.batch: if self.batch:
assert isinstance(request_json, list), ( assert isinstance(request_json, list), (
'Batch requests should receive a list, but received {}.' "Batch requests should receive a list, but received {}."
).format(repr(request_json)) ).format(repr(request_json))
assert len(request_json) > 0, ( assert (
'Received an empty list in the batch request.' len(request_json) > 0
) ), "Received an empty list in the batch request."
else: else:
assert isinstance(request_json, dict), ( assert isinstance(
'The received data is not a valid JSON query.' request_json, dict
) ), "The received data is not a valid JSON query."
return request_json return request_json
except AssertionError as e: except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e))) raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError): except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest( raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
'POST body sent invalid JSON.'))
elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']: elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST return request.POST
return {} return {}
def execute(self, *args, **kwargs): def execute_graphql_request(
return execute(self.schema, *args, **kwargs) self, request, data, query, variables, operation_name, show_graphiql=False
):
def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False):
if not query: if not query:
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(HttpResponseBadRequest( raise HttpError(HttpResponseBadRequest("Must provide query string."))
'Must provide query string.'))
source = Source(query, name='GraphQL request')
try: try:
document_ast = parse(source) backend = self.get_backend(request)
validation_errors = validate(self.schema, document_ast) document = backend.document_from_string(self.schema, query)
if validation_errors:
return ExecutionResult(
errors=validation_errors,
invalid=True,
)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e], invalid=True) return ExecutionResult(errors=[e], invalid=True)
if request.method.lower() == 'get': if request.method.lower() == "get":
operation_ast = get_operation_ast(document_ast, operation_name) operation_type = document.get_operation_type(operation_name)
if operation_ast and operation_ast.operation != 'query': if operation_type and operation_type != "query":
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(HttpResponseNotAllowed( raise HttpError(
['POST'], 'Can only perform a {} operation from a POST request.'.format( HttpResponseNotAllowed(
operation_ast.operation) ["POST"],
)) "Can only perform a {} operation from a POST request.".format(
operation_type
),
)
)
try: try:
return self.execute( extra_options = {}
document_ast, if self.executor:
root_value=self.get_root_value(request), # We only include it optionally since
variable_values=variables, # executor is not a valid argument in all backends
extra_options["executor"] = self.executor
return document.execute(
root=self.get_root_value(request),
variables=variables,
operation_name=operation_name, operation_name=operation_name,
context_value=self.get_context(request), context=self.get_context(request),
middleware=self.get_middleware(request), middleware=self.get_middleware(request),
executor=self.executor, **extra_options
) )
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e], invalid=True) return ExecutionResult(errors=[e], invalid=True)
@classmethod @classmethod
def can_display_graphiql(cls, request, data): def can_display_graphiql(cls, request, data):
raw = 'raw' in request.GET or 'raw' in data raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request) return not raw and cls.request_wants_html(request)
@classmethod @classmethod
def request_wants_html(cls, request): def request_wants_html(cls, request):
accepted = get_accepted_content_types(request) accepted = get_accepted_content_types(request)
html_index = accepted.count('text/html') accepted_length = len(accepted)
json_index = accepted.count('application/json') # the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_index > json_index return html_priority > json_priority
@staticmethod @staticmethod
def get_graphql_params(request, data): def get_graphql_params(request, data):
query = request.GET.get('query') or data.get('query') query = request.GET.get("query") or data.get("query")
variables = request.GET.get('variables') or data.get('variables') variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get('id') or data.get('id') id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, six.text_type): if variables and isinstance(variables, six.text_type):
try: try:
variables = json.loads(variables) variables = json.loads(variables)
except Exception: except Exception:
raise HttpError(HttpResponseBadRequest( raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
'Variables are invalid JSON.'))
operation_name = request.GET.get( operation_name = request.GET.get("operationName") or data.get("operationName")
'operationName') or data.get('operationName')
if operation_name == "null": if operation_name == "null":
operation_name = None operation_name = None
@ -310,11 +329,10 @@ class GraphQLView(View):
if isinstance(error, GraphQLError): if isinstance(error, GraphQLError):
return format_graphql_error(error) return format_graphql_error(error)
return {'message': six.text_type(error)} return {"message": six.text_type(error)}
@staticmethod @staticmethod
def get_content_type(request): def get_content_type(request):
meta = request.META meta = request.META
content_type = meta.get( content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) return content_type.split(";", 1)[0].lower()
return content_type.split(';', 1)[0].lower()

View File

@ -3,77 +3,63 @@ import sys
import ast import ast
import re import re
_version_re = re.compile(r'__version__\s+=\s+(.*)') _version_re = re.compile(r"__version__\s+=\s+(.*)")
with open('graphene_django/__init__.py', 'rb') as f: with open("graphene_django/__init__.py", "rb") as f:
version = str(ast.literal_eval(_version_re.search( version = str(
f.read().decode('utf-8')).group(1))) ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))
)
rest_framework_require = [ rest_framework_require = ["djangorestframework>=3.6.3"]
'djangorestframework>=3.6.3',
]
tests_require = [ tests_require = [
'pytest>=2.7.2', "pytest>=3.6.3",
'pytest-cov', "pytest-cov",
'coveralls', "coveralls",
'mock', "mock",
'pytz', "pytz",
'django-filter', "django-filter<2;python_version<'3'",
'pytest-django==2.9.1', "django-filter>=2;python_version>='3'",
"pytest-django>=3.3.2",
] + rest_framework_require ] + rest_framework_require
django_version = 'Django>=1.8.0,<2' if sys.version_info[0] < 3 else 'Django>=1.8.0'
setup( setup(
name='graphene-django', name="graphene-django",
version=version, version=version,
description="Graphene Django integration",
description='Graphene Django integration', long_description=open("README.rst").read(),
long_description=open('README.rst').read(), url="https://github.com/graphql-python/graphene-django",
author="Syrus Akbary",
url='https://github.com/graphql-python/graphene-django', author_email="me@syrusakbary.com",
license="MIT",
author='Syrus Akbary',
author_email='me@syrusakbary.com',
license='MIT',
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha', "Development Status :: 3 - Alpha",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'Topic :: Software Development :: Libraries', "Topic :: Software Development :: Libraries",
'Programming Language :: Python :: 2', "Programming Language :: Python :: 2",
'Programming Language :: Python :: 2.7', "Programming Language :: Python :: 2.7",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'Programming Language :: Python :: 3.3', "Programming Language :: Python :: 3.4",
'Programming Language :: Python :: 3.4', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: Implementation :: PyPy', "Programming Language :: Python :: Implementation :: PyPy",
], ],
keywords="api graphql protocol rest relay graphene",
keywords='api graphql protocol rest relay graphene', packages=find_packages(exclude=["tests"]),
packages=find_packages(exclude=['tests']),
install_requires=[ install_requires=[
'six>=1.10.0', "six>=1.10.0",
'graphene>=2.0,<3', "graphene>=2.1.3,<3",
django_version, "graphql-core>=2.1.0,<3",
'iso8601', "Django>=1.11",
'singledispatch>=3.4.0.3', "singledispatch>=3.4.0.3",
'promise>=2.1', "promise>=2.1",
],
setup_requires=[
'pytest-runner',
], ],
setup_requires=["pytest-runner"],
tests_require=tests_require, tests_require=tests_require,
rest_framework_require=rest_framework_require, rest_framework_require=rest_framework_require,
extras_require={ extras_require={"test": tests_require, "rest_framework": rest_framework_require},
'test': tests_require,
'rest_framework': rest_framework_require,
},
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
platforms='any', platforms="any",
) )