Merge branch 'master' into master

This commit is contained in:
Carlos Martinez 2019-01-14 16:58:25 -05:00 committed by GitHub
commit 9bbd9ff9cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
88 changed files with 3030 additions and 1598 deletions

View File

@ -11,9 +11,6 @@ install:
pip install -e .[test]
pip install psycopg2 # Required for Django postgres fields testing
pip install django==$DJANGO_VERSION
if [ $DJANGO_VERSION = 1.8 ]; then # DRF dropped 1.8 support at 3.7.0
pip install djangorestframework==3.6.4
fi
python setup.py develop
elif [ "$TEST_TYPE" = lint ]; then
pip install flake8
@ -38,13 +35,19 @@ env:
matrix:
fast_finish: true
include:
- python: '3.4'
env: TEST_TYPE=build DJANGO_VERSION=2.0
- python: '3.5'
env: TEST_TYPE=build DJANGO_VERSION=2.0
- python: '3.6'
env: TEST_TYPE=build DJANGO_VERSION=2.0
- python: '3.5'
env: TEST_TYPE=build DJANGO_VERSION=2.1
- python: '3.6'
env: TEST_TYPE=build DJANGO_VERSION=2.1
- python: '2.7'
env: TEST_TYPE=build DJANGO_VERSION=1.8
- python: '2.7'
env: TEST_TYPE=build DJANGO_VERSION=1.9
- python: '2.7'
env: TEST_TYPE=build DJANGO_VERSION=1.10
- python: '2.7'
env: TEST_TYPE=lint
- python: '3.6'
env: TEST_TYPE=lint
deploy:
provider: pypi

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016-Present Syrus Akbary
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,2 +1,2 @@
include README.md
include README.md LICENSE
recursive-include graphene_django/templates *

View File

@ -9,10 +9,10 @@ A [Django](https://www.djangoproject.com/) integration for [Graphene](http://gra
## Installation
For instaling graphene, just run this command in your shell
For installing graphene, just run this command in your shell
```bash
pip install "graphene-django>=2.0.dev"
pip install "graphene-django>=2.0"
```
### Settings
@ -67,8 +67,7 @@ class User(DjangoObjectType):
class Query(graphene.ObjectType):
users = graphene.List(User)
@graphene.resolve_only_args
def resolve_users(self):
def resolve_users(self, info):
return UserModel.objects.all()
schema = graphene.Schema(query=Query)

View File

@ -13,11 +13,11 @@ A `Django <https://www.djangoproject.com/>`__ integration for
Installation
------------
For instaling graphene, just run this command in your shell
For installing graphene, just run this command in your shell
.. code:: bash
pip install "graphene-django>=2.0.dev"
pip install "graphene-django>=2.0"
Settings
~~~~~~~~

View File

@ -8,6 +8,7 @@ SECRET_KEY = 1
INSTALLED_APPS = [
'graphene_django',
'graphene_django.rest_framework',
'graphene_django.tests',
'starwars',
]

View File

@ -20,7 +20,7 @@ Let's use a simple example model.
Limiting Field Access
---------------------
This is easy, simply use the ``only_fields`` meta attribute.
To limit fields in a GraphQL query simply use the ``only_fields`` meta attribute.
.. code:: python
@ -34,7 +34,7 @@ This is easy, simply use the ``only_fields`` meta attribute.
only_fields = ('title', 'content')
interfaces = (relay.Node, )
conversely you can use ``exclude_fields`` meta atrribute.
conversely you can use ``exclude_fields`` meta attribute.
.. code:: python
@ -61,10 +61,11 @@ define a resolve method for that field and return the desired queryset.
from .models import Post
class Query(ObjectType):
all_posts = DjangoFilterConnectionField(CategoryNode)
all_posts = DjangoFilterConnectionField(PostNode)
def resolve_all_posts(self, info):
return Post.objects.filter(published=True)
def resolve_all_posts(self, args, info):
return Post.objects.filter(published=True)
User-based Queryset Filtering
-----------------------------
@ -79,7 +80,7 @@ with the context argument.
from .models import Post
class Query(ObjectType):
my_posts = DjangoFilterConnectionField(CategoryNode)
my_posts = DjangoFilterConnectionField(PostNode)
def resolve_my_posts(self, info):
# context will reference to the Django request
@ -95,7 +96,7 @@ schema is simple.
result = schema.execute(query, context_value=request)
Filtering ID-based node access
Filtering ID-based Node Access
------------------------------
In order to add authorization to id-based node access, we need to add a
@ -113,13 +114,13 @@ method to your ``DjangoObjectType``.
interfaces = (relay.Node, )
@classmethod
def get_node(cls, id, context, info):
def get_node(cls, id, info):
try:
post = cls._meta.model.objects.get(id=id)
except cls._meta.model.DoesNotExist:
return None
if post.published or context.user == post.owner:
if post.published or info.context.user == post.owner:
return post
return None
@ -206,14 +207,14 @@ Connection example:
all_reporters = MyAuthDjangoConnectionField(ReporterType)
Adding login required
Adding Login Required
---------------------
If you want to use the standard Django LoginRequiredMixin_ you can create your own view, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``:
To restrict users from accessing the GraphQL API page the standard Django LoginRequiredMixin_ can be used to create your own standard Django Class Based View, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``.:
.. code:: python
#views.py
from django.contrib.auth.mixins import LoginRequiredMixin
from graphene_django.views import GraphQLView
@ -221,7 +222,9 @@ If you want to use the standard Django LoginRequiredMixin_ you can create your o
class PrivateGraphQLView(LoginRequiredMixin, GraphQLView):
pass
After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``:
After this, you can use the new ``PrivateGraphQLView`` in the project's URL Configuration file ``url.py``:
For Django 1.9 and below:
.. code:: python
@ -229,5 +232,14 @@ After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``:
# some other urls
url(r'^graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
]
For Django 2.0 and above:
.. code:: python
urlpatterns = [
# some other urls
path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
]
.. _LoginRequiredMixin: https://docs.djangoproject.com/en/1.10/topics/auth/default/#the-loginrequired-mixin

View File

@ -2,9 +2,9 @@ Filtering
=========
Graphene integrates with
`django-filter <https://django-filter.readthedocs.org>`__ to provide
filtering of results. See the `usage
documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__
`django-filter <https://django-filter.readthedocs.io/en/master/>`__ (2.x for
Python 3 or 1.x for Python 2) to provide filtering of results. See the `usage
documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__
for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``.
@ -15,7 +15,7 @@ You will need to install it manually, which can be done as follows:
.. code:: bash
# You'll need to django-filter
pip install django-filter
pip install django-filter>=2
Note: The techniques below are demoed in the `cookbook example
app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__.
@ -26,7 +26,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering
documentation <https://django-filter.readthedocs.io/en/latest/guide/usage.html#the-filter>`__
documentation <https://django-filter.readthedocs.io/en/master/guide/usage.html#the-filter>`__
for full details on the range of options available.
For example:
@ -126,3 +126,23 @@ create your own ``Filterset`` as follows:
# We specify our custom AnimalFilter using the filterset_class param
all_animals = DjangoFilterConnectionField(AnimalNode,
filterset_class=AnimalFilter)
The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/master/guide/usage.html#request-based-filtering>`__
in a ``django_filters.FilterSet`` instance. You can use this to customize your
filters to be context-dependent. We could modify the ``AnimalFilter`` above to
pre-filter animals owned by the authenticated user (set in ``context.user``).
.. code:: python
class AnimalFilter(django_filters.FilterSet):
# Do case-insensitive lookups on 'name'
name = django_filters.CharFilter(lookup_type='iexact')
class Meta:
model = Animal
fields = ['name', 'genus', 'is_domesticated']
@property
def qs(self):
# The query context can be found in self.request.
return super(AnimalFilter, self).qs.filter(owner=self.request.user)

68
docs/form-mutations.rst Normal file
View File

@ -0,0 +1,68 @@
Integration with Django forms
=============================
Graphene-Django comes with mutation classes that will convert the fields on Django forms into inputs on a mutation.
*Note: the API is experimental and will likely change in the future.*
FormMutation
------------
.. code:: python
class MyForm(forms.Form):
name = forms.CharField()
class MyMutation(FormMutation):
class Meta:
form_class = MyForm
``MyMutation`` will automatically receive an ``input`` argument. This argument should be a ``dict`` where the key is ``name`` and the value is a string.
ModelFormMutation
-----------------
``ModelFormMutation`` will pull the fields from a ``ModelForm``.
.. code:: python
class Pet(models.Model):
name = models.CharField()
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = ('name',)
# This will get returned when the mutation completes successfully
class PetType(DjangoObjectType):
class Meta:
model = Pet
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
``PetMutation`` will grab the fields from ``PetForm`` and turn them into inputs. If the form is valid then the mutation
will lookup the ``DjangoObjectType`` for the ``Pet`` model and return that under the key ``pet``. Otherwise it will
return a list of errors.
You can change the input name (default is ``input``) and the return field name (default is the model name lowercase).
.. code:: python
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
input_field_name = 'data'
return_field_name = 'my_pet'
Form validation
---------------
Form mutations will call ``is_valid()`` on your forms.
If the form is valid then ``form_valid(form, info)`` is called on the mutation. Override this method to change how
the form is saved or to return a different Graphene object type.
If the form is *not* valid then a list of errors will be returned. These errors have two fields: ``field``, a string
containing the name of the invalid form field, and ``messages``, a list of strings with the validation messages.

View File

@ -12,4 +12,5 @@ Contents:
authorization
debug
rest-framework
form-mutations
introspection

View File

@ -1,3 +1,3 @@
sphinx
# Docs template
https://github.com/graphql-python/graphene-python.org/archive/docs.zip
http://graphene-python.org/sphinx_graphene_theme.zip

View File

@ -19,3 +19,46 @@ You can create a Mutation based on a serializer by using the
class Meta:
serializer_class = MySerializer
Create/Update Operations
---------------------
By default ModelSerializers accept create and update operations. To
customize this use the `model_operations` attribute. The update
operation looks up models by the primary key by default. You can
customize the look up with the lookup attribute.
.. code:: python
from graphene_django.rest_framework.mutation import SerializerMutation
class AwesomeModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['create', 'update']
lookup_field = 'id'
Overriding Update Queries
-------------------------
Use the method `get_serializer_kwargs` to override how
updates are applied.
.. code:: python
from graphene_django.rest_framework.mutation import SerializerMutation
class AwesomeModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
if 'id' in input:
instance = Post.objects.filter(id=input['id'], owner=info.context.user).first()
if instance:
return {'instance': instance, 'data': input, 'partial': True}
else:
raise http.Http404
return {'data': input, 'partial': True}

View File

@ -68,7 +68,8 @@ Let's get started with these models:
class Ingredient(models.Model):
name = models.CharField(max_length=100)
notes = models.TextField()
category = models.ForeignKey(Category, related_name='ingredients')
category = models.ForeignKey(
Category, related_name='ingredients', on_delete=models.CASCADE)
def __str__(self):
return self.name
@ -80,9 +81,10 @@ Add ingredients as INSTALLED_APPS:
INSTALLED_APPS = [
...
# Install the ingredients app
'ingredients',
'cookbook.ingredients',
]
Don't forget to create & run migrations:
.. code:: bash
@ -111,6 +113,18 @@ Alternatively you can use the Django admin interface to create some data
yourself. You'll need to run the development server (see below), and
create a login for yourself too (``./manage.py createsuperuser``).
Register models with admin panel:
.. code:: python
# cookbook/ingredients/admin.py
from django.contrib import admin
from cookbook.ingredients.models import Category, Ingredient
admin.site.register(Category)
admin.site.register(Ingredient)
Hello GraphQL - Schema and Object Types
---------------------------------------
@ -153,7 +167,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
model = Ingredient
class Query(graphene.AbstractType):
class Query(object):
all_categories = graphene.List(CategoryType)
all_ingredients = graphene.List(IngredientType)
@ -165,9 +179,9 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
return Ingredient.objects.select_related('category').all()
Note that the above ``Query`` class is marked as 'abstract'. This is
because we will now create a project-level query which will combine all
our app-level queries.
Note that the above ``Query`` class is a mixin, inheriting from
``object``. This is because we will now create a project-level query
class which will combine all our app-level mixins.
Create the parent project-level ``cookbook/schema.py``:
@ -426,7 +440,7 @@ We can update our schema to support that, by adding new query for ``ingredient``
model = Ingredient
class Query(graphene.AbstractType):
class Query(object):
category = graphene.Field(CategoryType,
id=graphene.Int(),
name=graphene.String())
@ -445,8 +459,8 @@ We can update our schema to support that, by adding new query for ``ingredient``
return Ingredient.objects.all()
def resolve_category(self, info, **kwargs):
id = kargs.get('id')
name = kargs.get('name')
id = kwargs.get('id')
name = kwargs.get('name')
if id is not None:
return Category.objects.get(pk=id)
@ -457,8 +471,8 @@ We can update our schema to support that, by adding new query for ``ingredient``
return None
def resolve_ingredient(self, info, **kwargs):
id = kargs.get('id')
name = kargs.get('name')
id = kwargs.get('id')
name = kwargs.get('name')
if id is not None:
return Ingredient.objects.get(pk=id)

View File

@ -10,7 +10,7 @@ app <https://github.com/graphql-python/graphene-django/tree/master/examples/cook
A good idea is to check the following things first:
* `Graphene Relay documentation <http://docs.graphene-python.org/en/latest/relay/>`__
* `GraphQL Relay Specification <https://facebook.github.io/relay/docs/graphql-relay-specification.html>`__
* `GraphQL Relay Specification <https://facebook.github.io/relay/docs/en/graphql-server-specification.html>`__
Setup the Django project
------------------------
@ -118,7 +118,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
.. code:: python
# cookbook/ingredients/schema.py
from graphene import relay, ObjectType, AbstractType
from graphene import relay, ObjectType
from graphene_django import DjangoObjectType
from graphene_django.filter import DjangoFilterConnectionField
@ -147,7 +147,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
interfaces = (relay.Node, )
class Query(AbstractType):
class Query(object):
category = relay.Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -3,7 +3,7 @@ Cookbook Example Django Project
This example project demos integration between Graphene and Django.
The project contains two apps, one named `ingredients` and another
named `recepies`.
named `recipes`.
Getting started
---------------
@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries!
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema)
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries)

View File

@ -14,7 +14,7 @@ class IngredientType(DjangoObjectType):
model = Ingredient
class Query(graphene.AbstractType):
class Query(object):
category = graphene.Field(CategoryType,
id=graphene.Int(),
name=graphene.String())

View File

@ -0,0 +1,2 @@
# Create your tests here.

View File

@ -0,0 +1,2 @@
# Create your views here.

View File

@ -6,6 +6,7 @@ from cookbook.ingredients.models import Ingredient
class Recipe(models.Model):
title = models.CharField(max_length=100)
instructions = models.TextField()
__unicode__ = lambda self: self.title
class RecipeIngredient(models.Model):

View File

@ -14,7 +14,7 @@ class RecipeIngredientType(DjangoObjectType):
model = RecipeIngredient
class Query(graphene.AbstractType):
class Query(object):
recipe = graphene.Field(RecipeType,
id=graphene.Int(),
title=graphene.String())

View File

@ -0,0 +1,2 @@
# Create your tests here.

View File

@ -0,0 +1,2 @@
# Create your views here.

View File

@ -1,4 +1,4 @@
graphene
graphene-django
graphql-core
graphql-core>=2.1rc1
django==1.9

View File

@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries!
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema)
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries)

View File

@ -2,9 +2,11 @@ from django.contrib import admin
from cookbook.ingredients.models import Category, Ingredient
@admin.register(Ingredient)
class IngredientAdmin(admin.ModelAdmin):
list_display = ("id","name","category")
list_editable = ("name","category")
list_display = ('id', 'name', 'category')
list_editable = ('name', 'category')
admin.site.register(Category)

View File

@ -10,7 +10,7 @@ class Category(models.Model):
class Ingredient(models.Model):
name = models.CharField(max_length=100)
notes = models.TextField(null=True,blank=True)
notes = models.TextField(null=True, blank=True)
category = models.ForeignKey(Category, related_name='ingredients')
def __str__(self):

View File

@ -1,5 +1,5 @@
from cookbook.ingredients.models import Category, Ingredient
from graphene import AbstractType, Node
from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType
@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType):
}
class Query(AbstractType):
class Query(object):
category = Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -2,9 +2,11 @@ from django.contrib import admin
from cookbook.recipes.models import Recipe, RecipeIngredient
class RecipeIngredientInline(admin.TabularInline):
model = RecipeIngredient
model = RecipeIngredient
@admin.register(Recipe)
class RecipeAdmin(admin.ModelAdmin):
inlines = [RecipeIngredientInline]
inlines = [RecipeIngredientInline]

View File

@ -8,6 +8,7 @@ class Recipe(models.Model):
instructions = models.TextField()
__unicode__ = lambda self: self.title
class RecipeIngredient(models.Model):
recipe = models.ForeignKey(Recipe, related_name='amounts')
ingredient = models.ForeignKey(Ingredient, related_name='used_by')

View File

@ -1,5 +1,5 @@
from cookbook.recipes.models import Recipe, RecipeIngredient
from graphene import AbstractType, Node
from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType
@ -24,7 +24,7 @@ class RecipeIngredientNode(DjangoObjectType):
}
class Query(AbstractType):
class Query(object):
recipe = Node.Field(RecipeNode)
all_recipes = DjangoFilterConnectionField(RecipeNode)

View File

@ -5,7 +5,9 @@ import graphene
from graphene_django.debug import DjangoDebug
class Query(cookbook.recipes.schema.Query, cookbook.ingredients.schema.Query, graphene.ObjectType):
class Query(cookbook.ingredients.schema.Query,
cookbook.recipes.schema.Query,
graphene.ObjectType):
debug = graphene.Field(DjangoDebug, name='__debug')

View File

@ -1,3 +1,4 @@
# flake8: noqa
"""
Django settings for cookbook project.

View File

@ -3,6 +3,7 @@ from django.contrib import admin
from graphene_django.views import GraphQLView
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^graphql', GraphQLView.as_view(graphiql=True)),

View File

@ -1,5 +1,5 @@
graphene
graphene-django
graphql-core
graphql-core>=2.1rc1
django==1.9
django-filter==0.11.0
django-filter>=2

View File

@ -0,0 +1,2 @@
[flake8]
exclude=migrations,.git,__pycache__

View File

@ -5,7 +5,7 @@ from django.db import models
class Character(models.Model):
name = models.CharField(max_length=50)
ship = models.ForeignKey('Ship', blank=True, null=True, related_name='characters')
ship = models.ForeignKey('Ship', on_delete=models.CASCADE, blank=True, null=True, related_name='characters')
def __str__(self):
return self.name
@ -13,7 +13,7 @@ class Character(models.Model):
class Faction(models.Model):
name = models.CharField(max_length=50)
hero = models.ForeignKey(Character)
hero = models.ForeignKey(Character, on_delete=models.CASCADE)
def __str__(self):
return self.name
@ -21,7 +21,7 @@ class Faction(models.Model):
class Ship(models.Model):
name = models.CharField(max_length=50)
faction = models.ForeignKey(Faction, related_name='ships')
faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name='ships')
def __str__(self):
return self.name

View File

@ -1,14 +1,6 @@
from .types import (
DjangoObjectType,
)
from .fields import (
DjangoConnectionField,
)
from .types import DjangoObjectType
from .fields import DjangoConnectionField
__version__ = '2.0.0'
__version__ = "2.2.0"
__all__ = [
'__version__',
'DjangoObjectType',
'DjangoConnectionField'
]
__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]

View File

@ -5,13 +5,7 @@ class MissingType(object):
try:
# Postgres fields are only available in Django with psycopg2 installed
# and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
from django.contrib.postgres.fields import (ArrayField, HStoreField,
JSONField, RangeField)
except ImportError:
ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
try:
# Postgres fields are only available in Django 1.9+
from django.contrib.postgres.fields import JSONField
except ImportError:
JSONField = MissingType
ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4

View File

@ -1,9 +1,22 @@
from django.db import models
from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
NonNull, String, UUID)
from graphene.types.datetime import DateTime, Time
from graphene import (
ID,
Boolean,
Dynamic,
Enum,
Field,
Float,
Int,
List,
NonNull,
String,
UUID,
DateTime,
Date,
Time,
)
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name
@ -33,37 +46,44 @@ def get_choices(choices):
else:
name = convert_choice_name(value)
while name in converted_names:
name += '_' + str(len(converted_names))
name += "_" + str(len(converted_names))
converted_names.append(name)
description = help_text
yield name, value, description
def convert_django_field_with_choices(field, registry=None):
choices = getattr(field, 'choices', None)
if registry is not None:
converted = registry.get_converted_field(field)
if converted:
return converted
choices = getattr(field, "choices", None)
if choices:
meta = field.model._meta
name = to_camel_case('{}_{}'.format(meta.object_name, field.name))
name = to_camel_case("{}_{}".format(meta.object_name, field.name))
choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object):
@property
def description(self):
return named_choices_descriptions[self.name]
enum = Enum(name, list(named_choices), type=EnumWithDescriptionsType)
return enum(description=field.help_text, required=not field.null)
return convert_django_field(field, registry)
converted = enum(description=field.help_text, required=not field.null)
else:
converted = convert_django_field(field, registry)
if registry is not None:
registry.register_converted_field(field, converted)
return converted
@singledispatch
def convert_django_field(field, registry=None):
raise Exception(
"Don't know how to convert the Django field %s (%s)" %
(field, field.__class__))
"Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
)
@convert_django_field.register(models.CharField)
@ -73,6 +93,7 @@ def convert_django_field(field, registry=None):
@convert_django_field.register(models.URLField)
@convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(models.FileField)
@convert_django_field.register(models.FilePathField)
def convert_field_to_string(field, registry=None):
return String(description=field.help_text, required=not field.null)
@ -113,9 +134,14 @@ def convert_field_to_float(field, registry=None):
return Float(description=field.help_text, required=not field.null)
@convert_django_field.register(models.DateTimeField)
def convert_datetime_to_string(field, registry=None):
return DateTime(description=field.help_text, required=not field.null)
@convert_django_field.register(models.DateField)
def convert_date_to_string(field, registry=None):
return DateTime(description=field.help_text, required=not field.null)
return Date(description=field.help_text, required=not field.null)
@convert_django_field.register(models.TimeField)
@ -134,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
# We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance
null = getattr(field, 'null', True)
null = getattr(field, "null", True)
return Field(_type, required=not null)
return Dynamic(dynamic_type)
@ -158,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)

View File

@ -1,4 +1,4 @@
from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]

View File

@ -7,7 +7,6 @@ from .types import DjangoDebug
class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
@ -38,20 +37,21 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object):
def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, 'django_debug', None)
django_debug = getattr(context, "django_debug", None)
if not django_debug:
if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts')
raise Exception("DjangoDebug cannot be executed in None contexts")
try:
context.django_debug = DjangoDebugContext()
except Exception:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
context.__class__.__name__
))
if info.schema.get_type('DjangoDebug') == info.return_type:
raise Exception(
"DjangoDebug need the context to be writable, context received: {}.".format(
context.__class__.__name__
)
)
if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, info, **args)
context.django_debug.add_promise(promise)

View File

@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
class ThreadLocalState(local):
def __init__(self):
self.enabled = True
@ -35,7 +34,7 @@ recording = state.recording # export function
def wrap_cursor(connection, panel):
if not hasattr(connection, '_graphene_cursor'):
if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor
def cursor():
@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
def unwrap_cursor(connection):
if hasattr(connection, '_graphene_cursor'):
if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor
del connection._graphene_cursor
@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
if not params:
return params
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.items())
return dict((key, self._quote_expr(value)) for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
return '(encoded string)'
return "(encoded string)"
def _record(self, method, sql, params):
start_time = time()
@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
return method(sql, params)
finally:
stop_time = time()
duration = (stop_time - start_time)
_params = ''
duration = stop_time - start_time
_params = ""
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
alias = getattr(self.db, 'alias', 'default')
alias = getattr(self.db, "alias", "default")
conn = self.db.connection
vendor = getattr(conn, 'vendor', 'unknown')
vendor = getattr(conn, "vendor", "unknown")
params = {
'vendor': vendor,
'alias': alias,
'sql': self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'start_time': start_time,
'stop_time': stop_time,
'is_slow': duration > 10,
'is_select': sql.lower().strip().startswith('select'),
"vendor": vendor,
"alias": alias,
"sql": self.db.ops.last_executed_query(
self.cursor, sql, self._quote_params(params)
),
"duration": duration,
"raw_sql": sql,
"params": _params,
"start_time": start_time,
"stop_time": stop_time,
"is_slow": duration > 10,
"is_select": sql.lower().strip().startswith("select"),
}
if vendor == 'postgresql':
if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = 'unknown'
params.update({
'trans_id': self.logger.get_transaction_id(alias),
'trans_status': conn.get_transaction_status(),
'iso_level': iso_level,
'encoding': conn.encoding,
})
iso_level = "unknown"
params.update(
{
"trans_id": self.logger.get_transaction_id(alias),
"trans_status": conn.get_transaction_status(),
"iso_level": iso_level,
"encoding": conn.encoding,
}
)
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility

View File

@ -2,19 +2,53 @@ from graphene import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType):
vendor = String()
alias = String()
sql = String()
duration = Float()
raw_sql = String()
params = String()
start_time = Float()
stop_time = Float()
is_slow = Boolean()
is_select = Boolean()
class Meta:
description = (
"Represents a single database query made to a Django managed DB."
)
vendor = String(
required=True,
description=(
"The type of database being used (e.g. postrgesql, mysql, sqlite)."
),
)
alias = String(
required=True,
description="The Django database alias (e.g. 'default').",
)
sql = String(description="The actual SQL sent to this database.")
duration = Float(
required=True,
description="Duration of this database query in seconds.",
)
raw_sql = String(
required=True,
description="The raw SQL of this query, without params.",
)
params = String(
required=True,
description="JSON encoded database query parameters.",
)
start_time = Float(
required=True,
description="Start time of this database query.",
)
stop_time = Float(
required=True,
description="Stop time of this database query.",
)
is_slow = Boolean(
required=True,
description="Whether this database query took more than 10 seconds.",
)
is_select = Boolean(
required=True,
description="Whether this database query was a SELECT.",
)
# Postgres
trans_id = String()
trans_status = String()
iso_level = String()
encoding = String()
trans_id = String(description="Postgres transaction ID if available.")
trans_status = String(description="Postgres transaction status if available.")
iso_level = String(description="Postgres isolation level if available.")
encoding = String(description="Postgres connection encoding if available.")

View File

@ -12,31 +12,31 @@ from ..types import DjangoDebug
class context(object):
pass
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_field():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
query = '''
query = """
query ReporterQuery {
reporter {
lastName
@ -47,43 +47,40 @@ def test_should_query_field():
}
}
}
'''
"""
expected = {
'reporter': {
'lastName': 'ABA',
"reporter": {"lastName": "ABA"},
"__debug": {
"sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
},
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
}]
}
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
def test_should_query_list():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters {
lastName
@ -94,45 +91,38 @@ def test_should_query_list():
}
}
}
'''
"""
expected = {
'allReporters': [{
'lastName': 'ABA',
}, {
'lastName': 'Griffin',
}],
'__debug': {
'sql': [{
'rawSql': str(Reporter.objects.all().query)
}]
}
"allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
"__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data == expected
def test_should_query_connection():
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@ -147,48 +137,41 @@ def test_should_query_connection():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query
assert result.data["__debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField
r1 = Reporter(last_name='ABA')
r1 = Reporter(last_name="ABA")
r1.save()
r2 = Reporter(last_name='Griffin')
r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug')
debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
query = '''
query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'lastName': 'ABA',
}
}]
},
}
"""
expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
result = schema.execute(
query, context_value=context(), middleware=[DjangoDebugMiddleware()]
)
assert not result.errors
assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
assert result.data["allReporters"] == expected["allReporters"]
assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
assert result.data['__debug']['sql'][1]['rawSql'] == query
assert result.data["__debug"]["sql"][1]["rawSql"] == query

View File

@ -4,4 +4,10 @@ from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType):
sql = List(DjangoDebugSQL)
class Meta:
description = "Debugging information for the current query."
sql = List(
DjangoDebugSQL,
description="Executed SQL queries for this API query.",
)

View File

@ -13,7 +13,6 @@ from .utils import maybe_queryset
class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@ -30,25 +29,28 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
self.on = kwargs.pop('on', False)
self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
'max_limit',
graphene_settings.RELAY_CONNECTION_MAX_LIMIT
"max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
'enforce_first_or_last',
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
"enforce_first_or_last",
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
assert issubclass(
_type, DjangoObjectType
), "DjangoConnectionField only accepts DjangoObjectType types"
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
@ -67,6 +69,10 @@ class DjangoConnectionField(ConnectionField):
@classmethod
def merge_querysets(cls, default_queryset, queryset):
if default_queryset.query.distinct and not queryset.query.distinct:
queryset = queryset.distinct()
elif queryset.query.distinct and not default_queryset.query.distinct:
default_queryset = default_queryset.distinct()
return queryset & default_queryset
@classmethod
@ -96,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
return connection
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, info, **args):
first = args.get('first')
last = args.get('last')
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
root,
info,
**args
):
first = args.get("first")
last = args.get("last")
if enforce_first_or_last:
assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
"You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
"Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit)
args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
).format(first, info.field_name, max_limit)
args['last'] = min(last, max_limit)
"Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
).format(last, info.field_name, max_limit)
args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
@ -134,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
self.type,
self.get_manager(),
self.max_limit,
self.enforce_first_or_last
self.enforce_first_or_last,
)

View File

@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
"be installed. You can do so using `pip install django-filter`", ImportWarning
"be installed. You can do so using `pip install django-filter`",
ImportWarning,
)
else:
from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
__all__ = ['DjangoFilterConnectionField',
'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter']
__all__ = [
"DjangoFilterConnectionField",
"GlobalIDFilter",
"GlobalIDMultipleChoiceFilter",
]

View File

@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField):
def __init__(self, type, fields=None, order_by=None,
extra_filter_meta=None, filterset_class=None,
*args, **kwargs):
def __init__(
self,
type,
fields=None,
order_by=None,
extra_filter_meta=None,
filterset_class=None,
*args,
**kwargs
):
self._fields = fields
self._provided_filterset_class = filterset_class
self._filterset_class = None
@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self):
if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model,
fields=fields)
meta = dict(model=self.model, fields=fields)
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta)
self._filterset_class = get_filterset_class(
self._provided_filterset_class, **meta
)
return self._filterset_class
@ -43,8 +50,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
@staticmethod
def merge_querysets(default_queryset, queryset):
@classmethod
def merge_querysets(cls, default_queryset, queryset):
# There could be the case where the default queryset (returned from the filterclass)
# and the resolver queryset have some limits on it.
# We only would be able to apply one of those, but not both
@ -52,27 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
# See related PR: https://github.com/graphql-python/graphene-django/pull/126
assert not (default_queryset.query.low_mark and queryset.query.low_mark), (
'Received two sliced querysets (low mark) in the connection, please slice only in one.'
)
assert not (default_queryset.query.high_mark and queryset.query.high_mark), (
'Received two sliced querysets (high mark) in the connection, please slice only in one.'
)
assert not (
default_queryset.query.low_mark and queryset.query.low_mark
), "Received two sliced querysets (low mark) in the connection, please slice only in one."
assert not (
default_queryset.query.high_mark and queryset.query.high_mark
), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits()
queryset = default_queryset & queryset
queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
default_queryset, queryset
)
queryset.query.set_limits(low, high)
return queryset
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, filterset_class, filtering_args,
root, info, **args):
def connection_resolver(
cls,
resolver,
connection,
default_manager,
max_limit,
enforce_first_or_last,
filterset_class,
filtering_args,
root,
info,
**args
):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
data=filter_kwargs,
queryset=default_manager.get_queryset()
queryset=default_manager.get_queryset(),
request=info.context,
).qs
return super(DjangoFilterConnectionField, cls).connection_resolver(
@ -95,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.max_limit,
self.enforce_first_or_last,
self.filterset_class,
self.filtering_args
self.filtering_args,
)

View File

@ -1,8 +1,7 @@
import itertools
from django.db import models
from django.utils.text import capfirst
from django_filters import Filter, MultipleChoiceFilter
from django_filters import Filter, MultipleChoiceFilter, VERSION
from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
@ -15,7 +14,10 @@ class GlobalIDFilter(Filter):
field_class = GlobalIDFormField
def filter(self, qs, value):
_type, _id = from_global_id(value)
""" Convert the filter value to a primary key before filtering """
_id = None
if value is not None:
_, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
@ -28,71 +30,76 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: {
'filter_class': GlobalIDFilter,
},
models.OneToOneField: {
'filter_class': GlobalIDFilter,
},
models.ForeignKey: {
'filter_class': GlobalIDFilter,
},
models.ManyToManyField: {
'filter_class': GlobalIDMultipleChoiceFilter,
}
models.AutoField: {"filter_class": GlobalIDFilter},
models.OneToOneField: {"filter_class": GlobalIDFilter},
models.ForeignKey: {"filter_class": GlobalIDFilter},
models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter},
models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter},
}
class GrapheneFilterSetMixin(BaseFilterSet):
FILTER_DEFAULTS = dict(itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(),
GRAPHENE_FILTER_SET_OVERRIDES.items()
))
""" A django_filters.filterset.BaseFilterSet with default filter overrides
to handle global IDs """
@classmethod
def filter_for_reverse_field(cls, f, name):
"""Handles retrieving filters for reverse relationships
FILTER_DEFAULTS = dict(
itertools.chain(
FILTER_FOR_DBFIELD_DEFAULTS.items(),
GRAPHENE_FILTER_SET_OVERRIDES.items()
)
)
We override the default implementation so that we can handle
Global IDs (the default implementation expects database
primary keys)
"""
rel = f.field.rel
default = {
'name': name,
'label': capfirst(rel.related_name)
}
if rel.multiple:
# For to-many relationships
return GlobalIDMultipleChoiceFilter(**default)
else:
# For to-one relationships
return GlobalIDFilter(**default)
# To support a Django 1.11 + Python 2.7 combination django-filter must be
# < 2.x.x. To support the earlier version of django-filter, the
# filter_for_reverse_field method must be present on GrapheneFilterSetMixin and
# must not be present for later versions of django-filter.
if VERSION[0] < 2:
from django.utils.text import capfirst
class GrapheneFilterSetMixinPython2(GrapheneFilterSetMixin):
@classmethod
def filter_for_reverse_field(cls, f, name):
"""Handles retrieving filters for reverse relationships
We override the default implementation so that we can handle
Global IDs (the default implementation expects database
primary keys)
"""
try:
rel = f.field.remote_field
except AttributeError:
rel = f.field.rel
default = {"name": name, "label": capfirst(rel.related_name)}
if rel.multiple:
# For to-many relationships
return GlobalIDMultipleChoiceFilter(**default)
else:
# For to-one relationships
return GlobalIDFilter(**default)
GrapheneFilterSetMixin = GrapheneFilterSetMixinPython2
def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality
"""
return type(
'Graphene{}'.format(filterset_class.__name__),
"Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin),
{},
)
def custom_filterset_factory(model, filterset_base_class=FilterSet,
**meta):
def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data
"""
meta.update({
'model': model,
})
meta_class = type(str('Meta'), (object,), meta)
meta.update({"model": model})
meta_class = type(str("Meta"), (object,), meta)
filterset = type(
str('%sFilterSet' % model._meta.object_name),
str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin),
{
'Meta': meta_class
}
{"Meta": meta_class},
)
return filterset

View File

@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = {
'headline': ['exact', 'icontains'],
'pub_date': ['gt', 'lt', 'exact'],
'reporter': ['exact'],
"headline": ["exact", "icontains"],
"pub_date": ["gt", "lt", "exact"],
"reporter": ["exact"],
}
order_by = OrderingFilter(fields=('pub_date',))
order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['first_name', 'last_name', 'email', 'pets']
fields = ["first_name", "last_name", "email", "pets"]
order_by = OrderingFilter(fields=('pub_date',))
order_by = OrderingFilter(fields=("pub_date",))
class PetFilter(django_filters.FilterSet):
class Meta:
model = Pet
fields = ['name']
fields = ["name"]

View File

@ -2,50 +2,60 @@ from datetime import datetime
import pytest
from graphene import Field, ObjectType, Schema, Argument, Float
from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
from graphene_django.forms import (GlobalIDFormField,
GlobalIDMultipleChoiceField)
from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
# for annotation test
from django.db.models import TextField, Value
from django.db.models.functions import Concat
pytestmark = []
if DJANGO_FILTER_INSTALLED:
import django_filters
from django_filters import FilterSet, NumberFilter
from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter)
from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter
from graphene_django.filter import (
GlobalIDFilter,
DjangoFilterConnectionField,
GlobalIDMultipleChoiceFilter,
)
from graphene_django.filter.tests.filters import (
ArticleFilter,
PetFilter,
ReporterFilter,
)
else:
pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible'))
pytestmark.append(
pytest.mark.skipif(
True, reason="django_filters not installed or not compatible"
)
)
pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED:
class ArticleNode(DjangoObjectType):
class ArticleNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ('headline', )
interfaces = (Node,)
filter_fields = ("headline",)
class ReporterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class PetNode(DjangoObjectType):
class Meta:
model = Pet
interfaces = (Node, )
interfaces = (Node,)
# schema = Schema()
@ -55,58 +65,47 @@ def get_args(field):
def assert_arguments(field, *arguments):
ignore = ('after', 'before', 'first', 'last', 'order_by')
ignore = ("after", "before", "first", "last", "order_by")
args = get_args(field)
actual = [
name
for name in args
if name not in ignore and not name.startswith('_')
]
assert set(arguments) == set(actual), \
'Expected arguments ({}) did not match actual ({})'.format(
arguments,
actual
)
actual = [name for name in args if name not in ignore and not name.startswith("_")]
assert set(arguments) == set(
actual
), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
def assert_orderable(field):
args = get_args(field)
assert 'order_by' in args, \
'Field cannot be ordered'
assert "order_by" in args, "Field cannot be ordered"
def assert_not_orderable(field):
args = get_args(field)
assert 'order_by' not in args, \
'Field can be ordered'
assert "order_by" not in args, "Field can be ordered"
def test_filter_explicit_filterset_arguments():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
assert_arguments(field,
'headline', 'headline__icontains',
'pub_date', 'pub_date__gt', 'pub_date__lt',
'reporter',
)
assert_arguments(
field,
"headline",
"headline__icontains",
"pub_date",
"pub_date__gt",
"pub_date__lt",
"reporter",
)
def test_filter_shortcut_filterset_arguments_list():
field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter'])
assert_arguments(field,
'pub_date',
'reporter',
)
field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
assert_arguments(field, "pub_date", "reporter")
def test_filter_shortcut_filterset_arguments_dict():
field = DjangoFilterConnectionField(ArticleNode, fields={
'headline': ['exact', 'icontains'],
'reporter': ['exact'],
})
assert_arguments(field,
'headline', 'headline__icontains',
'reporter',
)
field = DjangoFilterConnectionField(
ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
)
assert_arguments(field, "headline", "headline__icontains", "reporter")
def test_filter_explicit_filterset_orderable():
@ -130,39 +129,91 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta():
field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
'exclude': ('headline', )
})
assert 'headline' not in field.filterset_class.get_fields()
field = DjangoFilterConnectionField(
ArticleNode, extra_filter_meta={"exclude": ("headline",)}
)
assert "headline" not in field.filterset_class.get_fields()
def test_filter_shortcut_filterset_context():
class ArticleContextFilter(django_filters.FilterSet):
class Meta:
model = Article
exclude = set()
@property
def qs(self):
qs = super(ArticleContextFilter, self).qs
return qs.filter(reporter=self.request.reporter)
class Query(ObjectType):
context_articles = DjangoFilterConnectionField(
ArticleNode, filterset_class=ArticleContextFilter
)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
editor=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
editor=r2,
)
class context(object):
reporter = r2
query = """
query {
contextArticles {
edges {
node {
headline
}
}
}
}
"""
schema = Schema(query=Query)
result = schema.execute(query, context_value=context())
assert not result.errors
assert len(result.data["contextArticles"]["edges"]) == 1
assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
field = DjangoFilterConnectionField(ReporterFilterNode)
assert_arguments(field, 'first_name', 'articles')
assert_arguments(field, "first_name", "articles")
assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ['headline', 'reporter']
interfaces = (Node,)
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -171,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
article = Field(ArticleFilterNode)
schema = Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields['articles'].get_type()
assert_arguments(articles_field, 'headline', 'reporter')
articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field)
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ['first_name', 'articles']
interfaces = (Node,)
filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
class Meta:
interfaces = (Node, )
interfaces = (Node,)
model = Article
filter_fields = ['headline', 'reporter']
filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@ -197,12 +246,22 @@ def test_filter_filterset_related_results():
reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2)
r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
Article.objects.create(
headline="a1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r1,
)
Article.objects.create(
headline="a2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r2,
)
query = '''
query = """
query {
allReporters {
edges {
@ -218,123 +277,134 @@ def test_filter_filterset_related_results():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get back a single article for each reporter
assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1
assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1
assert (
len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
)
assert (
len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
)
def test_global_id_field_implicit():
field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id']
id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = ['id']
fields = ["id"]
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['id']
id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_filterset_descriptions():
class ArticleIdFilter(django_filters.FilterSet):
class Meta:
model = Article
fields = ['id']
fields = ["id"]
max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time")
max_time = django_filters.NumberFilter(
method="filter_max_time", label="The maximum time"
)
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
max_time = field.args['max_time']
max_time = field.args["max_time"]
assert isinstance(max_time, Argument)
assert max_time.type == Float
assert max_time.description == 'The maximum time'
assert max_time.description == "The maximum time"
def test_global_id_field_relation():
field = DjangoFilterConnectionField(ArticleNode, fields=['reporter'])
field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
filterset_class = field.filterset_class
id_filter = filterset_class.base_filters['reporter']
id_filter = filterset_class.base_filters["reporter"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_multiple_field_implicit():
field = DjangoFilterConnectionField(ReporterNode, fields=['pets'])
field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets']
multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit():
class ReporterPetsFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['pets']
fields = ["pets"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['pets']
multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_implicit_reverse():
field = DjangoFilterConnectionField(ReporterNode, fields=['articles'])
field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles']
multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit_reverse():
class ReporterPetsFilter(django_filters.FilterSet):
class Meta:
model = Reporter
fields = ['articles']
fields = ["articles"]
field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
field = DjangoFilterConnectionField(
ReporterNode, filterset_class=ReporterPetsFilter
)
filterset_class = field.filterset_class
multiple_filter = filterset_class.base_filters['articles']
multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = {
'first_name': ['icontains']
}
interfaces = (Node,)
filter_fields = {"first_name": ["icontains"]}
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com')
r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com')
r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com')
r1 = Reporter.objects.create(
first_name="A test user", last_name="Last Name", email="test1@test.com"
)
r2 = Reporter.objects.create(
first_name="Other test user",
last_name="Other Last Name",
email="test2@test.com",
)
r3 = Reporter.objects.create(
first_name="Random", last_name="RandomLast", email="random@test.com"
)
query = '''
query = """
query {
allReporters(firstName_Icontains: "test") {
edges {
@ -344,12 +414,12 @@ def test_filter_filterset_related_results():
}
}
}
'''
"""
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get two reporters
assert len(result.data['allReporters']['edges']) == 2
assert len(result.data["allReporters"]["edges"]) == 2
def test_recursive_filter_connection():
@ -361,77 +431,73 @@ def test_recursive_filter_connection():
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode
assert (
ReporterFilterNode._meta.fields["child_reporters"].node_type
== ReporterFilterNode
)
def test_should_query_filter_node_limit():
class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit')
limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value):
return queryset[:value]
class Meta:
model = Reporter
fields = ['first_name', ]
fields = ["first_name"]
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
filter_fields = ('lang', )
interfaces = (Node,)
filter_fields = ("lang",)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
ReporterType,
filterset_class=ReporterFilter
ReporterType, filterset_class=ReporterFilter
)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')
return Reporter.objects.order_by("a_choice")
Reporter.objects.create(
first_name='Bob',
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
Article.objects.create(
headline='Article Node 1',
headline="Article Node 1",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r,
editor=r,
lang='es'
lang="es",
)
Article.objects.create(
headline='Article Node 2',
headline="Article Node 2",
pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r,
editor=r,
lang='en'
lang="en",
)
schema = Schema(query=Query)
query = '''
query = """
query NodeFilteringQuery {
allReporters(limit: 1) {
edges {
@ -450,24 +516,23 @@ def test_should_query_filter_node_limit():
}
}
}
'''
"""
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjI=',
'firstName': 'John',
'articles': {
'edges': [{
'node': {
'id': 'QXJ0aWNsZVR5cGU6MQ==',
'lang': 'ES'
}
}]
"allReporters": {
"edges": [
{
"node": {
"id": "UmVwb3J0ZXJUeXBlOjI=",
"firstName": "John",
"articles": {
"edges": [
{"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}}
]
},
}
}
}]
]
}
}
@ -478,45 +543,37 @@ def test_should_query_filter_node_limit():
def test_should_query_filter_node_double_limit_raises():
class ReporterFilter(FilterSet):
limit = NumberFilter(method='filter_limit')
limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value):
return queryset[:value]
class Meta:
model = Reporter
fields = ['first_name', ]
fields = ["first_name"]
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
ReporterType,
filterset_class=ReporterFilter
ReporterType, filterset_class=ReporterFilter
)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2]
return Reporter.objects.order_by("a_choice")[:2]
Reporter.objects.create(
first_name='Bob',
last_name='Doe',
email='bobdoe@example.com',
a_choice=2
first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
schema = Schema(query=Query)
query = '''
query = """
query NodeFilteringQuery {
allReporters(limit: 1) {
edges {
@ -527,10 +584,116 @@ def test_should_query_filter_node_double_limit_raises():
}
}
}
'''
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert str(result.errors[0]) == (
'Received two sliced querysets (high mark) in the connection, please slice only in one.'
"Received two sliced querysets (high mark) in the connection, please slice only in one."
)
def test_order_by_is_perserved():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
filter_fields = ()
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
ReporterType, reverse_order=Boolean()
)
def resolve_all_reporters(self, info, reverse_order=False, **args):
reporters = Reporter.objects.order_by("first_name")
if reverse_order:
return reporters.reverse()
return reporters
Reporter.objects.create(first_name="b")
r = Reporter.objects.create(first_name="a")
schema = Schema(query=Query)
query = """
query NodeFilteringQuery {
allReporters(first: 1) {
edges {
node {
firstName
}
}
}
}
"""
expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
reverse_query = """
query NodeFilteringQuery {
allReporters(first: 1, reverseOrder: true) {
edges {
node {
firstName
}
}
}
}
"""
reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}}
reverse_result = schema.execute(reverse_query)
assert not reverse_result.errors
assert reverse_result.data == reverse_expected
def test_annotation_is_perserved():
class ReporterType(DjangoObjectType):
full_name = String()
def resolve_full_name(instance, info, **args):
return instance.full_name
class Meta:
model = Reporter
interfaces = (Node,)
filter_fields = ()
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate(
full_name=Concat(
"first_name", Value(" "), "last_name", output_field=TextField()
)
)
Reporter.objects.create(first_name="John", last_name="Doe")
schema = Schema(query=Query)
query = """
query NodeFilteringQuery {
allReporters(first: 1) {
edges {
node {
fullName
}
}
}
}
"""
expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -8,7 +8,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
a Graphene Field. These arguments will be available to
filter against in the GraphQL
"""
from ..form_converter import convert_form_field
from ..forms.converter import convert_form_field
args = {}
for name, filter_field in six.iteritems(filterset_class.base_filters):

View File

@ -0,0 +1 @@
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa

View File

@ -1,30 +1,24 @@
from django import forms
from django.forms.fields import BaseTemporalField
from django.core.exceptions import ImproperlyConfigured
from graphene import ID, Boolean, Float, Int, List, String, UUID
from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from .utils import import_single_dispatch
from ..utils import import_single_dispatch
singledispatch = import_single_dispatch()
try:
UUIDField = forms.UUIDField
except AttributeError:
class UUIDField(object):
pass
@singledispatch
def convert_form_field(field):
raise Exception(
raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) "
"to Graphene type" %
(field, field.__class__)
"to Graphene type" % (field, field.__class__)
)
@convert_form_field.register(BaseTemporalField)
@convert_form_field.register(forms.fields.BaseTemporalField)
@convert_form_field.register(forms.CharField)
@convert_form_field.register(forms.EmailField)
@convert_form_field.register(forms.SlugField)
@ -36,7 +30,7 @@ def convert_form_field_to_string(field):
return String(description=field.help_text, required=field.required)
@convert_form_field.register(UUIDField)
@convert_form_field.register(forms.UUIDField)
def convert_form_field_to_uuid(field):
return UUID(description=field.help_text, required=field.required)
@ -69,6 +63,21 @@ def convert_form_field_to_list(field):
return List(ID, required=field.required)
@convert_form_field.register(forms.DateField)
def convert_form_field_to_date(field):
return Date(description=field.help_text, required=field.required)
@convert_form_field.register(forms.DateTimeField)
def convert_form_field_to_datetime(field):
return DateTime(description=field.help_text, required=field.required)
@convert_form_field.register(forms.TimeField)
def convert_form_field_to_time(field):
return Time(description=field.help_text, required=field.required)
@convert_form_field.register(forms.ModelChoiceField)
@convert_form_field.register(GlobalIDFormField)
def convert_form_field_to_id(field):

View File

@ -8,9 +8,7 @@ from graphql_relay import from_global_id
class GlobalIDFormField(Field):
default_error_messages = {
'invalid': _('Invalid ID specified.'),
}
default_error_messages = {"invalid": _("Invalid ID specified.")}
def clean(self, value):
if not value and not self.required:
@ -19,21 +17,21 @@ class GlobalIDFormField(Field):
try:
_type, _id = from_global_id(value)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
raise ValidationError(self.error_messages['invalid'])
raise ValidationError(self.error_messages["invalid"])
try:
CharField().clean(_id)
CharField().clean(_type)
except ValidationError:
raise ValidationError(self.error_messages['invalid'])
raise ValidationError(self.error_messages["invalid"])
return value
class GlobalIDMultipleChoiceField(MultipleChoiceField):
default_error_messages = {
'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'),
'invalid_list': _('Enter a list of values.'),
"invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
"invalid_list": _("Enter a list of values."),
}
def valid_value(self, value):

View File

@ -0,0 +1,192 @@
# from django import forms
from collections import OrderedDict
import graphene
from graphene import Field, InputField
from graphene.relay.mutation import ClientIDMutation
from graphene.types.mutation import MutationOptions
# from graphene.types.inputobjecttype import (
# InputObjectTypeOptions,
# InputObjectType,
# )
from graphene.types.utils import yank_fields_from_attrs
from graphene_django.registry import get_global_registry
from .converter import convert_form_field
from .types import ErrorType
def fields_for_form(form, only_fields, exclude_fields):
fields = OrderedDict()
for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name
in exclude_fields # or
# name in already_created_fields
)
if is_not_in_only or is_excluded:
continue
fields[name] = convert_form_field(field)
return fields
class BaseDjangoFormMutation(ClientIDMutation):
class Meta:
abstract = True
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
form = cls.get_form(root, info, **input)
if form.is_valid():
return cls.perform_mutate(form, info)
else:
errors = [
ErrorType(field=key, messages=value)
for key, value in form.errors.items()
]
return cls(errors=errors)
@classmethod
def get_form(cls, root, info, **input):
form_kwargs = cls.get_form_kwargs(root, info, **input)
return cls._meta.form_class(**form_kwargs)
@classmethod
def get_form_kwargs(cls, root, info, **input):
kwargs = {"data": input}
pk = input.pop("id", None)
if pk:
instance = cls._meta.model._default_manager.get(pk=pk)
kwargs["instance"] = instance
return kwargs
# class DjangoFormInputObjectTypeOptions(InputObjectTypeOptions):
# form_class = None
# class DjangoFormInputObjectType(InputObjectType):
# class Meta:
# abstract = True
# @classmethod
# def __init_subclass_with_meta__(cls, form_class=None,
# only_fields=(), exclude_fields=(), _meta=None, **options):
# if not _meta:
# _meta = DjangoFormInputObjectTypeOptions(cls)
# assert isinstance(form_class, forms.Form), (
# 'form_class must be an instance of django.forms.Form'
# )
# _meta.form_class = form_class
# form = form_class()
# fields = fields_for_form(form, only_fields, exclude_fields)
# super(DjangoFormInputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, fields=fields, **options)
class DjangoFormMutationOptions(MutationOptions):
form_class = None
class DjangoFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(
cls, form_class=None, only_fields=(), exclude_fields=(), **options
):
if not form_class:
raise Exception("form_class is required for DjangoFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
output_fields = fields_for_form(form, only_fields, exclude_fields)
_meta = DjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def perform_mutate(cls, form, info):
form.save()
return cls(errors=[])
class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions):
model = None
return_field_name = None
class DjangoModelFormMutation(BaseDjangoFormMutation):
class Meta:
abstract = True
errors = graphene.List(ErrorType)
@classmethod
def __init_subclass_with_meta__(
cls,
form_class=None,
model=None,
return_field_name=None,
only_fields=(),
exclude_fields=(),
**options
):
if not form_class:
raise Exception("form_class is required for DjangoModelFormMutation")
if not model:
model = form_class._meta.model
if not model:
raise Exception("model is required for DjangoModelFormMutation")
form = form_class()
input_fields = fields_for_form(form, only_fields, exclude_fields)
if "id" not in exclude_fields:
input_fields["id"] = graphene.ID()
registry = get_global_registry()
model_type = registry.get_type_for_model(model)
return_field_name = return_field_name
if not return_field_name:
model_name = model.__name__
return_field_name = model_name[:1].lower() + model_name[1:]
output_fields = OrderedDict()
output_fields[return_field_name] = graphene.Field(model_type)
_meta = DjangoModelDjangoFormMutationOptions(cls)
_meta.form_class = form_class
_meta.model = model
_meta.return_field_name = return_field_name
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
@classmethod
def perform_mutate(cls, form, info):
obj = form.save()
kwargs = {cls._meta.return_field_name: obj}
return cls(errors=[], **kwargs)

View File

View File

@ -0,0 +1,114 @@
from django import forms
from py.test import raises
import graphene
from graphene import (
String,
Int,
Boolean,
Float,
ID,
UUID,
List,
NonNull,
DateTime,
Date,
Time,
)
from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text="Custom Help Text")
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == "Custom Help Text"
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_form_field(None)
assert "Don't know how to convert the Django form field" in str(excinfo.value)
def test_should_date_convert_date():
assert_conversion(forms.DateField, Date)
def test_should_time_convert_time():
assert_conversion(forms.TimeField, Time)
def test_should_date_time_convert_date_time():
assert_conversion(forms.DateTimeField, DateTime)
def test_should_char_convert_string():
assert_conversion(forms.CharField, String)
def test_should_email_convert_string():
assert_conversion(forms.EmailField, String)
def test_should_slug_convert_string():
assert_conversion(forms.SlugField, String)
def test_should_url_convert_string():
assert_conversion(forms.URLField, String)
def test_should_choice_convert_string():
assert_conversion(forms.ChoiceField, String)
def test_should_base_field_convert_string():
assert_conversion(forms.Field, String)
def test_should_regex_convert_string():
assert_conversion(forms.RegexField, String, "[0-9]+")
def test_should_uuid_convert_string():
if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID)
def test_should_integer_convert_int():
assert_conversion(forms.IntegerField, Int)
def test_should_boolean_convert_boolean():
field = assert_conversion(forms.BooleanField, Boolean)
assert isinstance(field.type, NonNull)
def test_should_nullboolean_convert_boolean():
field = assert_conversion(forms.NullBooleanField, Boolean)
assert not isinstance(field.type, NonNull)
def test_should_float_convert_float():
assert_conversion(forms.FloatField, Float)
def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, Float)
def test_should_multiple_choice_convert_connectionorlist():
field = forms.ModelMultipleChoiceField(queryset=None)
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == ID
def test_should_manytoone_convert_connectionorlist():
field = forms.ModelChoiceField(queryset=None)
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, ID)

View File

@ -0,0 +1,141 @@
from django import forms
from django.test import TestCase
from py.test import raises
from graphene_django.tests.models import Pet, Film, FilmDetails
from ..mutation import DjangoFormMutation, DjangoModelFormMutation
class MyForm(forms.Form):
text = forms.CharField()
class PetForm(forms.ModelForm):
class Meta:
model = Pet
fields = '__all__'
def test_needs_form_class():
with raises(Exception) as exc:
class MyMutation(DjangoFormMutation):
pass
assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
def test_has_output_fields():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
class MyMutation(DjangoFormMutation):
class Meta:
form_class = MyForm
assert "text" in MyMutation.Input._meta.fields
class ModelFormMutationTests(TestCase):
def test_default_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("pet", PetMutation._meta.fields)
def test_default_input_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("name", PetMutation.Input._meta.fields)
self.assertIn("client_mutation_id", PetMutation.Input._meta.fields)
self.assertIn("id", PetMutation.Input._meta.fields)
def test_exclude_fields_input_meta_fields(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
exclude_fields = ['id']
self.assertEqual(PetMutation._meta.model, Pet)
self.assertEqual(PetMutation._meta.return_field_name, "pet")
self.assertIn("name", PetMutation.Input._meta.fields)
self.assertIn("age", PetMutation.Input._meta.fields)
self.assertIn("client_mutation_id", PetMutation.Input._meta.fields)
self.assertNotIn("id", PetMutation.Input._meta.fields)
def test_return_field_name_is_camelcased(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
model = FilmDetails
self.assertEqual(PetMutation._meta.model, FilmDetails)
self.assertEqual(PetMutation._meta.return_field_name, "filmDetails")
def test_custom_return_field_name(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
model = Film
return_field_name = "animal"
self.assertEqual(PetMutation._meta.model, Film)
self.assertEqual(PetMutation._meta.return_field_name, "animal")
self.assertIn("animal", PetMutation._meta.fields)
def test_model_form_mutation_mutate(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
pet = Pet.objects.create(name="Axel", age=10)
result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia", age=10)
self.assertEqual(Pet.objects.count(), 1)
pet.refresh_from_db()
self.assertEqual(pet.name, "Mia")
self.assertEqual(result.errors, [])
def test_model_form_mutation_updates_existing_(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None, name="Mia", age=10)
self.assertEqual(Pet.objects.count(), 1)
pet = Pet.objects.get()
self.assertEqual(pet.name, "Mia")
self.assertEqual(pet.age, 10)
self.assertEqual(result.errors, [])
def test_model_form_mutation_mutate_invalid_form(self):
class PetMutation(DjangoModelFormMutation):
class Meta:
form_class = PetForm
result = PetMutation.mutate_and_get_payload(None, None)
# A pet was not created
self.assertEqual(Pet.objects.count(), 0)
fields_w_error = [e.field for e in result.errors]
self.assertEqual(len(result.errors), 2)
self.assertIn("name", fields_w_error)
self.assertEqual(result.errors[0].messages, ["This field is required."])
self.assertIn("age", fields_w_error)
self.assertEqual(result.errors[1].messages, ["This field is required."])

View File

@ -0,0 +1,6 @@
import graphene
class ErrorType(graphene.ObjectType):
field = graphene.String()
messages = graphene.List(graphene.String)

View File

@ -1,79 +1,51 @@
import importlib
import json
from distutils.version import StrictVersion
from optparse import make_option
from django import get_version as get_django_version
from django.core.management.base import BaseCommand, CommandError
from graphene_django.settings import graphene_settings
LT_DJANGO_1_8 = StrictVersion(get_django_version()) < StrictVersion('1.8')
if LT_DJANGO_1_8:
class CommandArguments(BaseCommand):
option_list = BaseCommand.option_list + (
make_option(
'--schema',
type=str,
dest='schema',
default='',
help='Django app containing schema to dump, e.g. myproject.core.schema.schema',
),
make_option(
'--out',
type=str,
dest='out',
default='',
help='Output file (default: schema.json)'
),
make_option(
'--indent',
type=int,
dest='indent',
default=None,
help='Output file indent (default: None)'
),
class CommandArguments(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
"--schema",
type=str,
dest="schema",
default=graphene_settings.SCHEMA,
help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
)
else:
class CommandArguments(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
'--schema',
type=str,
dest='schema',
default=graphene_settings.SCHEMA,
help='Django app containing schema to dump, e.g. myproject.core.schema.schema')
parser.add_argument(
"--out",
type=str,
dest="out",
default=graphene_settings.SCHEMA_OUTPUT,
help="Output file, --out=- prints to stdout (default: schema.json)",
)
parser.add_argument(
'--out',
type=str,
dest='out',
default=graphene_settings.SCHEMA_OUTPUT,
help='Output file (default: schema.json)')
parser.add_argument(
'--indent',
type=int,
dest='indent',
default=graphene_settings.SCHEMA_INDENT,
help='Output file indent (default: None)')
parser.add_argument(
"--indent",
type=int,
dest="indent",
default=graphene_settings.SCHEMA_INDENT,
help="Output file indent (default: None)",
)
class Command(CommandArguments):
help = 'Dump Graphene schema JSON to file'
help = "Dump Graphene schema JSON to file"
can_import_settings = True
def save_file(self, out, schema_dict, indent):
with open(out, 'w') as outfile:
with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent)
def handle(self, *args, **options):
options_schema = options.get('schema')
options_schema = options.get("schema")
if options_schema and type(options_schema) is str:
module_str, schema_name = options_schema.rsplit('.', 1)
module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name)
@ -83,16 +55,21 @@ class Command(CommandArguments):
else:
schema = graphene_settings.SCHEMA
out = options.get('out') or graphene_settings.SCHEMA_OUTPUT
out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
if not schema:
raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema')
raise CommandError(
"Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
)
indent = options.get('indent')
schema_dict = {'data': schema.introspect()}
self.save_file(out, schema_dict, indent)
indent = options.get("indent")
schema_dict = {"data": schema.introspect()}
if out == '-':
self.stdout.write(json.dumps(schema_dict, indent=indent))
else:
self.save_file(out, schema_dict, indent)
style = getattr(self, 'style', None)
success = getattr(style, 'SUCCESS', lambda x: x)
style = getattr(self, "style", None)
success = getattr(style, "SUCCESS", lambda x: x)
self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out))
self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out))

View File

@ -1,25 +1,32 @@
class Registry(object):
def __init__(self):
self._registry = {}
self._registry_models = {}
self._field_registry = {}
def register(self, cls):
from .types import DjangoObjectType
assert issubclass(
cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
cls.__name__)
assert cls._meta.registry == self, 'Registry for a Model have to match.'
cls, DjangoObjectType
), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# )
if not getattr(cls._meta, 'skip_registry', False):
if not getattr(cls._meta, "skip_registry", False):
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model):
return self._registry.get(model)
def register_converted_field(self, field, converted):
self._field_registry[field] = converted
def get_converted_field(self, field):
return self._field_registry.get(field)
registry = None

View File

@ -0,0 +1,6 @@
from django.db import models
class MyFakeModel(models.Model):
cool_name = models.CharField(max_length=50)
created = models.DateTimeField(auto_now_add=True)

View File

@ -1,20 +1,21 @@
from collections import OrderedDict
from django.shortcuts import get_object_or_404
import graphene
from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation
from graphene.types.objecttype import (
yank_fields_from_attrs
)
from graphene.types.objecttype import yank_fields_from_attrs
from .serializer_converter import (
convert_serializer_field
)
from .serializer_converter import convert_serializer_field
from .types import ErrorType
class SerializerMutationOptions(MutationOptions):
lookup_field = None
model_class = None
model_operations = ["create", "update"]
serializer_class = None
@ -23,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
name in exclude_fields # or
name
in exclude_fields # or
# name in already_created_fields
)
@ -39,37 +41,86 @@ class SerializerMutation(ClientIDMutation):
abstract = True
errors = graphene.List(
ErrorType,
description='May contain more than one error for same field.'
ErrorType, description="May contain more than one error for same field."
)
@classmethod
def __init_subclass_with_meta__(cls, serializer_class=None,
only_fields=(), exclude_fields=(), **options):
def __init_subclass_with_meta__(
cls,
lookup_field=None,
serializer_class=None,
model_class=None,
model_operations=["create", "update"],
only_fields=(),
exclude_fields=(),
**options
):
if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation')
raise Exception("serializer_class is required for the SerializerMutation")
if "update" not in model_operations and "create" not in model_operations:
raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class()
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
if model_class is None:
serializer_meta = getattr(serializer_class, "Meta", None)
if serializer_meta:
model_class = getattr(serializer_meta, "model", None)
if lookup_field is None and model_class:
lookup_field = model_class._meta.pk.name
input_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=True
)
output_fields = fields_for_serializer(
serializer, only_fields, exclude_fields, is_input=False
)
_meta = SerializerMutationOptions(cls)
_meta.lookup_field = lookup_field
_meta.model_operations = model_operations
_meta.serializer_class = serializer_class
_meta.fields = yank_fields_from_attrs(
output_fields,
_as=Field,
_meta.model_class = model_class
_meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
super(SerializerMutation, cls).__init_subclass_with_meta__(
_meta=_meta, input_fields=input_fields, **options
)
input_fields = yank_fields_from_attrs(
input_fields,
_as=InputField,
)
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
lookup_field = cls._meta.lookup_field
model_class = cls._meta.model_class
if model_class:
if "update" in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(
model_class, **{lookup_field: input[lookup_field]}
)
elif "create" in cls._meta.model_operations:
instance = None
else:
raise Exception(
'Invalid update operation. Input parameter "{}" required.'.format(
lookup_field
)
)
return {
"instance": instance,
"data": input,
"context": {"request": info.context},
}
return {"data": input, "context": {"request": info.context}}
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
serializer = cls._meta.serializer_class(data=input)
kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs)
if serializer.is_valid():
return cls.perform_mutate(serializer, info)
@ -84,4 +135,9 @@ class SerializerMutation(ClientIDMutation):
@classmethod
def perform_mutate(cls, serializer, info):
obj = serializer.save()
return cls(errors=None, **obj)
kwargs = {}
for f, field in serializer.fields.items():
kwargs[f] = field.get_attribute(obj)
return cls(errors=None, **kwargs)

View File

@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True):
graphql_type = get_graphene_type_from_serializer_field(field)
args = []
kwargs = {
'description': field.help_text,
'required': is_input and field.required,
}
kwargs = {"description": field.help_text, "required": is_input and field.required}
# if it is a tuple or a list it means that we are returning
# the graphql type and the child type
if isinstance(graphql_type, (list, tuple)):
kwargs['of_type'] = graphql_type[1]
kwargs["of_type"] = graphql_type[1]
graphql_type = graphql_type[0]
if isinstance(field, serializers.ModelSerializer):
@ -46,6 +43,15 @@ def convert_serializer_field(field, is_input=True):
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
elif isinstance(field, serializers.ListSerializer):
field = field.child
if is_input:
kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
else:
del kwargs["of_type"]
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
return graphql_type(*args, **kwargs)
@ -59,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class):
}
return type(
'{}Input'.format(serializer.__class__.__name__),
"{}Input".format(serializer.__class__.__name__),
(graphene.InputObjectType,),
items
items,
)
@ -75,6 +81,12 @@ def convert_serializer_to_field(field):
return graphene.Field
@get_graphene_type_from_serializer_field.register(serializers.ListSerializer)
def convert_list_serializer_to_field(field):
child_type = get_graphene_type_from_serializer_field(field.child)
return (graphene.List, child_type)
@get_graphene_type_from_serializer_field.register(serializers.IntegerField)
def convert_serializer_field_to_int(field):
return graphene.Int
@ -92,9 +104,13 @@ def convert_serializer_field_to_float(field):
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
def convert_serializer_field_to_datetime_time(field):
return graphene.types.datetime.DateTime
@get_graphene_type_from_serializer_field.register(serializers.DateField)
def convert_serializer_field_to_date_time(field):
return graphene.types.datetime.DateTime
return graphene.types.datetime.Date
@get_graphene_type_from_serializer_field.register(serializers.TimeField)

View File

@ -1,8 +1,10 @@
import copy
from rest_framework import serializers
from py.test import raises
import graphene
from django.db import models
from graphene import InputObjectType
from py.test import raises
from rest_framework import serializers
from ..serializer_converter import convert_serializer_field
from ..types import DictType
@ -14,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
# Remove `source=` from the field declaration.
# since we are reusing the same child in when testing the required attribute
if 'child' in kwargs:
kwargs['child'] = copy.deepcopy(kwargs['child'])
if "child" in kwargs:
kwargs["child"] = copy.deepcopy(kwargs["child"])
field = rest_framework_field(**kwargs)
@ -23,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
def assert_conversion(rest_framework_field, graphene_field, **kwargs):
graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs)
graphene_type = _get_type(
rest_framework_field, help_text="Custom Help Text", **kwargs
)
assert isinstance(graphene_type, graphene_field)
graphene_type_required = _get_type(
rest_framework_field, help_text='Custom Help Text', required=True, **kwargs
rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
)
assert isinstance(graphene_type_required, graphene_field)
@ -37,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs):
def test_should_unknown_rest_framework_field_raise_exception():
with raises(Exception) as excinfo:
convert_serializer_field(None)
assert 'Don\'t know how to convert the serializer field' in str(excinfo.value)
assert "Don't know how to convert the serializer field" in str(excinfo.value)
def test_should_char_convert_string():
@ -65,20 +69,19 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string():
assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+')
assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
def test_should_uuid_convert_string():
if hasattr(serializers, 'UUIDField'):
if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String)
def test_should_model_convert_field():
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = None
fields = '__all__'
fields = "__all__"
assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
@ -87,8 +90,8 @@ def test_should_date_time_convert_datetime():
assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime)
def test_should_date_convert_datetime():
assert_conversion(serializers.DateField, graphene.types.datetime.DateTime)
def test_should_date_convert_date():
assert_conversion(serializers.DateField, graphene.types.datetime.Date)
def test_should_time_convert_time():
@ -108,7 +111,9 @@ def test_should_float_convert_float():
def test_should_decimal_convert_float():
assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2)
assert_conversion(
serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
)
def test_should_list_convert_to_list():
@ -118,7 +123,7 @@ def test_should_list_convert_to_list():
field_a = assert_conversion(
serializers.ListField,
graphene.List,
child=serializers.IntegerField(min_value=0, max_value=100)
child=serializers.IntegerField(min_value=0, max_value=100),
)
assert field_a.of_type == graphene.Int
@ -128,6 +133,34 @@ def test_should_list_convert_to_list():
assert field_b.of_type == graphene.String
def test_should_list_serializer_convert_to_list():
class FooModel(models.Model):
pass
class ChildSerializer(serializers.ModelSerializer):
class Meta:
model = FooModel
fields = "__all__"
class ParentSerializer(serializers.ModelSerializer):
child = ChildSerializer(many=True)
class Meta:
model = FooModel
fields = "__all__"
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=True
)
assert isinstance(converted_type, graphene.List)
converted_type = convert_serializer_field(
ParentSerializer().get_fields()["child"], is_input=False
)
assert isinstance(converted_type, graphene.List)
assert converted_type.of_type is None
def test_should_dict_convert_dict():
assert_conversion(serializers.DictField, DictType)
@ -141,7 +174,7 @@ def test_should_file_convert_string():
def test_should_filepath_convert_string():
assert_conversion(serializers.FilePathField, graphene.String, path='/')
assert_conversion(serializers.FilePathField, graphene.String, path="/")
def test_should_ip_convert_string():
@ -157,6 +190,8 @@ def test_should_json_convert_jsonstring():
def test_should_multiplechoicefield_convert_to_list_of_string():
field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3])
field = assert_conversion(
serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
)
assert field.of_type == graphene.String

View File

@ -1,21 +1,40 @@
from django.db import models
from graphene import Field
import datetime
from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType
from py.test import raises
from py.test import mark
from rest_framework import serializers
from ...types import DjangoObjectType
from ..models import MyFakeModel
from ..mutation import SerializerMutation
class MyFakeModel(models.Model):
cool_name = models.CharField(max_length=50)
def mock_info():
return ResolveInfo(
None,
None,
None,
None,
schema=None,
fragments=None,
root_value=None,
operation=None,
variable_values=None,
context=None,
)
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeModel
fields = '__all__'
fields = "__all__"
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer):
@ -28,10 +47,11 @@ class MySerializer(serializers.Serializer):
def test_needs_serializer_class():
with raises(Exception) as exc:
class MyMutation(SerializerMutation):
pass
assert str(exc.value) == 'serializer_class is required for the SerializerMutation'
assert str(exc.value) == "serializer_class is required for the SerializerMutation"
def test_has_fields():
@ -39,9 +59,9 @@ def test_has_fields():
class Meta:
serializer_class = MySerializer
assert 'text' in MyMutation._meta.fields
assert 'model' in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields
assert "text" in MyMutation._meta.fields
assert "model" in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
@ -49,12 +69,24 @@ def test_has_input_fields():
class Meta:
serializer_class = MySerializer
assert 'text' in MyMutation.Input._meta.fields
assert 'model' in MyMutation.Input._meta.fields
assert "text" in MyMutation.Input._meta.fields
assert "model" in MyMutation.Input._meta.fields
def test_exclude_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
exclude_fields = ["created"]
assert "cool_name" in MyMutation._meta.fields
assert "created" not in MyMutation._meta.fields
assert "errors" in MyMutation._meta.fields
assert "cool_name" in MyMutation.Input._meta.fields
assert "created" not in MyMutation.Input._meta.fields
def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType):
class Meta:
model = MyFakeModel
@ -63,37 +95,85 @@ def test_nested_model():
class Meta:
serializer_class = MySerializer
model_field = MyMutation._meta.fields['model']
model_field = MyMutation._meta.fields["model"]
assert isinstance(model_field, Field)
assert model_field.type == MyFakeModelGrapheneType
model_input = MyMutation.Input._meta.fields['model']
model_input = MyMutation.Input._meta.fields["model"]
model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType)
assert 'cool_name' in model_input_type._meta.fields
assert "cool_name" in model_input_type._meta.fields
assert "created" in model_input_type._meta.fields
def test_mutate_and_get_payload_success():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
result = MyMutation.mutate_and_get_payload(None, None, **{
'text': 'value',
'model': {
'cool_name': 'other_value'
}
})
result = MyMutation.mutate_and_get_payload(
None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
)
assert result.errors is None
def test_mutate_and_get_payload_error():
@mark.django_db
def test_model_add_mutate_and_get_payload_success():
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert result.errors is None
assert result.cool_name == "Narf"
assert isinstance(result.created, datetime.datetime)
@mark.django_db
def test_model_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(
None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
)
assert result.errors is None
assert result.cool_name == "New Narf"
@mark.django_db
def test_model_invalid_update_mutate_and_get_payload_success():
class InvalidModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ["update"]
with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"}
)
assert '"id" required' in str(exc.value)
def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
# missing required fields
result = MyMutation.mutate_and_get_payload(None, None, **{})
assert len(result.errors) > 0
result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_model_mutate_and_get_payload_error():
# missing required fields
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_invalid_serializer_operations():
with raises(Exception) as exc:
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ["Add"]
assert "model_operations" in str(exc.value)

View File

@ -3,8 +3,8 @@ from graphene.types.unmountedtype import UnmountedType
class ErrorType(graphene.ObjectType):
field = graphene.String()
messages = graphene.List(graphene.String)
field = graphene.String(required=True)
messages = graphene.List(graphene.NonNull(graphene.String), required=True)
class DictType(UnmountedType):

View File

@ -26,27 +26,22 @@ except ImportError:
# Copied shamelessly from Django REST Framework
DEFAULTS = {
'SCHEMA': None,
'SCHEMA_OUTPUT': 'schema.json',
'SCHEMA_INDENT': None,
'MIDDLEWARE': (),
"SCHEMA": None,
"SCHEMA_OUTPUT": "schema.json",
"SCHEMA_INDENT": None,
"MIDDLEWARE": (),
# Set to True if the connection fields must have
# either the first or last argument
'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False,
"RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
# Max items returned in ConnectionFields / FilterConnectionFields
'RELAY_CONNECTION_MAX_LIMIT': 100,
"RELAY_CONNECTION_MAX_LIMIT": 100,
}
if settings.DEBUG:
DEFAULTS['MIDDLEWARE'] += (
'graphene_django.debug.DjangoDebugMiddleware',
)
DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
# List of settings that may be in string import notation.
IMPORT_STRINGS = (
'MIDDLEWARE',
'SCHEMA',
)
IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
def perform_import(val, setting_name):
@ -69,12 +64,17 @@ def import_from_string(val, setting_name):
"""
try:
# Nod to tastypie's use of importlib.
parts = val.split('.')
module_path, class_name = '.'.join(parts[:-1]), parts[-1]
parts = val.split(".")
module_path, class_name = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg)
@ -96,8 +96,8 @@ class GrapheneSettings(object):
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'GRAPHENE', {})
if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, "GRAPHENE", {})
return self._user_settings
def __getattr__(self, attr):
@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_graphene_settings(*args, **kwargs):
global graphene_settings
setting, value = kwargs['setting'], kwargs['value']
if setting == 'GRAPHENE':
setting, value = kwargs["setting"], kwargs["value"]
if setting == "GRAPHENE":
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)

View File

@ -16,11 +16,11 @@ add "&raw" to the end of the URL within a browser.
width: 100%;
}
</style>
<link href="//cdn.jsdelivr.net/graphiql/{{graphiql_version}}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.1/react.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.1/react-dom.min.js"></script>
<script src="//cdn.jsdelivr.net/graphiql/{{graphiql_version}}/graphiql.min.js"></script>
<link href="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/npm/whatwg-fetch@2.0.3/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/npm/react@16.2.0/umd/react.production.min.js"></script>
<script src="//cdn.jsdelivr.net/npm/react-dom@16.2.0/umd/react-dom.production.min.js"></script>
<script src="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.js"></script>
</head>
<body>
<script>

View File

@ -3,51 +3,103 @@ from __future__ import absolute_import
from django.db import models
from django.utils.translation import ugettext_lazy as _
CHOICES = (
(1, 'this'),
(2, _('that'))
)
CHOICES = ((1, "this"), (2, _("that")))
class Pet(models.Model):
name = models.CharField(max_length=30)
age = models.PositiveIntegerField()
class FilmDetails(models.Model):
location = models.CharField(max_length=30)
film = models.OneToOneField('Film', related_name='details')
film = models.OneToOneField(
"Film", on_delete=models.CASCADE, related_name="details"
)
class Film(models.Model):
reporters = models.ManyToManyField('Reporter',
related_name='films')
genre = models.CharField(
max_length=2,
help_text="Genre",
choices=[("do", "Documentary"), ("ot", "Other")],
default="ot",
)
reporters = models.ManyToManyField("Reporter", related_name="films")
class DoeReporterManager(models.Manager):
def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
class Reporter(models.Model):
first_name = models.CharField(max_length=30)
last_name = models.CharField(max_length=30)
email = models.EmailField()
pets = models.ManyToManyField('self')
pets = models.ManyToManyField("self")
a_choice = models.CharField(max_length=30, choices=CHOICES)
objects = models.Manager()
doe_objects = DoeReporterManager()
def __str__(self): # __unicode__ on Python 2
reporter_type = models.IntegerField(
"Reporter Type",
null=True,
blank=True,
choices=[(1, u"Regular"), (2, u"CNN Reporter")],
)
def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name)
def __init__(self, *args, **kwargs):
"""
Override the init method so that during runtime, Django
can know that this object can be a CNNReporter by casting
it to the proxy model. Otherwise, as far as Django knows,
when a CNNReporter is pulled from the database, it is still
of type Reporter. This was added to test proxy model support.
"""
super(Reporter, self).__init__(*args, **kwargs)
if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter
class CNNReporter(Reporter):
"""
This class is a proxy model for Reporter, used for testing
proxy model support
"""
class Meta:
proxy = True
class Article(models.Model):
headline = models.CharField(max_length=100)
pub_date = models.DateField()
reporter = models.ForeignKey(Reporter, related_name='articles')
editor = models.ForeignKey(Reporter, related_name='edited_articles_+')
lang = models.CharField(max_length=2, help_text='Language', choices=[
('es', 'Spanish'),
('en', 'English')
], default='es')
importance = models.IntegerField('Importance', null=True, blank=True,
choices=[(1, u'Very important'), (2, u'Not as important')])
pub_date_time = models.DateTimeField()
reporter = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="articles"
)
editor = models.ForeignKey(
Reporter, on_delete=models.CASCADE, related_name="edited_articles_+"
)
lang = models.CharField(
max_length=2,
help_text="Language",
choices=[("es", "Spanish"), ("en", "English")],
default="es",
)
importance = models.IntegerField(
"Importance",
null=True,
blank=True,
choices=[(1, u"Very important"), (2, u"Not as important")],
)
def __str__(self): # __unicode__ on Python 2
def __str__(self): # __unicode__ on Python 2
return self.headline
class Meta:
ordering = ('headline',)
ordering = ("headline",)

View File

@ -6,10 +6,9 @@ from .models import Article, Reporter
class Character(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node, )
interfaces = (relay.Node,)
def get_node(self, info, id):
pass
@ -20,7 +19,7 @@ class Human(DjangoObjectType):
class Meta:
model = Article
interfaces = (relay.Node, )
interfaces = (relay.Node,)
def resolve_raises(self, info):
raise Exception("This field should raise exception")

View File

@ -12,10 +12,10 @@ class QueryRoot(ObjectType):
raise Exception("Throws!")
def resolve_request(self, info):
return info.context.GET.get('q')
return info.context.GET.get("q")
def resolve_test(self, info, who=None):
return 'Hello %s' % (who or 'World')
return "Hello %s" % (who or "World")
class MutationRoot(ObjectType):

View File

@ -3,8 +3,8 @@ from mock import patch
from six import StringIO
@patch('graphene_django.management.commands.graphql_schema.Command.save_file')
@patch("graphene_django.management.commands.graphql_schema.Command.save_file")
def test_generate_file_on_call_graphql_schema(savefile_mock, settings):
out = StringIO()
management.call_command('graphql_schema', schema='', stdout=out)
management.call_command("graphql_schema", schema="", stdout=out)
assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue()

View File

@ -5,7 +5,7 @@ from py.test import raises
import graphene
from graphene.relay import ConnectionField, Node
from graphene.types.datetime import DateTime, Time
from graphene.types.datetime import DateTime, Date, Time
from graphene.types.json import JSONString
from ..compat import JSONField, ArrayField, HStoreField, RangeField, MissingType
@ -19,11 +19,11 @@ from .models import Article, Film, FilmDetails, Reporter
def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs)
field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs)
graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == 'Custom Help Text'
assert field.description == "Custom Help Text"
nonnull_field = django_field(null=False, *args, **kwargs)
if not nonnull_field.null:
nonnull_graphene_type = convert_django_field(nonnull_field)
@ -36,11 +36,15 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_django_field(None)
assert 'Don\'t know how to convert the Django field' in str(excinfo.value)
assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string():
assert_conversion(models.DateTimeField, DateTime)
def test_should_date_convert_string():
assert_conversion(models.DateField, DateTime)
assert_conversion(models.DateField, Date)
def test_should_time_convert_string():
@ -79,6 +83,10 @@ def test_should_image_convert_string():
assert_conversion(models.ImageField, graphene.String)
def test_should_url_convert_string():
assert_conversion(models.FilePathField, graphene.String)
def test_should_auto_convert_id():
assert_conversion(models.AutoField, graphene.ID, primary_key=True)
@ -121,70 +129,69 @@ def test_should_nullboolean_convert_boolean():
def test_field_with_choices_convert_enum():
field = models.CharField(help_text='Language', choices=(
('es', 'Spanish'),
('en', 'English')
))
field = models.CharField(
help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
)
class TranslatedModel(models.Model):
language = field
class Meta:
app_label = 'test'
app_label = "test"
graphene_type = convert_django_field_with_choices(field)
assert isinstance(graphene_type, graphene.Enum)
assert graphene_type._meta.name == 'TranslatedModelLanguage'
assert graphene_type._meta.enum.__members__['ES'].value == 'es'
assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish'
assert graphene_type._meta.enum.__members__['EN'].value == 'en'
assert graphene_type._meta.enum.__members__['EN'].description == 'English'
assert graphene_type._meta.name == "TranslatedModelLanguage"
assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_grouped_choices():
field = models.CharField(help_text='Language', choices=(
('Europe', (
('es', 'Spanish'),
('en', 'English'),
)),
))
field = models.CharField(
help_text="Language",
choices=(("Europe", (("es", "Spanish"), ("en", "English"))),),
)
class GroupedChoicesModel(models.Model):
language = field
class Meta:
app_label = 'test'
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_gettext():
field = models.CharField(help_text='Language', choices=(
('es', _('Spanish')),
('en', _('English'))
))
field = models.CharField(
help_text="Language", choices=(("es", _("Spanish")), ("en", _("English")))
)
class TranslatedChoicesModel(models.Model):
language = field
class Meta:
app_label = 'test'
app_label = "test"
convert_django_field_with_choices(field)
def test_field_with_choices_collision():
field = models.CharField(help_text='Timezone', choices=(
('Etc/GMT+1+2', 'Fake choice to produce double collision'),
('Etc/GMT+1', 'Greenwich Mean Time +1'),
('Etc/GMT-1', 'Greenwich Mean Time -1'),
))
field = models.CharField(
help_text="Timezone",
choices=(
("Etc/GMT+1+2", "Fake choice to produce double collision"),
("Etc/GMT+1", "Greenwich Mean Time +1"),
("Etc/GMT-1", "Greenwich Mean Time -1"),
),
)
class CollisionChoicesModel(models.Model):
timezone = field
class Meta:
app_label = 'test'
app_label = "test"
convert_django_field_with_choices(field)
@ -201,11 +208,12 @@ def test_should_manytomany_convert_connectionorlist():
def test_should_manytomany_convert_connectionorlist_list():
class A(DjangoObjectType):
class Meta:
model = Reporter
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry)
graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
@ -215,12 +223,13 @@ def test_should_manytomany_convert_connectionorlist_list():
def test_should_manytomany_convert_connectionorlist_connection():
class A(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
interfaces = (Node,)
graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry)
graphene_field = convert_django_field(
Reporter._meta.local_many_to_many[0], A._meta.registry
)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, ConnectionField)
@ -228,16 +237,12 @@ def test_should_manytomany_convert_connectionorlist_connection():
def test_should_manytoone_convert_connectionorlist():
# Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Reporter.articles, 'rel', None) or \
getattr(Reporter.articles, 'related')
class A(DjangoObjectType):
class Meta:
model = Article
graphene_field = convert_django_field(related, A._meta.registry)
graphene_field = convert_django_field(Reporter.articles.rel,
A._meta.registry)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
@ -246,57 +251,53 @@ def test_should_manytoone_convert_connectionorlist():
def test_should_onetoone_reverse_convert_model():
# Django 1.9 uses 'rel', <1.9 uses 'related
related = getattr(Film.details, 'rel', None) or \
getattr(Film.details, 'related')
class A(DjangoObjectType):
class Meta:
model = FilmDetails
graphene_field = convert_django_field(related, A._meta.registry)
graphene_field = convert_django_field(Film.details.related,
A._meta.registry)
assert isinstance(graphene_field, graphene.Dynamic)
dynamic_field = graphene_field.get_type()
assert isinstance(dynamic_field, graphene.Field)
assert dynamic_field.type == A
@pytest.mark.skipif(ArrayField is MissingType,
reason="ArrayField should exist")
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_should_postgres_array_convert_list():
field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100))
field = assert_conversion(
ArrayField, graphene.List, models.CharField(max_length=100)
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert field.type.of_type.of_type == graphene.String
@pytest.mark.skipif(ArrayField is MissingType,
reason="ArrayField should exist")
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_should_postgres_array_multiple_convert_list():
field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)))
field = assert_conversion(
ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))
)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)
assert isinstance(field.type.of_type.of_type, graphene.List)
assert field.type.of_type.of_type.of_type == graphene.String
@pytest.mark.skipif(HStoreField is MissingType,
reason="HStoreField should exist")
@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist")
def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(JSONField is MissingType,
reason="JSONField should exist")
@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
def test_should_postgres_json_convert_string():
assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType,
reason="RangeField should exist")
@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField
field = assert_conversion(IntegerRangeField, graphene.List)
assert isinstance(field.type, graphene.NonNull)
assert isinstance(field.type.of_type, graphene.List)

View File

@ -1,103 +0,0 @@
from django import forms
from py.test import raises
import graphene
from graphene import ID, List, NonNull
from ..form_converter import convert_form_field
from .models import Reporter
def assert_conversion(django_field, graphene_field, *args):
field = django_field(*args, help_text='Custom Help Text')
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
assert field.description == 'Custom Help Text'
return field
def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo:
convert_form_field(None)
assert 'Don\'t know how to convert the Django form field' in str(excinfo.value)
def test_should_date_convert_string():
assert_conversion(forms.DateField, graphene.String)
def test_should_time_convert_string():
assert_conversion(forms.TimeField, graphene.String)
def test_should_date_time_convert_string():
assert_conversion(forms.DateTimeField, graphene.String)
def test_should_char_convert_string():
assert_conversion(forms.CharField, graphene.String)
def test_should_email_convert_string():
assert_conversion(forms.EmailField, graphene.String)
def test_should_slug_convert_string():
assert_conversion(forms.SlugField, graphene.String)
def test_should_url_convert_string():
assert_conversion(forms.URLField, graphene.String)
def test_should_choice_convert_string():
assert_conversion(forms.ChoiceField, graphene.String)
def test_should_base_field_convert_string():
assert_conversion(forms.Field, graphene.String)
def test_should_regex_convert_string():
assert_conversion(forms.RegexField, graphene.String, '[0-9]+')
def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'):
assert_conversion(forms.UUIDField, graphene.UUID)
def test_should_integer_convert_int():
assert_conversion(forms.IntegerField, graphene.Int)
def test_should_boolean_convert_boolean():
field = assert_conversion(forms.BooleanField, graphene.Boolean)
assert isinstance(field.type, NonNull)
def test_should_nullboolean_convert_boolean():
field = assert_conversion(forms.NullBooleanField, graphene.Boolean)
assert not isinstance(field.type, NonNull)
def test_should_float_convert_float():
assert_conversion(forms.FloatField, graphene.Float)
def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, graphene.Float)
def test_should_multiple_choice_convert_connectionorlist():
field = forms.ModelMultipleChoiceField(Reporter.objects.all())
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, List)
assert graphene_type.of_type == ID
def test_should_manytoone_convert_connectionorlist():
field = forms.ModelChoiceField(Reporter.objects.all())
graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene.ID)

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ValidationError
from py.test import raises
from ..forms import GlobalIDFormField
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
@ -9,13 +9,24 @@ from ..forms import GlobalIDFormField
def test_global_id_valid():
field = GlobalIDFormField()
field.clean('TXlUeXBlOmFiYw==')
field.clean("TXlUeXBlOmFiYw==")
def test_global_id_invalid():
field = GlobalIDFormField()
with raises(ValidationError):
field.clean('badvalue')
field.clean("badvalue")
def test_global_id_multiple_valid():
field = GlobalIDMultipleChoiceField()
field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="])
def test_global_id_multiple_invalid():
field = GlobalIDMultipleChoiceField()
with raises(ValidationError):
field.clean(["badvalue", "another bad avue"])
def test_global_id_none():

File diff suppressed because it is too large Load Diff

View File

@ -7,47 +7,47 @@ from .models import Reporter
def test_should_raise_if_no_model():
with raises(Exception) as excinfo:
class Character1(DjangoObjectType):
pass
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo:
class Character2(DjangoObjectType):
class Character2(DjangoObjectType):
class Meta:
model = 1
assert 'valid Django Model' in str(excinfo.value)
assert "valid Django Model" in str(excinfo.value)
def test_should_map_fields_correctly():
class ReporterType2(DjangoObjectType):
class Meta:
model = Reporter
registry = Registry()
fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [
'id',
'first_name',
'last_name',
'email',
'pets',
'a_choice',
"id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == [
'articles',
'films',
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_should_map_only_few_fields():
class Reporter2(DjangoObjectType):
class Meta:
model = Reporter
only_fields = ('id', 'email')
only_fields = ("id", "email")
assert list(Reporter2._meta.fields.keys()) == ['id', 'email']
assert list(Reporter2._meta.fields.keys()) == ["id", "email"]

View File

@ -1,10 +1,10 @@
from mock import patch
from graphene import Interface, ObjectType, Schema
from graphene import Interface, ObjectType, Schema, Connection, String
from graphene.relay import Node
from .. import registry
from ..types import DjangoObjectType
from ..types import DjangoObjectType, DjangoObjectTypeOptions
from .models import Article as ArticleModel
from .models import Reporter as ReporterModel
@ -12,16 +12,31 @@ registry.reset_global_registry()
class Reporter(DjangoObjectType):
'''Reporter description'''
"""Reporter description"""
class Meta:
model = ReporterModel
class ArticleConnection(Connection):
"""Article Connection"""
test = String()
def resolve_test():
return "test"
class Meta:
abstract = True
class Article(DjangoObjectType):
'''Article description'''
"""Article description"""
class Meta:
model = ArticleModel
interfaces = (Node, )
interfaces = (Node,)
connection_class = ArticleConnection
class RootQuery(ObjectType):
@ -36,7 +51,7 @@ def test_django_interface():
assert issubclass(Node, Node)
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1))
def test_django_get_node(get):
article = Article.get_node(None, 1)
get.assert_called_with(pk=1)
@ -46,13 +61,50 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields
fields = list(fields.keys())
assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice']
assert sorted(fields[-2:]) == ['articles', 'films']
assert fields[:-2] == [
"id",
"first_name",
"last_name",
"email",
"pets",
"a_choice",
"reporter_type",
]
assert sorted(fields[-2:]) == ["articles", "films"]
def test_django_objecttype_with_node_have_correct_fields():
fields = Article._meta.fields
assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'reporter', 'editor', 'lang', 'importance']
assert list(fields.keys()) == [
"id",
"headline",
"pub_date",
"pub_date_time",
"reporter",
"editor",
"lang",
"importance",
]
def test_django_objecttype_with_custom_meta():
class ArticleTypeOptions(DjangoObjectTypeOptions):
"""Article Type Options"""
class ArticleType(DjangoObjectType):
class Meta:
abstract = True
@classmethod
def __init_subclass_with_meta__(cls, **options):
options.setdefault("_meta", ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(**options)
class Article(ArticleType):
class Meta:
model = ArticleModel
assert isinstance(Article._meta, ArticleTypeOptions)
def test_schema_representation():
@ -64,7 +116,8 @@ schema {
type Article implements Node {
id: ID!
headline: String!
pubDate: DateTime!
pubDate: Date!
pubDateTime: DateTime!
reporter: Reporter!
editor: Reporter!
lang: ArticleLang!
@ -74,6 +127,7 @@ type Article implements Node {
type ArticleConnection {
pageInfo: PageInfo!
edges: [ArticleEdge]!
test: String
}
type ArticleEdge {
@ -91,6 +145,8 @@ enum ArticleLang {
EN
}
scalar Date
scalar DateTime
interface Node {
@ -111,6 +167,7 @@ type Reporter {
email: String!
pets: [Reporter]
aChoice: ReporterAChoice!
reporterType: ReporterReporterType
articles(before: String, after: String, first: Int, last: Int): ArticleConnection
}
@ -119,6 +176,11 @@ enum ReporterAChoice {
A_2
}
enum ReporterReporterType {
A_1
A_2
}
type RootQuery {
node(id: ID!): Node
}
@ -138,6 +200,7 @@ def with_local_registry(func):
else:
registry.registry = old
return retval
return inner
@ -146,11 +209,10 @@ def test_django_objecttype_only_fields():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
only_fields = ('id', 'email', 'films')
only_fields = ("id", "email", "films")
fields = list(Reporter._meta.fields.keys())
assert fields == ['id', 'email', 'films']
assert fields == ["id", "email", "films"]
@with_local_registry
@ -158,8 +220,7 @@ def test_django_objecttype_exclude_fields():
class Reporter(DjangoObjectType):
class Meta:
model = ReporterModel
exclude_fields = ('email')
exclude_fields = "email"
fields = list(Reporter._meta.fields.keys())
assert 'email' not in fields
assert "email" not in fields

View File

@ -8,15 +8,15 @@ except ImportError:
from urllib.parse import urlencode
def url_string(string='/graphql', **url_params):
def url_string(string="/graphql", **url_params):
if url_params:
string += '?' + urlencode(url_params)
string += "?" + urlencode(url_params)
return string
def batch_url_string(**url_params):
return url_string('/graphql/batch', **url_params)
return url_string("/graphql/batch", **url_params)
def response_json(response):
@ -28,405 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs])
def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT='text/html')
response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_graphiql(client):
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
)
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html"
def test_qfactor_json(client):
response = client.get(
url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
)
assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_query_param(client):
response = client.get(url_string(query='{test}'))
response = client.get(url_string(query="{test}"))
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_get_with_variable_values(client):
response = client.get(url_string(
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
))
response = client.get(
url_string(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_get_with_operation_name(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
))
""",
operationName="helloWorld",
)
)
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_reports_validation_errors(client):
response = client.get(url_string(
query='{ test, unknownOne, unknownTwo }'
))
response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
assert response.status_code == 400
assert response_json(response) == {
'errors': [
"errors": [
{
'message': 'Cannot query field "unknownOne" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 9}]
"message": 'Cannot query field "unknownOne" on type "QueryRoot".',
"locations": [{"line": 1, "column": 9}],
},
{
'message': 'Cannot query field "unknownTwo" on type "QueryRoot".',
'locations': [{'line': 1, 'column': 21}]
}
"message": 'Cannot query field "unknownTwo" on type "QueryRoot".',
"locations": [{"line": 1, "column": 21}],
},
]
}
def test_errors_when_missing_operation_name(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
'''
))
"""
)
)
assert response.status_code == 400
assert response_json(response) == {
'errors': [
"errors": [
{
'message': 'Must provide operation name if query contains multiple operations.'
"message": "Must provide operation name if query contains multiple operations."
}
]
}
def test_errors_when_sending_a_mutation_via_get(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
mutation TestMutation { writeTest { test } }
'''
))
"""
)
)
assert response.status_code == 405
assert response_json(response) == {
'errors': [
{
'message': 'Can only perform a mutation operation from a POST request.'
}
"errors": [
{"message": "Can only perform a mutation operation from a POST request."}
]
}
def test_errors_when_selecting_a_mutation_within_a_get(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
''',
operationName='TestMutation'
))
""",
operationName="TestMutation",
)
)
assert response.status_code == 405
assert response_json(response) == {
'errors': [
{
'message': 'Can only perform a mutation operation from a POST request.'
}
"errors": [
{"message": "Can only perform a mutation operation from a POST request."}
]
}
def test_allows_mutation_to_exist_within_a_get(client):
response = client.get(url_string(
query='''
response = client.get(
url_string(
query="""
query TestQuery { test }
mutation TestMutation { writeTest { test } }
''',
operationName='TestQuery'
))
""",
operationName="TestQuery",
)
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query='{test}'), 'application/json')
response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_batch_allows_post_with_json_encoding(client):
response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json')
response = client.post(
batch_url_string(), jl(id=1, query="{test}"), "application/json"
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {'test': "Hello World"},
'status': 200,
}]
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello World"}, "status": 200}
]
def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), '[]', 'application/json')
response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Received an empty list in the batch request.'}]
"errors": [{"message": "Received an empty list in the batch request."}]
}
def test_allows_sending_a_mutation_via_post(client):
response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json')
response = client.post(
url_string(),
j(query="mutation TestMutation { writeTest { test } }"),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'writeTest': {'test': 'Hello World'}}
}
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
def test_allows_post_with_url_encoding(client):
response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded')
response = client.post(
url_string(),
urlencode(dict(query="{test}")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello World"}
}
assert response_json(response) == {"data": {"test": "Hello World"}}
def test_supports_post_json_query_with_string_variables(client):
response = client.post(url_string(), j(
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
), 'application/json')
response = client.post(
url_string(),
j(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_batch_supports_post_json_query_with_string_variables(client):
response = client.post(batch_url_string(), jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
), 'application/json')
response = client.post(
batch_url_string(),
jl(
id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {'test': "Hello Dolly"},
'status': 200,
}]
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
]
def test_supports_post_json_query_with_json_variables(client):
response = client.post(url_string(), j(
query='query helloWho($who: String){ test(who: $who) }',
variables={'who': "Dolly"}
), 'application/json')
response = client.post(
url_string(),
j(
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_batch_supports_post_json_query_with_json_variables(client):
response = client.post(batch_url_string(), jl(
id=1,
query='query helloWho($who: String){ test(who: $who) }',
variables={'who': "Dolly"}
), 'application/json')
response = client.post(
batch_url_string(),
jl(
id=1,
query="query helloWho($who: String){ test(who: $who) }",
variables={"who": "Dolly"},
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {'test': "Hello Dolly"},
'status': 200,
}]
assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
]
def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post(url_string(), urlencode(dict(
query='query helloWho($who: String){ test(who: $who) }',
variables=json.dumps({'who': "Dolly"})
)), 'application/x-www-form-urlencoded')
response = client.post(
url_string(),
urlencode(
dict(
query="query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}),
)
),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_supports_post_json_quey_with_get_variable_values(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
), j(
query='query helloWho($who: String){ test(who: $who) }',
), 'application/json')
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
j(query="query helloWho($who: String){ test(who: $who) }"),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_post_url_encoded_query_with_get_variable_values(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
), urlencode(dict(
query='query helloWho($who: String){ test(who: $who) }',
)), 'application/x-www-form-urlencoded')
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")),
"application/x-www-form-urlencoded",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_supports_post_raw_text_query_with_get_variable_values(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
),
'query helloWho($who: String){ test(who: $who) }',
'application/graphql'
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
"query helloWho($who: String){ test(who: $who) }",
"application/graphql",
)
assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_allows_post_with_operation_name(client):
response = client.post(
url_string(),
j(
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
""",
operationName="helloWorld",
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {'test': "Hello Dolly"}
}
def test_allows_post_with_operation_name(client):
response = client.post(url_string(), j(
query='''
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
def test_batch_allows_post_with_operation_name(client):
response = client.post(batch_url_string(), jl(
id=1,
query='''
response = client.post(
batch_url_string(),
jl(
id=1,
query="""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
operationName='helloWorld'
), 'application/json')
""",
operationName="helloWorld",
),
"application/json",
)
assert response.status_code == 200
assert response_json(response) == [{
'id': 1,
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
},
'status': 200,
}]
assert response_json(response) == [
{
"id": 1,
"data": {"test": "Hello World", "shared": "Hello Everyone"},
"status": 200,
}
]
def test_allows_post_with_get_operation_name(client):
response = client.post(url_string(
operationName='helloWorld'
), '''
response = client.post(
url_string(operationName="helloWorld"),
"""
query helloYou { test(who: "You"), ...shared }
query helloWorld { test(who: "World"), ...shared }
query helloDolly { test(who: "Dolly"), ...shared }
fragment shared on QueryRoot {
shared: test(who: "Everyone")
}
''',
'application/graphql')
""",
"application/graphql",
)
assert response.status_code == 200
assert response_json(response) == {
'data': {
'test': 'Hello World',
'shared': 'Hello Everyone'
}
"data": {"test": "Hello World", "shared": "Hello Everyone"}
}
@pytest.mark.urls('graphene_django.tests.urls_pretty')
@pytest.mark.urls("graphene_django.tests.urls_inherited")
def test_inherited_class_with_attributes_works(client):
inherited_url = "/graphql/inherited/"
# Check schema and pretty attributes work
response = client.post(url_string(inherited_url, query="{test}"))
assert response.content.decode() == (
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
# Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == 200
@pytest.mark.urls("graphene_django.tests.urls_pretty")
def test_supports_pretty_printing(client):
response = client.get(url_string(query='{test}'))
response = client.get(url_string(query="{test}"))
assert response.content.decode() == (
'{\n'
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
def test_supports_pretty_printing_by_request(client):
response = client.get(url_string(query='{test}', pretty='1'))
response = client.get(url_string(query="{test}", pretty="1"))
assert response.content.decode() == (
'{\n'
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
"{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}"
)
def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query='{thrower}'))
response = client.get(url_string(query="{thrower}"))
assert response.status_code == 200
assert response_json(response) == {
'data': None,
'errors': [{'locations': [{'column': 2, 'line': 1}], 'message': 'Throws!'}]
"data": None,
"errors": [
{
"locations": [{"column": 2, "line": 1}],
"path": ["thrower"],
"message": "Throws!",
}
],
}
def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query='syntaxerror'))
response = client.get(url_string(query="syntaxerror"))
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'locations': [{'column': 1, 'line': 1}],
'message': 'Syntax Error GraphQL request (1:1) '
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}]
"errors": [
{
"locations": [{"column": 1, "line": 1}],
"message": "Syntax Error GraphQL (1:1) "
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n',
}
]
}
@ -435,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}]
"errors": [{"message": "Must provide query string."}]
}
def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), '[]', 'application/json')
response = client.post(url_string(), "[]", "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'The received data is not a valid JSON query.'}]
"errors": [{"message": "The received data is not a valid JSON query."}]
}
def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), '[oh}', 'application/json')
response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}]
"errors": [{"message": "POST body sent invalid JSON."}]
}
@ -463,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch):
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
valid_json = json.dumps(dict(foo='bar'))
response = client.post(url_string(), valid_json, 'application/json')
valid_json = json.dumps(dict(foo="bar"))
response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'foo-bar'}]
}
assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', 'application/json')
response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'POST body sent invalid JSON.'}]
"errors": [{"message": "POST body sent invalid JSON."}]
}
def test_handles_plain_post_text(client):
response = client.post(url_string(
variables=json.dumps({'who': "Dolly"})
),
'query helloWho($who: String){ test(who: $who) }',
'text/plain'
response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})),
"query helloWho($who: String){ test(who: $who) }",
"text/plain",
)
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Must provide query string.'}]
"errors": [{"message": "Must provide query string."}]
}
def test_handles_poorly_formed_variables(client):
response = client.get(url_string(
query='query helloWho($who: String){ test(who: $who) }',
variables='who:You'
))
response = client.get(
url_string(
query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
)
)
assert response.status_code == 400
assert response_json(response) == {
'errors': [{'message': 'Variables are invalid JSON.'}]
"errors": [{"message": "Variables are invalid JSON."}]
}
def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query='{test}'))
response = client.put(url_string(query="{test}"))
assert response.status_code == 405
assert response['Allow'] == 'GET, POST'
assert response["Allow"] == "GET, POST"
assert response_json(response) == {
'errors': [{'message': 'GraphQL only supports GET and POST requests.'}]
"errors": [{"message": "GraphQL only supports GET and POST requests."}]
}
def test_passes_request_into_context_request(client):
response = client.get(url_string(query='{request}', q='testing'))
response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == 200
assert response_json(response) == {
'data': {
'request': 'testing'
}
}
assert response_json(response) == {"data": {"request": "testing"}}

View File

@ -3,6 +3,6 @@ from django.conf.urls import url
from ..views import GraphQLView
urlpatterns = [
url(r'^graphql/batch', GraphQLView.as_view(batch=True)),
url(r'^graphql', GraphQLView.as_view(graphiql=True)),
url(r"^graphql/batch", GraphQLView.as_view(batch=True)),
url(r"^graphql", GraphQLView.as_view(graphiql=True)),
]

View File

@ -0,0 +1,13 @@
from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
class CustomGraphQLView(GraphQLView):
schema = schema
graphiql = True
pretty = True
urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())]

View File

@ -3,6 +3,4 @@ from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
urlpatterns = [
url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)),
]
urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))]

View File

@ -8,8 +8,7 @@ from graphene.types.utils import yank_fields_from_attrs
from .converter import convert_django_field_with_choices
from .registry import Registry, get_global_registry
from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields,
is_valid_django_model)
from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model
def construct_fields(model, registry, only_fields, exclude_fields):
@ -21,7 +20,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
# is_already_created = name in options.fields
is_excluded = name in exclude_fields # or is_already_created
# https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name
is_no_backref = str(name).endswith('+')
is_no_backref = str(name).endswith("+")
if is_not_in_only or is_excluded or is_no_backref:
# We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields.
@ -43,9 +42,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions):
class DjangoObjectType(ObjectType):
@classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
use_connection=None, interfaces=(), **options):
def __init_subclass_with_meta__(
cls,
model=None,
registry=None,
skip_registry=False,
only_fields=(),
exclude_fields=(),
filter_fields=None,
connection=None,
connection_class=None,
use_connection=None,
interfaces=(),
_meta=None,
**options
):
assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(cls.__name__, model)
@ -54,7 +65,7 @@ class DjangoObjectType(ObjectType):
registry = get_global_registry()
assert isinstance(registry, Registry), (
'The attribute registry in {} needs to be an instance of '
"The attribute registry in {} needs to be an instance of "
'Registry, received "{}".'
).format(cls.__name__, registry)
@ -62,30 +73,40 @@ class DjangoObjectType(ObjectType):
raise Exception("Can only set filter_fields if Django-Filter is installed")
django_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields),
_as=Field,
construct_fields(model, registry, only_fields, exclude_fields), _as=Field
)
if use_connection is None and interfaces:
use_connection = any((issubclass(interface, Node) for interface in interfaces))
use_connection = any(
(issubclass(interface, Node) for interface in interfaces)
)
if use_connection and not connection:
# We create the connection automatically
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
if not connection_class:
connection_class = Connection
connection = connection_class.create_type(
"{}Connection".format(cls.__name__), node=cls
)
if connection is not None:
assert issubclass(connection, Connection), (
"The connection must be a Connection. Received {}"
).format(connection.__name__)
_meta = DjangoObjectTypeOptions(cls)
if not _meta:
_meta = DjangoObjectTypeOptions(cls)
_meta.model = model
_meta.registry = registry
_meta.filter_fields = filter_fields
_meta.fields = django_fields
_meta.connection = connection
super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)
super(DjangoObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options
)
if not skip_registry:
registry.register(cls)
@ -101,10 +122,9 @@ class DjangoObjectType(ObjectType):
if isinstance(root, cls):
return True
if not is_valid_django_model(type(root)):
raise Exception((
'Received incompatible instance "{}".'
).format(root))
model = root._meta.model
raise Exception(('Received incompatible instance "{}".').format(root))
model = root._meta.model._meta.concrete_model
return model == cls._meta.model
@classmethod

View File

@ -13,6 +13,7 @@ class LazyList(object):
try:
import django_filters # noqa
DJANGO_FILTER_INSTALLED = True
except ImportError:
DJANGO_FILTER_INSTALLED = False
@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names):
continue
# Django =>1.9 uses 'rel', django <1.9 uses 'related'
related = getattr(attr, 'rel', None) or \
getattr(attr, 'related', None)
related = getattr(attr, "rel", None) or getattr(attr, "related", None)
if isinstance(related, models.ManyToOneRel):
yield (name, related)
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
@ -42,9 +42,9 @@ def maybe_queryset(value):
def get_model_fields(model):
local_fields = [
(field.name, field)
for field
in sorted(list(model._meta.fields) +
list(model._meta.local_many_to_many))
for field in sorted(
list(model._meta.fields) + list(model._meta.local_many_to_many)
)
]
# Make sure we don't duplicate local fields with "reverse" version

View File

@ -10,18 +10,16 @@ from django.utils.decorators import method_decorator
from django.views.generic import View
from django.views.decorators.csrf import ensure_csrf_cookie
from graphql import Source, execute, parse, validate
from graphql import get_default_backend
from graphql.error import format_error as format_graphql_error
from graphql.error import GraphQLError
from graphql.execution import ExecutionResult
from graphql.type.schema import GraphQLSchema
from graphql.utils.get_operation_ast import get_operation_ast
from .settings import graphene_settings
class HttpError(Exception):
def __init__(self, response, message=None, *args, **kwargs):
self.response = response
self.message = message = message or response.content.decode()
@ -30,18 +28,18 @@ class HttpError(Exception):
def get_accepted_content_types(request):
def qualify(x):
parts = x.split(';', 1)
parts = x.split(";", 1)
if len(parts) == 2:
match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)',
parts[1])
match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
if match:
return parts[0], float(match.group(2))
return parts[0], 1
return parts[0].strip(), float(match.group(2))
return parts[0].strip(), 1
raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',')
raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types)
return list(x[0] for x in sorted(qualified_content_types,
key=lambda x: x[1], reverse=True))
return list(
x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
)
def instantiate_middleware(middlewares):
@ -53,38 +51,52 @@ def instantiate_middleware(middlewares):
class GraphQLView(View):
graphiql_version = '0.10.2'
graphiql_template = 'graphene/graphiql.html'
graphiql_version = "0.11.10"
graphiql_template = "graphene/graphiql.html"
schema = None
graphiql = False
executor = None
backend = None
middleware = None
root_value = None
pretty = False
batch = False
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
batch=False):
def __init__(
self,
schema=None,
executor=None,
middleware=None,
root_value=None,
graphiql=False,
pretty=False,
batch=False,
backend=None,
):
if not schema:
schema = graphene_settings.SCHEMA
if backend is None:
backend = get_default_backend()
if middleware is None:
middleware = graphene_settings.MIDDLEWARE
self.schema = schema
self.schema = self.schema or schema
if middleware is not None:
self.middleware = list(instantiate_middleware(middleware))
self.executor = executor
self.root_value = root_value
self.pretty = pretty
self.graphiql = graphiql
self.batch = batch
self.pretty = self.pretty or pretty
self.graphiql = self.graphiql or graphiql
self.batch = self.batch or batch
self.backend = backend
assert isinstance(
self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
assert not all((graphiql, batch)
), 'Use either graphiql or batch processing'
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
# noinspection PyUnusedLocal
def get_root_value(self, request):
@ -96,65 +108,65 @@ class GraphQLView(View):
def get_context(self, request):
return request
def get_backend(self, request):
return self.backend
@method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs):
try:
if request.method.lower() not in ('get', 'post'):
raise HttpError(HttpResponseNotAllowed(
['GET', 'POST'], 'GraphQL only supports GET and POST requests.'))
if request.method.lower() not in ("get", "post"):
raise HttpError(
HttpResponseNotAllowed(
["GET", "POST"], "GraphQL only supports GET and POST requests."
)
)
data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(
request, data)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
if self.batch:
responses = [self.get_response(
request, entry) for entry in data]
result = '[{}]'.format(
','.join([response[0] for response in responses]))
status_code = max(
responses, key=lambda response: response[1])[1]
responses = [self.get_response(request, entry) for entry in data]
result = "[{}]".format(
",".join([response[0] for response in responses])
)
status_code = (
responses
and max(responses, key=lambda response: response[1])[1]
or 200
)
else:
result, status_code = self.get_response(
request, data, show_graphiql)
result, status_code = self.get_response(request, data, show_graphiql)
if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params(
request, data)
request, data
)
return self.render_graphiql(
request,
graphiql_version=self.graphiql_version,
query=query or '',
variables=json.dumps(variables) or '',
operation_name=operation_name or '',
result=result or ''
query=query or "",
variables=json.dumps(variables) or "",
operation_name=operation_name or "",
result=result or "",
)
return HttpResponse(
status=status_code,
content=result,
content_type='application/json'
status=status_code, content=result, content_type="application/json"
)
except HttpError as e:
response = e.response
response['Content-Type'] = 'application/json'
response.content = self.json_encode(request, {
'errors': [self.format_error(e)]
})
response["Content-Type"] = "application/json"
response.content = self.json_encode(
request, {"errors": [self.format_error(e)]}
)
return response
def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(
request, data)
query, variables, operation_name, id = self.get_graphql_params(request, data)
execution_result = self.execute_graphql_request(
request,
data,
query,
variables,
operation_name,
show_graphiql
request, data, query, variables, operation_name, show_graphiql
)
status_code = 200
@ -162,17 +174,18 @@ class GraphQLView(View):
response = {}
if execution_result.errors:
response['errors'] = [self.format_error(
e) for e in execution_result.errors]
response["errors"] = [
self.format_error(e) for e in execution_result.errors
]
if execution_result.invalid:
status_code = 400
else:
response['data'] = execution_result.data
response["data"] = execution_result.data
if self.batch:
response['id'] = id
response['status'] = status_code
response["id"] = id
response["status"] = status_code
result = self.json_encode(request, response, pretty=show_graphiql)
else:
@ -184,22 +197,21 @@ class GraphQLView(View):
return render(request, self.graphiql_template, data)
def json_encode(self, request, d, pretty=False):
if not (self.pretty or pretty) and not request.GET.get('pretty'):
return json.dumps(d, separators=(',', ':'))
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"))
return json.dumps(d, sort_keys=True,
indent=2, separators=(',', ': '))
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
def parse_body(self, request):
content_type = self.get_content_type(request)
if content_type == 'application/graphql':
return {'query': request.body.decode()}
if content_type == "application/graphql":
return {"query": request.body.decode()}
elif content_type == 'application/json':
elif content_type == "application/json":
# noinspection PyBroadException
try:
body = request.body.decode('utf-8')
body = request.body.decode("utf-8")
except Exception as e:
raise HttpError(HttpResponseBadRequest(str(e)))
@ -207,102 +219,113 @@ class GraphQLView(View):
request_json = json.loads(body)
if self.batch:
assert isinstance(request_json, list), (
'Batch requests should receive a list, but received {}.'
"Batch requests should receive a list, but received {}."
).format(repr(request_json))
assert len(request_json) > 0, (
'Received an empty list in the batch request.'
)
assert (
len(request_json) > 0
), "Received an empty list in the batch request."
else:
assert isinstance(request_json, dict), (
'The received data is not a valid JSON query.'
)
assert isinstance(
request_json, dict
), "The received data is not a valid JSON query."
return request_json
except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest(
'POST body sent invalid JSON.'))
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']:
elif content_type in [
"application/x-www-form-urlencoded",
"multipart/form-data",
]:
return request.POST
return {}
def execute(self, *args, **kwargs):
return execute(self.schema, *args, **kwargs)
def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False):
def execute_graphql_request(
self, request, data, query, variables, operation_name, show_graphiql=False
):
if not query:
if show_graphiql:
return None
raise HttpError(HttpResponseBadRequest(
'Must provide query string.'))
source = Source(query, name='GraphQL request')
raise HttpError(HttpResponseBadRequest("Must provide query string."))
try:
document_ast = parse(source)
validation_errors = validate(self.schema, document_ast)
if validation_errors:
return ExecutionResult(
errors=validation_errors,
invalid=True,
)
backend = self.get_backend(request)
document = backend.document_from_string(self.schema, query)
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)
if request.method.lower() == 'get':
operation_ast = get_operation_ast(document_ast, operation_name)
if operation_ast and operation_ast.operation != 'query':
if request.method.lower() == "get":
operation_type = document.get_operation_type(operation_name)
if operation_type and operation_type != "query":
if show_graphiql:
return None
raise HttpError(HttpResponseNotAllowed(
['POST'], 'Can only perform a {} operation from a POST request.'.format(
operation_ast.operation)
))
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_type
),
)
)
try:
return self.execute(
document_ast,
root_value=self.get_root_value(request),
variable_values=variables,
extra_options = {}
if self.executor:
# We only include it optionally since
# executor is not a valid argument in all backends
extra_options["executor"] = self.executor
return document.execute(
root=self.get_root_value(request),
variables=variables,
operation_name=operation_name,
context_value=self.get_context(request),
context=self.get_context(request),
middleware=self.get_middleware(request),
executor=self.executor,
**extra_options
)
except Exception as e:
return ExecutionResult(errors=[e], invalid=True)
@classmethod
def can_display_graphiql(cls, request, data):
raw = 'raw' in request.GET or 'raw' in data
raw = "raw" in request.GET or "raw" in data
return not raw and cls.request_wants_html(request)
@classmethod
def request_wants_html(cls, request):
accepted = get_accepted_content_types(request)
html_index = accepted.count('text/html')
json_index = accepted.count('application/json')
accepted_length = len(accepted)
# the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = (
accepted_length - accepted.index("text/html")
if "text/html" in accepted
else 0
)
json_priority = (
accepted_length - accepted.index("application/json")
if "application/json" in accepted
else 0
)
return html_index > json_index
return html_priority > json_priority
@staticmethod
def get_graphql_params(request, data):
query = request.GET.get('query') or data.get('query')
variables = request.GET.get('variables') or data.get('variables')
id = request.GET.get('id') or data.get('id')
query = request.GET.get("query") or data.get("query")
variables = request.GET.get("variables") or data.get("variables")
id = request.GET.get("id") or data.get("id")
if variables and isinstance(variables, six.text_type):
try:
variables = json.loads(variables)
except Exception:
raise HttpError(HttpResponseBadRequest(
'Variables are invalid JSON.'))
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
operation_name = request.GET.get(
'operationName') or data.get('operationName')
operation_name = request.GET.get("operationName") or data.get("operationName")
if operation_name == "null":
operation_name = None
@ -313,11 +336,10 @@ class GraphQLView(View):
if isinstance(error, GraphQLError):
return format_graphql_error(error)
return {'message': six.text_type(error)}
return {"message": six.text_type(error)}
@staticmethod
def get_content_type(request):
meta = request.META
content_type = meta.get(
'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
return content_type.split(';', 1)[0].lower()
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
return content_type.split(";", 1)[0].lower()

View File

@ -3,76 +3,63 @@ import sys
import ast
import re
_version_re = re.compile(r'__version__\s+=\s+(.*)')
_version_re = re.compile(r"__version__\s+=\s+(.*)")
with open('graphene_django/__init__.py', 'rb') as f:
version = str(ast.literal_eval(_version_re.search(
f.read().decode('utf-8')).group(1)))
with open("graphene_django/__init__.py", "rb") as f:
version = str(
ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))
)
rest_framework_require = [
'djangorestframework>=3.6.3',
]
rest_framework_require = ["djangorestframework>=3.6.3"]
tests_require = [
'pytest>=2.7.2',
'pytest-cov',
'coveralls',
'mock',
'pytz',
'django-filter',
'pytest-django==2.9.1',
"pytest>=3.6.3",
"pytest-cov",
"coveralls",
"mock",
"pytz",
"django-filter<2;python_version<'3'",
"django-filter>=2;python_version>='3'",
"pytest-django>=3.3.2",
] + rest_framework_require
setup(
name='graphene-django',
name="graphene-django",
version=version,
description='Graphene Django integration',
long_description=open('README.rst').read(),
url='https://github.com/graphql-python/graphene-django',
author='Syrus Akbary',
author_email='me@syrusakbary.com',
license='MIT',
description="Graphene Django integration",
long_description=open("README.rst").read(),
url="https://github.com/graphql-python/graphene-django",
author="Syrus Akbary",
author_email="me@syrusakbary.com",
license="MIT",
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Topic :: Software Development :: Libraries',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: Implementation :: PyPy',
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Topic :: Software Development :: Libraries",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: Implementation :: PyPy",
],
keywords='api graphql protocol rest relay graphene',
packages=find_packages(exclude=['tests']),
keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests"]),
install_requires=[
'six>=1.10.0',
'graphene>=2.0',
'Django>=1.8.0',
'iso8601',
'singledispatch>=3.4.0.3',
'promise>=2.1',
],
setup_requires=[
'pytest-runner',
"six>=1.10.0",
"graphene>=2.1.3,<3",
"graphql-core>=2.1.0,<3",
"Django>=1.11",
"singledispatch>=3.4.0.3",
"promise>=2.1",
],
setup_requires=["pytest-runner"],
tests_require=tests_require,
rest_framework_require=rest_framework_require,
extras_require={
'test': tests_require,
'rest_framework': rest_framework_require,
},
extras_require={"test": tests_require, "rest_framework": rest_framework_require},
include_package_data=True,
zip_safe=False,
platforms='any',
platforms="any",
)