diff --git a/.travis.yml b/.travis.yml
index ef8a3d6..a8375ee 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -11,9 +11,6 @@ install:
pip install -e .[test]
pip install psycopg2 # Required for Django postgres fields testing
pip install django==$DJANGO_VERSION
- if [ $DJANGO_VERSION = 1.8 ]; then # DRF dropped 1.8 support at 3.7.0
- pip install djangorestframework==3.6.4
- fi
python setup.py develop
elif [ "$TEST_TYPE" = lint ]; then
pip install flake8
@@ -38,13 +35,19 @@ env:
matrix:
fast_finish: true
include:
+ - python: '3.4'
+ env: TEST_TYPE=build DJANGO_VERSION=2.0
+ - python: '3.5'
+ env: TEST_TYPE=build DJANGO_VERSION=2.0
+ - python: '3.6'
+ env: TEST_TYPE=build DJANGO_VERSION=2.0
+ - python: '3.5'
+ env: TEST_TYPE=build DJANGO_VERSION=2.1
+ - python: '3.6'
+ env: TEST_TYPE=build DJANGO_VERSION=2.1
- python: '2.7'
- env: TEST_TYPE=build DJANGO_VERSION=1.8
- - python: '2.7'
- env: TEST_TYPE=build DJANGO_VERSION=1.9
- - python: '2.7'
- env: TEST_TYPE=build DJANGO_VERSION=1.10
- - python: '2.7'
+ env: TEST_TYPE=lint
+ - python: '3.6'
env: TEST_TYPE=lint
deploy:
provider: pypi
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..b4a4b70
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2016-Present Syrus Akbary
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/MANIFEST.in b/MANIFEST.in
index 8fffb71..3c3d4f9 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,2 +1,2 @@
-include README.md
+include README.md LICENSE
recursive-include graphene_django/templates *
diff --git a/README.md b/README.md
index 62a36f0..4e0b01d 100644
--- a/README.md
+++ b/README.md
@@ -9,10 +9,10 @@ A [Django](https://www.djangoproject.com/) integration for [Graphene](http://gra
## Installation
-For instaling graphene, just run this command in your shell
+For installing graphene, just run this command in your shell
```bash
-pip install "graphene-django>=2.0.dev"
+pip install "graphene-django>=2.0"
```
### Settings
@@ -67,8 +67,7 @@ class User(DjangoObjectType):
class Query(graphene.ObjectType):
users = graphene.List(User)
- @graphene.resolve_only_args
- def resolve_users(self):
+ def resolve_users(self, info):
return UserModel.objects.all()
schema = graphene.Schema(query=Query)
diff --git a/README.rst b/README.rst
index 27cbdc0..a96e60f 100644
--- a/README.rst
+++ b/README.rst
@@ -13,11 +13,11 @@ A `Django `__ integration for
Installation
------------
-For instaling graphene, just run this command in your shell
+For installing graphene, just run this command in your shell
.. code:: bash
- pip install "graphene-django>=2.0.dev"
+ pip install "graphene-django>=2.0"
Settings
~~~~~~~~
diff --git a/django_test_settings.py b/django_test_settings.py
index 2e08272..9279a73 100644
--- a/django_test_settings.py
+++ b/django_test_settings.py
@@ -8,6 +8,7 @@ SECRET_KEY = 1
INSTALLED_APPS = [
'graphene_django',
+ 'graphene_django.rest_framework',
'graphene_django.tests',
'starwars',
]
diff --git a/docs/authorization.rst b/docs/authorization.rst
index 214cbc7..9d1b2c6 100644
--- a/docs/authorization.rst
+++ b/docs/authorization.rst
@@ -20,7 +20,7 @@ Let's use a simple example model.
Limiting Field Access
---------------------
-This is easy, simply use the ``only_fields`` meta attribute.
+To limit fields in a GraphQL query simply use the ``only_fields`` meta attribute.
.. code:: python
@@ -34,7 +34,7 @@ This is easy, simply use the ``only_fields`` meta attribute.
only_fields = ('title', 'content')
interfaces = (relay.Node, )
-conversely you can use ``exclude_fields`` meta atrribute.
+conversely you can use ``exclude_fields`` meta attribute.
.. code:: python
@@ -61,10 +61,11 @@ define a resolve method for that field and return the desired queryset.
from .models import Post
class Query(ObjectType):
- all_posts = DjangoFilterConnectionField(CategoryNode)
+ all_posts = DjangoFilterConnectionField(PostNode)
+
+ def resolve_all_posts(self, info):
+ return Post.objects.filter(published=True)
- def resolve_all_posts(self, args, info):
- return Post.objects.filter(published=True)
User-based Queryset Filtering
-----------------------------
@@ -79,7 +80,7 @@ with the context argument.
from .models import Post
class Query(ObjectType):
- my_posts = DjangoFilterConnectionField(CategoryNode)
+ my_posts = DjangoFilterConnectionField(PostNode)
def resolve_my_posts(self, info):
# context will reference to the Django request
@@ -95,7 +96,7 @@ schema is simple.
result = schema.execute(query, context_value=request)
-Filtering ID-based node access
+Filtering ID-based Node Access
------------------------------
In order to add authorization to id-based node access, we need to add a
@@ -113,13 +114,13 @@ method to your ``DjangoObjectType``.
interfaces = (relay.Node, )
@classmethod
- def get_node(cls, id, context, info):
+ def get_node(cls, id, info):
try:
post = cls._meta.model.objects.get(id=id)
except cls._meta.model.DoesNotExist:
return None
- if post.published or context.user == post.owner:
+ if post.published or info.context.user == post.owner:
return post
return None
@@ -206,14 +207,14 @@ Connection example:
all_reporters = MyAuthDjangoConnectionField(ReporterType)
-
-Adding login required
+Adding Login Required
---------------------
-If you want to use the standard Django LoginRequiredMixin_ you can create your own view, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``:
+To restrict users from accessing the GraphQL API page the standard Django LoginRequiredMixin_ can be used to create your own standard Django Class Based View, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``.:
.. code:: python
-
+ #views.py
+
from django.contrib.auth.mixins import LoginRequiredMixin
from graphene_django.views import GraphQLView
@@ -221,7 +222,9 @@ If you want to use the standard Django LoginRequiredMixin_ you can create your o
class PrivateGraphQLView(LoginRequiredMixin, GraphQLView):
pass
-After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``:
+After this, you can use the new ``PrivateGraphQLView`` in the project's URL Configuration file ``url.py``:
+
+For Django 1.9 and below:
.. code:: python
@@ -229,5 +232,14 @@ After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``:
# some other urls
url(r'^graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
]
+
+For Django 2.0 and above:
+
+.. code:: python
+
+ urlpatterns = [
+ # some other urls
+ path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)),
+ ]
.. _LoginRequiredMixin: https://docs.djangoproject.com/en/1.10/topics/auth/default/#the-loginrequired-mixin
diff --git a/docs/filtering.rst b/docs/filtering.rst
index edcb7e5..feafd40 100644
--- a/docs/filtering.rst
+++ b/docs/filtering.rst
@@ -2,9 +2,9 @@ Filtering
=========
Graphene integrates with
-`django-filter `__ to provide
-filtering of results. See the `usage
-documentation `__
+`django-filter `__ (2.x for
+Python 3 or 1.x for Python 2) to provide filtering of results. See the `usage
+documentation `__
for details on the format for ``filter_fields``.
This filtering is automatically available when implementing a ``relay.Node``.
@@ -15,7 +15,7 @@ You will need to install it manually, which can be done as follows:
.. code:: bash
# You'll need to django-filter
- pip install django-filter
+ pip install django-filter>=2
Note: The techniques below are demoed in the `cookbook example
app `__.
@@ -26,7 +26,7 @@ Filterable fields
The ``filter_fields`` parameter is used to specify the fields which can
be filtered upon. The value specified here is passed directly to
``django-filter``, so see the `filtering
-documentation `__
+documentation `__
for full details on the range of options available.
For example:
@@ -126,3 +126,23 @@ create your own ``Filterset`` as follows:
# We specify our custom AnimalFilter using the filterset_class param
all_animals = DjangoFilterConnectionField(AnimalNode,
filterset_class=AnimalFilter)
+
+The context argument is passed on as the `request argument `__
+in a ``django_filters.FilterSet`` instance. You can use this to customize your
+filters to be context-dependent. We could modify the ``AnimalFilter`` above to
+pre-filter animals owned by the authenticated user (set in ``context.user``).
+
+.. code:: python
+
+ class AnimalFilter(django_filters.FilterSet):
+ # Do case-insensitive lookups on 'name'
+ name = django_filters.CharFilter(lookup_type='iexact')
+
+ class Meta:
+ model = Animal
+ fields = ['name', 'genus', 'is_domesticated']
+
+ @property
+ def qs(self):
+ # The query context can be found in self.request.
+ return super(AnimalFilter, self).qs.filter(owner=self.request.user)
diff --git a/docs/form-mutations.rst b/docs/form-mutations.rst
new file mode 100644
index 0000000..e721a78
--- /dev/null
+++ b/docs/form-mutations.rst
@@ -0,0 +1,68 @@
+Integration with Django forms
+=============================
+
+Graphene-Django comes with mutation classes that will convert the fields on Django forms into inputs on a mutation.
+*Note: the API is experimental and will likely change in the future.*
+
+FormMutation
+------------
+
+.. code:: python
+
+ class MyForm(forms.Form):
+ name = forms.CharField()
+
+ class MyMutation(FormMutation):
+ class Meta:
+ form_class = MyForm
+
+``MyMutation`` will automatically receive an ``input`` argument. This argument should be a ``dict`` where the key is ``name`` and the value is a string.
+
+ModelFormMutation
+-----------------
+
+``ModelFormMutation`` will pull the fields from a ``ModelForm``.
+
+.. code:: python
+
+ class Pet(models.Model):
+ name = models.CharField()
+
+ class PetForm(forms.ModelForm):
+ class Meta:
+ model = Pet
+ fields = ('name',)
+
+ # This will get returned when the mutation completes successfully
+ class PetType(DjangoObjectType):
+ class Meta:
+ model = Pet
+
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+
+``PetMutation`` will grab the fields from ``PetForm`` and turn them into inputs. If the form is valid then the mutation
+will lookup the ``DjangoObjectType`` for the ``Pet`` model and return that under the key ``pet``. Otherwise it will
+return a list of errors.
+
+You can change the input name (default is ``input``) and the return field name (default is the model name lowercase).
+
+.. code:: python
+
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+ input_field_name = 'data'
+ return_field_name = 'my_pet'
+
+Form validation
+---------------
+
+Form mutations will call ``is_valid()`` on your forms.
+
+If the form is valid then ``form_valid(form, info)`` is called on the mutation. Override this method to change how
+the form is saved or to return a different Graphene object type.
+
+If the form is *not* valid then a list of errors will be returned. These errors have two fields: ``field``, a string
+containing the name of the invalid form field, and ``messages``, a list of strings with the validation messages.
diff --git a/docs/index.rst b/docs/index.rst
index 256da68..7c64ae7 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -12,4 +12,5 @@ Contents:
authorization
debug
rest-framework
+ form-mutations
introspection
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 2548604..220b7cf 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,3 +1,3 @@
sphinx
# Docs template
-https://github.com/graphql-python/graphene-python.org/archive/docs.zip
+http://graphene-python.org/sphinx_graphene_theme.zip
diff --git a/docs/rest-framework.rst b/docs/rest-framework.rst
index 5e5dd70..ce666de 100644
--- a/docs/rest-framework.rst
+++ b/docs/rest-framework.rst
@@ -19,3 +19,46 @@ You can create a Mutation based on a serializer by using the
class Meta:
serializer_class = MySerializer
+Create/Update Operations
+---------------------
+
+By default ModelSerializers accept create and update operations. To
+customize this use the `model_operations` attribute. The update
+operation looks up models by the primary key by default. You can
+customize the look up with the lookup attribute.
+
+.. code:: python
+
+ from graphene_django.rest_framework.mutation import SerializerMutation
+
+ class AwesomeModelMutation(SerializerMutation):
+ class Meta:
+ serializer_class = MyModelSerializer
+ model_operations = ['create', 'update']
+ lookup_field = 'id'
+
+Overriding Update Queries
+-------------------------
+
+Use the method `get_serializer_kwargs` to override how
+updates are applied.
+
+.. code:: python
+
+ from graphene_django.rest_framework.mutation import SerializerMutation
+
+ class AwesomeModelMutation(SerializerMutation):
+ class Meta:
+ serializer_class = MyModelSerializer
+
+ @classmethod
+ def get_serializer_kwargs(cls, root, info, **input):
+ if 'id' in input:
+ instance = Post.objects.filter(id=input['id'], owner=info.context.user).first()
+ if instance:
+ return {'instance': instance, 'data': input, 'partial': True}
+
+ else:
+ raise http.Http404
+
+ return {'data': input, 'partial': True}
diff --git a/docs/tutorial-plain.rst b/docs/tutorial-plain.rst
index 592f244..a87b011 100644
--- a/docs/tutorial-plain.rst
+++ b/docs/tutorial-plain.rst
@@ -68,7 +68,8 @@ Let's get started with these models:
class Ingredient(models.Model):
name = models.CharField(max_length=100)
notes = models.TextField()
- category = models.ForeignKey(Category, related_name='ingredients')
+ category = models.ForeignKey(
+ Category, related_name='ingredients', on_delete=models.CASCADE)
def __str__(self):
return self.name
@@ -80,9 +81,10 @@ Add ingredients as INSTALLED_APPS:
INSTALLED_APPS = [
...
# Install the ingredients app
- 'ingredients',
+ 'cookbook.ingredients',
]
+
Don't forget to create & run migrations:
.. code:: bash
@@ -111,6 +113,18 @@ Alternatively you can use the Django admin interface to create some data
yourself. You'll need to run the development server (see below), and
create a login for yourself too (``./manage.py createsuperuser``).
+Register models with admin panel:
+
+.. code:: python
+
+ # cookbook/ingredients/admin.py
+ from django.contrib import admin
+ from cookbook.ingredients.models import Category, Ingredient
+
+ admin.site.register(Category)
+ admin.site.register(Ingredient)
+
+
Hello GraphQL - Schema and Object Types
---------------------------------------
@@ -153,7 +167,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
model = Ingredient
- class Query(graphene.AbstractType):
+ class Query(object):
all_categories = graphene.List(CategoryType)
all_ingredients = graphene.List(IngredientType)
@@ -165,9 +179,9 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
return Ingredient.objects.select_related('category').all()
-Note that the above ``Query`` class is marked as 'abstract'. This is
-because we will now create a project-level query which will combine all
-our app-level queries.
+Note that the above ``Query`` class is a mixin, inheriting from
+``object``. This is because we will now create a project-level query
+class which will combine all our app-level mixins.
Create the parent project-level ``cookbook/schema.py``:
@@ -426,7 +440,7 @@ We can update our schema to support that, by adding new query for ``ingredient``
model = Ingredient
- class Query(graphene.AbstractType):
+ class Query(object):
category = graphene.Field(CategoryType,
id=graphene.Int(),
name=graphene.String())
@@ -445,8 +459,8 @@ We can update our schema to support that, by adding new query for ``ingredient``
return Ingredient.objects.all()
def resolve_category(self, info, **kwargs):
- id = kargs.get('id')
- name = kargs.get('name')
+ id = kwargs.get('id')
+ name = kwargs.get('name')
if id is not None:
return Category.objects.get(pk=id)
@@ -457,8 +471,8 @@ We can update our schema to support that, by adding new query for ``ingredient``
return None
def resolve_ingredient(self, info, **kwargs):
- id = kargs.get('id')
- name = kargs.get('name')
+ id = kwargs.get('id')
+ name = kwargs.get('name')
if id is not None:
return Ingredient.objects.get(pk=id)
diff --git a/docs/tutorial-relay.rst b/docs/tutorial-relay.rst
index 3ac4cec..f2502d7 100644
--- a/docs/tutorial-relay.rst
+++ b/docs/tutorial-relay.rst
@@ -10,7 +10,7 @@ app `__
-* `GraphQL Relay Specification `__
+* `GraphQL Relay Specification `__
Setup the Django project
------------------------
@@ -118,7 +118,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
.. code:: python
# cookbook/ingredients/schema.py
- from graphene import relay, ObjectType, AbstractType
+ from graphene import relay, ObjectType
from graphene_django import DjangoObjectType
from graphene_django.filter import DjangoFilterConnectionField
@@ -147,7 +147,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
interfaces = (relay.Node, )
- class Query(AbstractType):
+ class Query(object):
category = relay.Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode)
diff --git a/examples/cookbook-plain/README.md b/examples/cookbook-plain/README.md
index 018c584..0ec906b 100644
--- a/examples/cookbook-plain/README.md
+++ b/examples/cookbook-plain/README.md
@@ -3,7 +3,7 @@ Cookbook Example Django Project
This example project demos integration between Graphene and Django.
The project contains two apps, one named `ingredients` and another
-named `recepies`.
+named `recipes`.
Getting started
---------------
@@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries!
-(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema)
+(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries)
diff --git a/examples/cookbook-plain/cookbook/ingredients/schema.py b/examples/cookbook-plain/cookbook/ingredients/schema.py
index 895f216..1f3bb18 100644
--- a/examples/cookbook-plain/cookbook/ingredients/schema.py
+++ b/examples/cookbook-plain/cookbook/ingredients/schema.py
@@ -14,7 +14,7 @@ class IngredientType(DjangoObjectType):
model = Ingredient
-class Query(graphene.AbstractType):
+class Query(object):
category = graphene.Field(CategoryType,
id=graphene.Int(),
name=graphene.String())
diff --git a/examples/cookbook-plain/cookbook/ingredients/tests.py b/examples/cookbook-plain/cookbook/ingredients/tests.py
new file mode 100644
index 0000000..4929020
--- /dev/null
+++ b/examples/cookbook-plain/cookbook/ingredients/tests.py
@@ -0,0 +1,2 @@
+
+# Create your tests here.
diff --git a/examples/cookbook-plain/cookbook/ingredients/views.py b/examples/cookbook-plain/cookbook/ingredients/views.py
new file mode 100644
index 0000000..b8e4ee0
--- /dev/null
+++ b/examples/cookbook-plain/cookbook/ingredients/views.py
@@ -0,0 +1,2 @@
+
+# Create your views here.
diff --git a/examples/cookbook-plain/cookbook/recipes/models.py b/examples/cookbook-plain/cookbook/recipes/models.py
index e688044..ca12fac 100644
--- a/examples/cookbook-plain/cookbook/recipes/models.py
+++ b/examples/cookbook-plain/cookbook/recipes/models.py
@@ -6,6 +6,7 @@ from cookbook.ingredients.models import Ingredient
class Recipe(models.Model):
title = models.CharField(max_length=100)
instructions = models.TextField()
+ __unicode__ = lambda self: self.title
class RecipeIngredient(models.Model):
diff --git a/examples/cookbook-plain/cookbook/recipes/schema.py b/examples/cookbook-plain/cookbook/recipes/schema.py
index 8ea1ccd..040c985 100644
--- a/examples/cookbook-plain/cookbook/recipes/schema.py
+++ b/examples/cookbook-plain/cookbook/recipes/schema.py
@@ -14,7 +14,7 @@ class RecipeIngredientType(DjangoObjectType):
model = RecipeIngredient
-class Query(graphene.AbstractType):
+class Query(object):
recipe = graphene.Field(RecipeType,
id=graphene.Int(),
title=graphene.String())
diff --git a/examples/cookbook-plain/cookbook/recipes/tests.py b/examples/cookbook-plain/cookbook/recipes/tests.py
new file mode 100644
index 0000000..4929020
--- /dev/null
+++ b/examples/cookbook-plain/cookbook/recipes/tests.py
@@ -0,0 +1,2 @@
+
+# Create your tests here.
diff --git a/examples/cookbook-plain/cookbook/recipes/views.py b/examples/cookbook-plain/cookbook/recipes/views.py
new file mode 100644
index 0000000..b8e4ee0
--- /dev/null
+++ b/examples/cookbook-plain/cookbook/recipes/views.py
@@ -0,0 +1,2 @@
+
+# Create your views here.
diff --git a/examples/cookbook-plain/requirements.txt b/examples/cookbook-plain/requirements.txt
index a693bd1..362a39a 100644
--- a/examples/cookbook-plain/requirements.txt
+++ b/examples/cookbook-plain/requirements.txt
@@ -1,4 +1,4 @@
graphene
graphene-django
-graphql-core
+graphql-core>=2.1rc1
django==1.9
diff --git a/examples/cookbook/README.md b/examples/cookbook/README.md
index 1d3fc31..0ec906b 100644
--- a/examples/cookbook/README.md
+++ b/examples/cookbook/README.md
@@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries!
-(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema)
+(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries)
diff --git a/examples/cookbook/cookbook/ingredients/admin.py b/examples/cookbook/cookbook/ingredients/admin.py
index 2b16cdc..b57cbc3 100644
--- a/examples/cookbook/cookbook/ingredients/admin.py
+++ b/examples/cookbook/cookbook/ingredients/admin.py
@@ -2,9 +2,11 @@ from django.contrib import admin
from cookbook.ingredients.models import Category, Ingredient
+
@admin.register(Ingredient)
class IngredientAdmin(admin.ModelAdmin):
- list_display = ("id","name","category")
- list_editable = ("name","category")
-
+ list_display = ('id', 'name', 'category')
+ list_editable = ('name', 'category')
+
+
admin.site.register(Category)
diff --git a/examples/cookbook/cookbook/ingredients/models.py b/examples/cookbook/cookbook/ingredients/models.py
index a072bcf..2f0eba3 100644
--- a/examples/cookbook/cookbook/ingredients/models.py
+++ b/examples/cookbook/cookbook/ingredients/models.py
@@ -10,7 +10,7 @@ class Category(models.Model):
class Ingredient(models.Model):
name = models.CharField(max_length=100)
- notes = models.TextField(null=True,blank=True)
+ notes = models.TextField(null=True, blank=True)
category = models.ForeignKey(Category, related_name='ingredients')
def __str__(self):
diff --git a/examples/cookbook/cookbook/ingredients/schema.py b/examples/cookbook/cookbook/ingredients/schema.py
index b8b3c12..5ad92e8 100644
--- a/examples/cookbook/cookbook/ingredients/schema.py
+++ b/examples/cookbook/cookbook/ingredients/schema.py
@@ -1,5 +1,5 @@
from cookbook.ingredients.models import Category, Ingredient
-from graphene import AbstractType, Node
+from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType
@@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType):
}
-class Query(AbstractType):
+class Query(object):
category = Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode)
diff --git a/examples/cookbook/cookbook/recipes/admin.py b/examples/cookbook/cookbook/recipes/admin.py
index 57e0418..10d568f 100644
--- a/examples/cookbook/cookbook/recipes/admin.py
+++ b/examples/cookbook/cookbook/recipes/admin.py
@@ -2,9 +2,11 @@ from django.contrib import admin
from cookbook.recipes.models import Recipe, RecipeIngredient
+
class RecipeIngredientInline(admin.TabularInline):
- model = RecipeIngredient
+ model = RecipeIngredient
+
@admin.register(Recipe)
class RecipeAdmin(admin.ModelAdmin):
- inlines = [RecipeIngredientInline]
+ inlines = [RecipeIngredientInline]
diff --git a/examples/cookbook/cookbook/recipes/models.py b/examples/cookbook/cookbook/recipes/models.py
index f666fe8..ca12fac 100644
--- a/examples/cookbook/cookbook/recipes/models.py
+++ b/examples/cookbook/cookbook/recipes/models.py
@@ -8,6 +8,7 @@ class Recipe(models.Model):
instructions = models.TextField()
__unicode__ = lambda self: self.title
+
class RecipeIngredient(models.Model):
recipe = models.ForeignKey(Recipe, related_name='amounts')
ingredient = models.ForeignKey(Ingredient, related_name='used_by')
diff --git a/examples/cookbook/cookbook/recipes/schema.py b/examples/cookbook/cookbook/recipes/schema.py
index 35a21de..8018322 100644
--- a/examples/cookbook/cookbook/recipes/schema.py
+++ b/examples/cookbook/cookbook/recipes/schema.py
@@ -1,5 +1,5 @@
from cookbook.recipes.models import Recipe, RecipeIngredient
-from graphene import AbstractType, Node
+from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType
@@ -24,7 +24,7 @@ class RecipeIngredientNode(DjangoObjectType):
}
-class Query(AbstractType):
+class Query(object):
recipe = Node.Field(RecipeNode)
all_recipes = DjangoFilterConnectionField(RecipeNode)
diff --git a/examples/cookbook/cookbook/schema.py b/examples/cookbook/cookbook/schema.py
index 910e259..f8606a7 100644
--- a/examples/cookbook/cookbook/schema.py
+++ b/examples/cookbook/cookbook/schema.py
@@ -5,7 +5,9 @@ import graphene
from graphene_django.debug import DjangoDebug
-class Query(cookbook.recipes.schema.Query, cookbook.ingredients.schema.Query, graphene.ObjectType):
+class Query(cookbook.ingredients.schema.Query,
+ cookbook.recipes.schema.Query,
+ graphene.ObjectType):
debug = graphene.Field(DjangoDebug, name='__debug')
diff --git a/examples/cookbook/cookbook/settings.py b/examples/cookbook/cookbook/settings.py
index 1916201..948292d 100644
--- a/examples/cookbook/cookbook/settings.py
+++ b/examples/cookbook/cookbook/settings.py
@@ -1,3 +1,4 @@
+# flake8: noqa
"""
Django settings for cookbook project.
diff --git a/examples/cookbook/cookbook/urls.py b/examples/cookbook/cookbook/urls.py
index 9410ca5..9f8755b 100644
--- a/examples/cookbook/cookbook/urls.py
+++ b/examples/cookbook/cookbook/urls.py
@@ -3,6 +3,7 @@ from django.contrib import admin
from graphene_django.views import GraphQLView
+
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^graphql', GraphQLView.as_view(graphiql=True)),
diff --git a/examples/cookbook/requirements.txt b/examples/cookbook/requirements.txt
index 66fa629..b2ace1f 100644
--- a/examples/cookbook/requirements.txt
+++ b/examples/cookbook/requirements.txt
@@ -1,5 +1,5 @@
graphene
graphene-django
-graphql-core
+graphql-core>=2.1rc1
django==1.9
-django-filter==0.11.0
+django-filter>=2
diff --git a/examples/cookbook/setup.cfg b/examples/cookbook/setup.cfg
new file mode 100644
index 0000000..8c6a6e8
--- /dev/null
+++ b/examples/cookbook/setup.cfg
@@ -0,0 +1,2 @@
+[flake8]
+exclude=migrations,.git,__pycache__
diff --git a/examples/starwars/models.py b/examples/starwars/models.py
index 2f80e27..45741da 100644
--- a/examples/starwars/models.py
+++ b/examples/starwars/models.py
@@ -5,7 +5,7 @@ from django.db import models
class Character(models.Model):
name = models.CharField(max_length=50)
- ship = models.ForeignKey('Ship', blank=True, null=True, related_name='characters')
+ ship = models.ForeignKey('Ship', on_delete=models.CASCADE, blank=True, null=True, related_name='characters')
def __str__(self):
return self.name
@@ -13,7 +13,7 @@ class Character(models.Model):
class Faction(models.Model):
name = models.CharField(max_length=50)
- hero = models.ForeignKey(Character)
+ hero = models.ForeignKey(Character, on_delete=models.CASCADE)
def __str__(self):
return self.name
@@ -21,7 +21,7 @@ class Faction(models.Model):
class Ship(models.Model):
name = models.CharField(max_length=50)
- faction = models.ForeignKey(Faction, related_name='ships')
+ faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name='ships')
def __str__(self):
return self.name
diff --git a/graphene_django/__init__.py b/graphene_django/__init__.py
index 5ba360f..4538cb3 100644
--- a/graphene_django/__init__.py
+++ b/graphene_django/__init__.py
@@ -1,14 +1,6 @@
-from .types import (
- DjangoObjectType,
-)
-from .fields import (
- DjangoConnectionField,
-)
+from .types import DjangoObjectType
+from .fields import DjangoConnectionField
-__version__ = '2.0.0'
+__version__ = "2.2.0"
-__all__ = [
- '__version__',
- 'DjangoObjectType',
- 'DjangoConnectionField'
-]
+__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"]
diff --git a/graphene_django/compat.py b/graphene_django/compat.py
index 0269e33..4a51de8 100644
--- a/graphene_django/compat.py
+++ b/graphene_django/compat.py
@@ -5,13 +5,7 @@ class MissingType(object):
try:
# Postgres fields are only available in Django with psycopg2 installed
# and we cannot have psycopg2 on PyPy
- from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField
+ from django.contrib.postgres.fields import (ArrayField, HStoreField,
+ JSONField, RangeField)
except ImportError:
- ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4
-
-
-try:
- # Postgres fields are only available in Django 1.9+
- from django.contrib.postgres.fields import JSONField
-except ImportError:
- JSONField = MissingType
+ ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4
diff --git a/graphene_django/converter.py b/graphene_django/converter.py
index 3f8511c..c40313d 100644
--- a/graphene_django/converter.py
+++ b/graphene_django/converter.py
@@ -1,9 +1,22 @@
from django.db import models
from django.utils.encoding import force_text
-from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
- NonNull, String, UUID)
-from graphene.types.datetime import DateTime, Time
+from graphene import (
+ ID,
+ Boolean,
+ Dynamic,
+ Enum,
+ Field,
+ Float,
+ Int,
+ List,
+ NonNull,
+ String,
+ UUID,
+ DateTime,
+ Date,
+ Time,
+)
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name
@@ -33,37 +46,44 @@ def get_choices(choices):
else:
name = convert_choice_name(value)
while name in converted_names:
- name += '_' + str(len(converted_names))
+ name += "_" + str(len(converted_names))
converted_names.append(name)
description = help_text
yield name, value, description
def convert_django_field_with_choices(field, registry=None):
- choices = getattr(field, 'choices', None)
+ if registry is not None:
+ converted = registry.get_converted_field(field)
+ if converted:
+ return converted
+ choices = getattr(field, "choices", None)
if choices:
meta = field.model._meta
- name = to_camel_case('{}_{}'.format(meta.object_name, field.name))
+ name = to_camel_case("{}_{}".format(meta.object_name, field.name))
choices = list(get_choices(choices))
named_choices = [(c[0], c[1]) for c in choices]
named_choices_descriptions = {c[0]: c[2] for c in choices}
class EnumWithDescriptionsType(object):
-
@property
def description(self):
return named_choices_descriptions[self.name]
enum = Enum(name, list(named_choices), type=EnumWithDescriptionsType)
- return enum(description=field.help_text, required=not field.null)
- return convert_django_field(field, registry)
+ converted = enum(description=field.help_text, required=not field.null)
+ else:
+ converted = convert_django_field(field, registry)
+ if registry is not None:
+ registry.register_converted_field(field, converted)
+ return converted
@singledispatch
def convert_django_field(field, registry=None):
raise Exception(
- "Don't know how to convert the Django field %s (%s)" %
- (field, field.__class__))
+ "Don't know how to convert the Django field %s (%s)" % (field, field.__class__)
+ )
@convert_django_field.register(models.CharField)
@@ -73,6 +93,7 @@ def convert_django_field(field, registry=None):
@convert_django_field.register(models.URLField)
@convert_django_field.register(models.GenericIPAddressField)
@convert_django_field.register(models.FileField)
+@convert_django_field.register(models.FilePathField)
def convert_field_to_string(field, registry=None):
return String(description=field.help_text, required=not field.null)
@@ -113,9 +134,14 @@ def convert_field_to_float(field, registry=None):
return Float(description=field.help_text, required=not field.null)
+@convert_django_field.register(models.DateTimeField)
+def convert_datetime_to_string(field, registry=None):
+ return DateTime(description=field.help_text, required=not field.null)
+
+
@convert_django_field.register(models.DateField)
def convert_date_to_string(field, registry=None):
- return DateTime(description=field.help_text, required=not field.null)
+ return Date(description=field.help_text, required=not field.null)
@convert_django_field.register(models.TimeField)
@@ -134,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
# We do this for a bug in Django 1.8, where null attr
# is not available in the OneToOneRel instance
- null = getattr(field, 'null', True)
+ null = getattr(field, "null", True)
return Field(_type, required=not null)
return Dynamic(dynamic_type)
@@ -158,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None):
# defined filter_fields in the DjangoObjectType Meta
if _type._meta.filter_fields:
from .filter.fields import DjangoFilterConnectionField
+
return DjangoFilterConnectionField(_type)
return DjangoConnectionField(_type)
diff --git a/graphene_django/debug/__init__.py b/graphene_django/debug/__init__.py
index cd5015e..3e078da 100644
--- a/graphene_django/debug/__init__.py
+++ b/graphene_django/debug/__init__.py
@@ -1,4 +1,4 @@
from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug
-__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
+__all__ = ["DjangoDebugMiddleware", "DjangoDebug"]
diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py
index 2b11f7e..48d471f 100644
--- a/graphene_django/debug/middleware.py
+++ b/graphene_django/debug/middleware.py
@@ -7,7 +7,6 @@ from .types import DjangoDebug
class DjangoDebugContext(object):
-
def __init__(self):
self.debug_promise = None
self.promises = []
@@ -38,20 +37,21 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object):
-
def resolve(self, next, root, info, **args):
context = info.context
- django_debug = getattr(context, 'django_debug', None)
+ django_debug = getattr(context, "django_debug", None)
if not django_debug:
if context is None:
- raise Exception('DjangoDebug cannot be executed in None contexts')
+ raise Exception("DjangoDebug cannot be executed in None contexts")
try:
context.django_debug = DjangoDebugContext()
except Exception:
- raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
- context.__class__.__name__
- ))
- if info.schema.get_type('DjangoDebug') == info.return_type:
+ raise Exception(
+ "DjangoDebug need the context to be writable, context received: {}.".format(
+ context.__class__.__name__
+ )
+ )
+ if info.schema.get_type("DjangoDebug") == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, info, **args)
context.django_debug.add_promise(promise)
diff --git a/graphene_django/debug/sql/tracking.py b/graphene_django/debug/sql/tracking.py
index 9d14e4b..f96583b 100644
--- a/graphene_django/debug/sql/tracking.py
+++ b/graphene_django/debug/sql/tracking.py
@@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception):
class ThreadLocalState(local):
-
def __init__(self):
self.enabled = True
@@ -35,7 +34,7 @@ recording = state.recording # export function
def wrap_cursor(connection, panel):
- if not hasattr(connection, '_graphene_cursor'):
+ if not hasattr(connection, "_graphene_cursor"):
connection._graphene_cursor = connection.cursor
def cursor():
@@ -46,7 +45,7 @@ def wrap_cursor(connection, panel):
def unwrap_cursor(connection):
- if hasattr(connection, '_graphene_cursor'):
+ if hasattr(connection, "_graphene_cursor"):
previous_cursor = connection._graphene_cursor
connection.cursor = previous_cursor
del connection._graphene_cursor
@@ -87,15 +86,14 @@ class NormalCursorWrapper(object):
if not params:
return params
if isinstance(params, dict):
- return dict((key, self._quote_expr(value))
- for key, value in params.items())
+ return dict((key, self._quote_expr(value)) for key, value in params.items())
return list(map(self._quote_expr, params))
def _decode(self, param):
try:
return force_text(param, strings_only=True)
except UnicodeDecodeError:
- return '(encoded string)'
+ return "(encoded string)"
def _record(self, method, sql, params):
start_time = time()
@@ -103,45 +101,48 @@ class NormalCursorWrapper(object):
return method(sql, params)
finally:
stop_time = time()
- duration = (stop_time - start_time)
- _params = ''
+ duration = stop_time - start_time
+ _params = ""
try:
_params = json.dumps(list(map(self._decode, params)))
except Exception:
pass # object not JSON serializable
- alias = getattr(self.db, 'alias', 'default')
+ alias = getattr(self.db, "alias", "default")
conn = self.db.connection
- vendor = getattr(conn, 'vendor', 'unknown')
+ vendor = getattr(conn, "vendor", "unknown")
params = {
- 'vendor': vendor,
- 'alias': alias,
- 'sql': self.db.ops.last_executed_query(
- self.cursor, sql, self._quote_params(params)),
- 'duration': duration,
- 'raw_sql': sql,
- 'params': _params,
- 'start_time': start_time,
- 'stop_time': stop_time,
- 'is_slow': duration > 10,
- 'is_select': sql.lower().strip().startswith('select'),
+ "vendor": vendor,
+ "alias": alias,
+ "sql": self.db.ops.last_executed_query(
+ self.cursor, sql, self._quote_params(params)
+ ),
+ "duration": duration,
+ "raw_sql": sql,
+ "params": _params,
+ "start_time": start_time,
+ "stop_time": stop_time,
+ "is_slow": duration > 10,
+ "is_select": sql.lower().strip().startswith("select"),
}
- if vendor == 'postgresql':
+ if vendor == "postgresql":
# If an erroneous query was ran on the connection, it might
# be in a state where checking isolation_level raises an
# exception.
try:
iso_level = conn.isolation_level
except conn.InternalError:
- iso_level = 'unknown'
- params.update({
- 'trans_id': self.logger.get_transaction_id(alias),
- 'trans_status': conn.get_transaction_status(),
- 'iso_level': iso_level,
- 'encoding': conn.encoding,
- })
+ iso_level = "unknown"
+ params.update(
+ {
+ "trans_id": self.logger.get_transaction_id(alias),
+ "trans_status": conn.get_transaction_status(),
+ "iso_level": iso_level,
+ "encoding": conn.encoding,
+ }
+ )
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility
diff --git a/graphene_django/debug/sql/types.py b/graphene_django/debug/sql/types.py
index 6ae4d31..850ced4 100644
--- a/graphene_django/debug/sql/types.py
+++ b/graphene_django/debug/sql/types.py
@@ -2,19 +2,53 @@ from graphene import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType):
- vendor = String()
- alias = String()
- sql = String()
- duration = Float()
- raw_sql = String()
- params = String()
- start_time = Float()
- stop_time = Float()
- is_slow = Boolean()
- is_select = Boolean()
+ class Meta:
+ description = (
+ "Represents a single database query made to a Django managed DB."
+ )
+
+ vendor = String(
+ required=True,
+ description=(
+ "The type of database being used (e.g. postrgesql, mysql, sqlite)."
+ ),
+ )
+ alias = String(
+ required=True,
+ description="The Django database alias (e.g. 'default').",
+ )
+ sql = String(description="The actual SQL sent to this database.")
+ duration = Float(
+ required=True,
+ description="Duration of this database query in seconds.",
+ )
+ raw_sql = String(
+ required=True,
+ description="The raw SQL of this query, without params.",
+ )
+ params = String(
+ required=True,
+ description="JSON encoded database query parameters.",
+ )
+ start_time = Float(
+ required=True,
+ description="Start time of this database query.",
+ )
+ stop_time = Float(
+ required=True,
+ description="Stop time of this database query.",
+ )
+ is_slow = Boolean(
+ required=True,
+ description="Whether this database query took more than 10 seconds.",
+ )
+ is_select = Boolean(
+ required=True,
+ description="Whether this database query was a SELECT.",
+ )
# Postgres
- trans_id = String()
- trans_status = String()
- iso_level = String()
- encoding = String()
+ trans_id = String(description="Postgres transaction ID if available.")
+ trans_status = String(description="Postgres transaction status if available.")
+ iso_level = String(description="Postgres isolation level if available.")
+ encoding = String(description="Postgres connection encoding if available.")
diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py
index 72747b2..f2ef096 100644
--- a/graphene_django/debug/tests/test_query.py
+++ b/graphene_django/debug/tests/test_query.py
@@ -12,31 +12,31 @@ from ..types import DjangoDebug
class context(object):
pass
+
# from examples.starwars_django.models import Character
pytestmark = pytest.mark.django_db
def test_should_query_field():
- r1 = Reporter(last_name='ABA')
+ r1 = Reporter(last_name="ABA")
r1.save()
- r2 = Reporter(last_name='Griffin')
+ r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
- debug = graphene.Field(DjangoDebug, name='__debug')
+ debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_reporter(self, info, **args):
return Reporter.objects.first()
- query = '''
+ query = """
query ReporterQuery {
reporter {
lastName
@@ -47,43 +47,40 @@ def test_should_query_field():
}
}
}
- '''
+ """
expected = {
- 'reporter': {
- 'lastName': 'ABA',
+ "reporter": {"lastName": "ABA"},
+ "__debug": {
+ "sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}]
},
- '__debug': {
- 'sql': [{
- 'rawSql': str(Reporter.objects.order_by('pk')[:1].query)
- }]
- }
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
+ result = schema.execute(
+ query, context_value=context(), middleware=[DjangoDebugMiddleware()]
+ )
assert not result.errors
assert result.data == expected
def test_should_query_list():
- r1 = Reporter(last_name='ABA')
+ r1 = Reporter(last_name="ABA")
r1.save()
- r2 = Reporter(last_name='Griffin')
+ r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = graphene.List(ReporterType)
- debug = graphene.Field(DjangoDebug, name='__debug')
+ debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
- query = '''
+ query = """
query ReporterQuery {
allReporters {
lastName
@@ -94,45 +91,38 @@ def test_should_query_list():
}
}
}
- '''
+ """
expected = {
- 'allReporters': [{
- 'lastName': 'ABA',
- }, {
- 'lastName': 'Griffin',
- }],
- '__debug': {
- 'sql': [{
- 'rawSql': str(Reporter.objects.all().query)
- }]
- }
+ "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}],
+ "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]},
}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
+ result = schema.execute(
+ query, context_value=context(), middleware=[DjangoDebugMiddleware()]
+ )
assert not result.errors
assert result.data == expected
def test_should_query_connection():
- r1 = Reporter(last_name='ABA')
+ r1 = Reporter(last_name="ABA")
r1.save()
- r2 = Reporter(last_name='Griffin')
+ r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
- debug = graphene.Field(DjangoDebug, name='__debug')
+ debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
- query = '''
+ query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@@ -147,48 +137,41 @@ def test_should_query_connection():
}
}
}
- '''
- expected = {
- 'allReporters': {
- 'edges': [{
- 'node': {
- 'lastName': 'ABA',
- }
- }]
- },
- }
+ """
+ expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
+ result = schema.execute(
+ query, context_value=context(), middleware=[DjangoDebugMiddleware()]
+ )
assert not result.errors
- assert result.data['allReporters'] == expected['allReporters']
- assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
+ assert result.data["allReporters"] == expected["allReporters"]
+ assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
- assert result.data['__debug']['sql'][1]['rawSql'] == query
+ assert result.data["__debug"]["sql"][1]["rawSql"] == query
def test_should_query_connectionfilter():
from ...filter import DjangoFilterConnectionField
- r1 = Reporter(last_name='ABA')
+ r1 = Reporter(last_name="ABA")
r1.save()
- r2 = Reporter(last_name='Griffin')
+ r2 = Reporter(last_name="Griffin")
r2.save()
class ReporterType(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class Query(graphene.ObjectType):
- all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name'])
+ all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"])
s = graphene.String(resolver=lambda *_: "S")
- debug = graphene.Field(DjangoDebug, name='__debug')
+ debug = graphene.Field(DjangoDebug, name="__debug")
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
- query = '''
+ query = """
query ReporterQuery {
allReporters(first:1) {
edges {
@@ -203,20 +186,14 @@ def test_should_query_connectionfilter():
}
}
}
- '''
- expected = {
- 'allReporters': {
- 'edges': [{
- 'node': {
- 'lastName': 'ABA',
- }
- }]
- },
- }
+ """
+ expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}}
schema = graphene.Schema(query=Query)
- result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()])
+ result = schema.execute(
+ query, context_value=context(), middleware=[DjangoDebugMiddleware()]
+ )
assert not result.errors
- assert result.data['allReporters'] == expected['allReporters']
- assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
+ assert result.data["allReporters"] == expected["allReporters"]
+ assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"]
query = str(Reporter.objects.all()[:1].query)
- assert result.data['__debug']['sql'][1]['rawSql'] == query
+ assert result.data["__debug"]["sql"][1]["rawSql"] == query
diff --git a/graphene_django/debug/types.py b/graphene_django/debug/types.py
index 0d3701d..cda5725 100644
--- a/graphene_django/debug/types.py
+++ b/graphene_django/debug/types.py
@@ -4,4 +4,10 @@ from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType):
- sql = List(DjangoDebugSQL)
+ class Meta:
+ description = "Debugging information for the current query."
+
+ sql = List(
+ DjangoDebugSQL,
+ description="Executed SQL queries for this API query.",
+ )
diff --git a/graphene_django/fields.py b/graphene_django/fields.py
index aa7f124..1ecce45 100644
--- a/graphene_django/fields.py
+++ b/graphene_django/fields.py
@@ -13,7 +13,6 @@ from .utils import maybe_queryset
class DjangoListField(Field):
-
def __init__(self, _type, *args, **kwargs):
super(DjangoListField, self).__init__(List(_type), *args, **kwargs)
@@ -30,25 +29,28 @@ class DjangoListField(Field):
class DjangoConnectionField(ConnectionField):
-
def __init__(self, *args, **kwargs):
- self.on = kwargs.pop('on', False)
+ self.on = kwargs.pop("on", False)
self.max_limit = kwargs.pop(
- 'max_limit',
- graphene_settings.RELAY_CONNECTION_MAX_LIMIT
+ "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
)
self.enforce_first_or_last = kwargs.pop(
- 'enforce_first_or_last',
- graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST
+ "enforce_first_or_last",
+ graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
)
super(DjangoConnectionField, self).__init__(*args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
+
_type = super(ConnectionField, self).type
- assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types"
- assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
+ assert issubclass(
+ _type, DjangoObjectType
+ ), "DjangoConnectionField only accepts DjangoObjectType types"
+ assert _type._meta.connection, "The type {} doesn't have a connection".format(
+ _type.__name__
+ )
return _type._meta.connection
@property
@@ -67,6 +69,10 @@ class DjangoConnectionField(ConnectionField):
@classmethod
def merge_querysets(cls, default_queryset, queryset):
+ if default_queryset.query.distinct and not queryset.query.distinct:
+ queryset = queryset.distinct()
+ elif queryset.query.distinct and not default_queryset.query.distinct:
+ default_queryset = default_queryset.distinct()
return queryset & default_queryset
@classmethod
@@ -96,28 +102,37 @@ class DjangoConnectionField(ConnectionField):
return connection
@classmethod
- def connection_resolver(cls, resolver, connection, default_manager, max_limit,
- enforce_first_or_last, root, info, **args):
- first = args.get('first')
- last = args.get('last')
+ def connection_resolver(
+ cls,
+ resolver,
+ connection,
+ default_manager,
+ max_limit,
+ enforce_first_or_last,
+ root,
+ info,
+ **args
+ ):
+ first = args.get("first")
+ last = args.get("last")
if enforce_first_or_last:
assert first or last, (
- 'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
+ "You must provide a `first` or `last` value to properly paginate the `{}` connection."
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
- 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
+ "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
).format(first, info.field_name, max_limit)
- args['first'] = min(first, max_limit)
+ args["first"] = min(first, max_limit)
if last:
assert last <= max_limit, (
- 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
- ).format(first, info.field_name, max_limit)
- args['last'] = min(last, max_limit)
+ "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
+ ).format(last, info.field_name, max_limit)
+ args["last"] = min(last, max_limit)
iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
@@ -134,5 +149,5 @@ class DjangoConnectionField(ConnectionField):
self.type,
self.get_manager(),
self.max_limit,
- self.enforce_first_or_last
+ self.enforce_first_or_last,
)
diff --git a/graphene_django/filter/__init__.py b/graphene_django/filter/__init__.py
index 24fae60..daafe56 100644
--- a/graphene_django/filter/__init__.py
+++ b/graphene_django/filter/__init__.py
@@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED:
warnings.warn(
"Use of django filtering requires the django-filter package "
- "be installed. You can do so using `pip install django-filter`", ImportWarning
+ "be installed. You can do so using `pip install django-filter`",
+ ImportWarning,
)
else:
from .fields import DjangoFilterConnectionField
from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter
- __all__ = ['DjangoFilterConnectionField',
- 'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter']
+ __all__ = [
+ "DjangoFilterConnectionField",
+ "GlobalIDFilter",
+ "GlobalIDMultipleChoiceFilter",
+ ]
diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py
index a80d8d7..cb42543 100644
--- a/graphene_django/filter/fields.py
+++ b/graphene_django/filter/fields.py
@@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class
class DjangoFilterConnectionField(DjangoConnectionField):
-
- def __init__(self, type, fields=None, order_by=None,
- extra_filter_meta=None, filterset_class=None,
- *args, **kwargs):
+ def __init__(
+ self,
+ type,
+ fields=None,
+ order_by=None,
+ extra_filter_meta=None,
+ filterset_class=None,
+ *args,
+ **kwargs
+ ):
self._fields = fields
self._provided_filterset_class = filterset_class
self._filterset_class = None
@@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self):
if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields
- meta = dict(model=self.model,
- fields=fields)
+ meta = dict(model=self.model, fields=fields)
if self._extra_filter_meta:
meta.update(self._extra_filter_meta)
- self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta)
+ self._filterset_class = get_filterset_class(
+ self._provided_filterset_class, **meta
+ )
return self._filterset_class
@@ -43,8 +50,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
- @staticmethod
- def merge_querysets(default_queryset, queryset):
+ @classmethod
+ def merge_querysets(cls, default_queryset, queryset):
# There could be the case where the default queryset (returned from the filterclass)
# and the resolver queryset have some limits on it.
# We only would be able to apply one of those, but not both
@@ -52,27 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField):
# See related PR: https://github.com/graphql-python/graphene-django/pull/126
- assert not (default_queryset.query.low_mark and queryset.query.low_mark), (
- 'Received two sliced querysets (low mark) in the connection, please slice only in one.'
- )
- assert not (default_queryset.query.high_mark and queryset.query.high_mark), (
- 'Received two sliced querysets (high mark) in the connection, please slice only in one.'
- )
+ assert not (
+ default_queryset.query.low_mark and queryset.query.low_mark
+ ), "Received two sliced querysets (low mark) in the connection, please slice only in one."
+ assert not (
+ default_queryset.query.high_mark and queryset.query.high_mark
+ ), "Received two sliced querysets (high mark) in the connection, please slice only in one."
low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits()
- queryset = default_queryset & queryset
+ queryset = super(DjangoFilterConnectionField, cls).merge_querysets(
+ default_queryset, queryset
+ )
queryset.query.set_limits(low, high)
return queryset
@classmethod
- def connection_resolver(cls, resolver, connection, default_manager, max_limit,
- enforce_first_or_last, filterset_class, filtering_args,
- root, info, **args):
+ def connection_resolver(
+ cls,
+ resolver,
+ connection,
+ default_manager,
+ max_limit,
+ enforce_first_or_last,
+ filterset_class,
+ filtering_args,
+ root,
+ info,
+ **args
+ ):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class(
data=filter_kwargs,
- queryset=default_manager.get_queryset()
+ queryset=default_manager.get_queryset(),
+ request=info.context,
).qs
return super(DjangoFilterConnectionField, cls).connection_resolver(
@@ -95,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.max_limit,
self.enforce_first_or_last,
self.filterset_class,
- self.filtering_args
+ self.filtering_args,
)
diff --git a/graphene_django/filter/filterset.py b/graphene_django/filter/filterset.py
index c716b05..4059083 100644
--- a/graphene_django/filter/filterset.py
+++ b/graphene_django/filter/filterset.py
@@ -1,8 +1,7 @@
import itertools
from django.db import models
-from django.utils.text import capfirst
-from django_filters import Filter, MultipleChoiceFilter
+from django_filters import Filter, MultipleChoiceFilter, VERSION
from django_filters.filterset import BaseFilterSet, FilterSet
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS
@@ -15,7 +14,10 @@ class GlobalIDFilter(Filter):
field_class = GlobalIDFormField
def filter(self, qs, value):
- _type, _id = from_global_id(value)
+ """ Convert the filter value to a primary key before filtering """
+ _id = None
+ if value is not None:
+ _, _id = from_global_id(value)
return super(GlobalIDFilter, self).filter(qs, _id)
@@ -28,71 +30,76 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter):
GRAPHENE_FILTER_SET_OVERRIDES = {
- models.AutoField: {
- 'filter_class': GlobalIDFilter,
- },
- models.OneToOneField: {
- 'filter_class': GlobalIDFilter,
- },
- models.ForeignKey: {
- 'filter_class': GlobalIDFilter,
- },
- models.ManyToManyField: {
- 'filter_class': GlobalIDMultipleChoiceFilter,
- }
+ models.AutoField: {"filter_class": GlobalIDFilter},
+ models.OneToOneField: {"filter_class": GlobalIDFilter},
+ models.ForeignKey: {"filter_class": GlobalIDFilter},
+ models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter},
+ models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter},
+ models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter},
}
class GrapheneFilterSetMixin(BaseFilterSet):
- FILTER_DEFAULTS = dict(itertools.chain(
- FILTER_FOR_DBFIELD_DEFAULTS.items(),
- GRAPHENE_FILTER_SET_OVERRIDES.items()
- ))
+ """ A django_filters.filterset.BaseFilterSet with default filter overrides
+ to handle global IDs """
- @classmethod
- def filter_for_reverse_field(cls, f, name):
- """Handles retrieving filters for reverse relationships
+ FILTER_DEFAULTS = dict(
+ itertools.chain(
+ FILTER_FOR_DBFIELD_DEFAULTS.items(),
+ GRAPHENE_FILTER_SET_OVERRIDES.items()
+ )
+ )
- We override the default implementation so that we can handle
- Global IDs (the default implementation expects database
- primary keys)
- """
- rel = f.field.rel
- default = {
- 'name': name,
- 'label': capfirst(rel.related_name)
- }
- if rel.multiple:
- # For to-many relationships
- return GlobalIDMultipleChoiceFilter(**default)
- else:
- # For to-one relationships
- return GlobalIDFilter(**default)
+
+# To support a Django 1.11 + Python 2.7 combination django-filter must be
+# < 2.x.x. To support the earlier version of django-filter, the
+# filter_for_reverse_field method must be present on GrapheneFilterSetMixin and
+# must not be present for later versions of django-filter.
+if VERSION[0] < 2:
+ from django.utils.text import capfirst
+
+ class GrapheneFilterSetMixinPython2(GrapheneFilterSetMixin):
+
+ @classmethod
+ def filter_for_reverse_field(cls, f, name):
+ """Handles retrieving filters for reverse relationships
+ We override the default implementation so that we can handle
+ Global IDs (the default implementation expects database
+ primary keys)
+ """
+ try:
+ rel = f.field.remote_field
+ except AttributeError:
+ rel = f.field.rel
+ default = {"name": name, "label": capfirst(rel.related_name)}
+ if rel.multiple:
+ # For to-many relationships
+ return GlobalIDMultipleChoiceFilter(**default)
+ else:
+ # For to-one relationships
+ return GlobalIDFilter(**default)
+
+ GrapheneFilterSetMixin = GrapheneFilterSetMixinPython2
def setup_filterset(filterset_class):
""" Wrap a provided filterset in Graphene-specific functionality
"""
return type(
- 'Graphene{}'.format(filterset_class.__name__),
+ "Graphene{}".format(filterset_class.__name__),
(filterset_class, GrapheneFilterSetMixin),
{},
)
-def custom_filterset_factory(model, filterset_base_class=FilterSet,
- **meta):
+def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta):
""" Create a filterset for the given model using the provided meta data
"""
- meta.update({
- 'model': model,
- })
- meta_class = type(str('Meta'), (object,), meta)
+ meta.update({"model": model})
+ meta_class = type(str("Meta"), (object,), meta)
filterset = type(
- str('%sFilterSet' % model._meta.object_name),
+ str("%sFilterSet" % model._meta.object_name),
(filterset_base_class, GrapheneFilterSetMixin),
- {
- 'Meta': meta_class
- }
+ {"Meta": meta_class},
)
return filterset
diff --git a/graphene_django/filter/tests/filters.py b/graphene_django/filter/tests/filters.py
index 4a3fbaa..359d2ba 100644
--- a/graphene_django/filter/tests/filters.py
+++ b/graphene_django/filter/tests/filters.py
@@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter
class ArticleFilter(django_filters.FilterSet):
-
class Meta:
model = Article
fields = {
- 'headline': ['exact', 'icontains'],
- 'pub_date': ['gt', 'lt', 'exact'],
- 'reporter': ['exact'],
+ "headline": ["exact", "icontains"],
+ "pub_date": ["gt", "lt", "exact"],
+ "reporter": ["exact"],
}
- order_by = OrderingFilter(fields=('pub_date',))
+ order_by = OrderingFilter(fields=("pub_date",))
class ReporterFilter(django_filters.FilterSet):
-
class Meta:
model = Reporter
- fields = ['first_name', 'last_name', 'email', 'pets']
+ fields = ["first_name", "last_name", "email", "pets"]
- order_by = OrderingFilter(fields=('pub_date',))
+ order_by = OrderingFilter(fields=("pub_date",))
class PetFilter(django_filters.FilterSet):
-
class Meta:
model = Pet
- fields = ['name']
+ fields = ["name"]
diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py
index 9a0ba21..f9ef0ae 100644
--- a/graphene_django/filter/tests/test_fields.py
+++ b/graphene_django/filter/tests/test_fields.py
@@ -2,50 +2,60 @@ from datetime import datetime
import pytest
-from graphene import Field, ObjectType, Schema, Argument, Float
+from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node
from graphene_django import DjangoObjectType
-from graphene_django.forms import (GlobalIDFormField,
- GlobalIDMultipleChoiceField)
+from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED
+# for annotation test
+from django.db.models import TextField, Value
+from django.db.models.functions import Concat
+
pytestmark = []
if DJANGO_FILTER_INSTALLED:
import django_filters
from django_filters import FilterSet, NumberFilter
- from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField,
- GlobalIDMultipleChoiceFilter)
- from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter
+ from graphene_django.filter import (
+ GlobalIDFilter,
+ DjangoFilterConnectionField,
+ GlobalIDMultipleChoiceFilter,
+ )
+ from graphene_django.filter.tests.filters import (
+ ArticleFilter,
+ PetFilter,
+ ReporterFilter,
+ )
else:
- pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible'))
+ pytestmark.append(
+ pytest.mark.skipif(
+ True, reason="django_filters not installed or not compatible"
+ )
+ )
pytestmark.append(pytest.mark.django_db)
if DJANGO_FILTER_INSTALLED:
- class ArticleNode(DjangoObjectType):
+ class ArticleNode(DjangoObjectType):
class Meta:
model = Article
- interfaces = (Node, )
- filter_fields = ('headline', )
-
+ interfaces = (Node,)
+ filter_fields = ("headline",)
class ReporterNode(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
-
+ interfaces = (Node,)
class PetNode(DjangoObjectType):
-
class Meta:
model = Pet
- interfaces = (Node, )
+ interfaces = (Node,)
# schema = Schema()
@@ -55,58 +65,47 @@ def get_args(field):
def assert_arguments(field, *arguments):
- ignore = ('after', 'before', 'first', 'last', 'order_by')
+ ignore = ("after", "before", "first", "last", "order_by")
args = get_args(field)
- actual = [
- name
- for name in args
- if name not in ignore and not name.startswith('_')
- ]
- assert set(arguments) == set(actual), \
- 'Expected arguments ({}) did not match actual ({})'.format(
- arguments,
- actual
- )
+ actual = [name for name in args if name not in ignore and not name.startswith("_")]
+ assert set(arguments) == set(
+ actual
+ ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual)
def assert_orderable(field):
args = get_args(field)
- assert 'order_by' in args, \
- 'Field cannot be ordered'
+ assert "order_by" in args, "Field cannot be ordered"
def assert_not_orderable(field):
args = get_args(field)
- assert 'order_by' not in args, \
- 'Field can be ordered'
+ assert "order_by" not in args, "Field can be ordered"
def test_filter_explicit_filterset_arguments():
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter)
- assert_arguments(field,
- 'headline', 'headline__icontains',
- 'pub_date', 'pub_date__gt', 'pub_date__lt',
- 'reporter',
- )
+ assert_arguments(
+ field,
+ "headline",
+ "headline__icontains",
+ "pub_date",
+ "pub_date__gt",
+ "pub_date__lt",
+ "reporter",
+ )
def test_filter_shortcut_filterset_arguments_list():
- field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter'])
- assert_arguments(field,
- 'pub_date',
- 'reporter',
- )
+ field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"])
+ assert_arguments(field, "pub_date", "reporter")
def test_filter_shortcut_filterset_arguments_dict():
- field = DjangoFilterConnectionField(ArticleNode, fields={
- 'headline': ['exact', 'icontains'],
- 'reporter': ['exact'],
- })
- assert_arguments(field,
- 'headline', 'headline__icontains',
- 'reporter',
- )
+ field = DjangoFilterConnectionField(
+ ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]}
+ )
+ assert_arguments(field, "headline", "headline__icontains", "reporter")
def test_filter_explicit_filterset_orderable():
@@ -130,39 +129,91 @@ def test_filter_explicit_filterset_not_orderable():
def test_filter_shortcut_filterset_extra_meta():
- field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={
- 'exclude': ('headline', )
- })
- assert 'headline' not in field.filterset_class.get_fields()
+ field = DjangoFilterConnectionField(
+ ArticleNode, extra_filter_meta={"exclude": ("headline",)}
+ )
+ assert "headline" not in field.filterset_class.get_fields()
+
+
+def test_filter_shortcut_filterset_context():
+ class ArticleContextFilter(django_filters.FilterSet):
+ class Meta:
+ model = Article
+ exclude = set()
+
+ @property
+ def qs(self):
+ qs = super(ArticleContextFilter, self).qs
+ return qs.filter(reporter=self.request.reporter)
+
+ class Query(ObjectType):
+ context_articles = DjangoFilterConnectionField(
+ ArticleNode, filterset_class=ArticleContextFilter
+ )
+
+ r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
+ r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
+ Article.objects.create(
+ headline="a1",
+ pub_date=datetime.now(),
+ pub_date_time=datetime.now(),
+ reporter=r1,
+ editor=r1,
+ )
+ Article.objects.create(
+ headline="a2",
+ pub_date=datetime.now(),
+ pub_date_time=datetime.now(),
+ reporter=r2,
+ editor=r2,
+ )
+
+ class context(object):
+ reporter = r2
+
+ query = """
+ query {
+ contextArticles {
+ edges {
+ node {
+ headline
+ }
+ }
+ }
+ }
+ """
+ schema = Schema(query=Query)
+ result = schema.execute(query, context_value=context())
+ assert not result.errors
+
+ assert len(result.data["contextArticles"]["edges"]) == 1
+ assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2"
def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
- filter_fields = ['first_name', 'articles']
+ interfaces = (Node,)
+ filter_fields = ["first_name", "articles"]
field = DjangoFilterConnectionField(ReporterFilterNode)
- assert_arguments(field, 'first_name', 'articles')
+ assert_arguments(field, "first_name", "articles")
assert_not_orderable(field)
def test_filter_filterset_information_on_meta_related():
class ReporterFilterNode(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
- filter_fields = ['first_name', 'articles']
+ interfaces = (Node,)
+ filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
-
class Meta:
model = Article
- interfaces = (Node, )
- filter_fields = ['headline', 'reporter']
+ interfaces = (Node,)
+ filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@@ -171,25 +222,23 @@ def test_filter_filterset_information_on_meta_related():
article = Field(ArticleFilterNode)
schema = Schema(query=Query)
- articles_field = ReporterFilterNode._meta.fields['articles'].get_type()
- assert_arguments(articles_field, 'headline', 'reporter')
+ articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
+ assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field)
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
- filter_fields = ['first_name', 'articles']
+ interfaces = (Node,)
+ filter_fields = ["first_name", "articles"]
class ArticleFilterNode(DjangoObjectType):
-
class Meta:
- interfaces = (Node, )
+ interfaces = (Node,)
model = Article
- filter_fields = ['headline', 'reporter']
+ filter_fields = ["headline", "reporter"]
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
@@ -197,12 +246,22 @@ def test_filter_filterset_related_results():
reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode)
- r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
- r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
- Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1)
- Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2)
+ r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com")
+ r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com")
+ Article.objects.create(
+ headline="a1",
+ pub_date=datetime.now(),
+ pub_date_time=datetime.now(),
+ reporter=r1,
+ )
+ Article.objects.create(
+ headline="a2",
+ pub_date=datetime.now(),
+ pub_date_time=datetime.now(),
+ reporter=r2,
+ )
- query = '''
+ query = """
query {
allReporters {
edges {
@@ -218,123 +277,134 @@ def test_filter_filterset_related_results():
}
}
}
- '''
+ """
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get back a single article for each reporter
- assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1
- assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1
+ assert (
+ len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1
+ )
+ assert (
+ len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1
+ )
def test_global_id_field_implicit():
- field = DjangoFilterConnectionField(ArticleNode, fields=['id'])
+ field = DjangoFilterConnectionField(ArticleNode, fields=["id"])
filterset_class = field.filterset_class
- id_filter = filterset_class.base_filters['id']
+ id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_field_explicit():
class ArticleIdFilter(django_filters.FilterSet):
-
class Meta:
model = Article
- fields = ['id']
+ fields = ["id"]
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
filterset_class = field.filterset_class
- id_filter = filterset_class.base_filters['id']
+ id_filter = filterset_class.base_filters["id"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_filterset_descriptions():
class ArticleIdFilter(django_filters.FilterSet):
-
class Meta:
model = Article
- fields = ['id']
+ fields = ["id"]
- max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time")
+ max_time = django_filters.NumberFilter(
+ method="filter_max_time", label="The maximum time"
+ )
field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter)
- max_time = field.args['max_time']
+ max_time = field.args["max_time"]
assert isinstance(max_time, Argument)
assert max_time.type == Float
- assert max_time.description == 'The maximum time'
+ assert max_time.description == "The maximum time"
def test_global_id_field_relation():
- field = DjangoFilterConnectionField(ArticleNode, fields=['reporter'])
+ field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"])
filterset_class = field.filterset_class
- id_filter = filterset_class.base_filters['reporter']
+ id_filter = filterset_class.base_filters["reporter"]
assert isinstance(id_filter, GlobalIDFilter)
assert id_filter.field_class == GlobalIDFormField
def test_global_id_multiple_field_implicit():
- field = DjangoFilterConnectionField(ReporterNode, fields=['pets'])
+ field = DjangoFilterConnectionField(ReporterNode, fields=["pets"])
filterset_class = field.filterset_class
- multiple_filter = filterset_class.base_filters['pets']
+ multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit():
class ReporterPetsFilter(django_filters.FilterSet):
-
class Meta:
model = Reporter
- fields = ['pets']
+ fields = ["pets"]
- field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
+ field = DjangoFilterConnectionField(
+ ReporterNode, filterset_class=ReporterPetsFilter
+ )
filterset_class = field.filterset_class
- multiple_filter = filterset_class.base_filters['pets']
+ multiple_filter = filterset_class.base_filters["pets"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_implicit_reverse():
- field = DjangoFilterConnectionField(ReporterNode, fields=['articles'])
+ field = DjangoFilterConnectionField(ReporterNode, fields=["articles"])
filterset_class = field.filterset_class
- multiple_filter = filterset_class.base_filters['articles']
+ multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_global_id_multiple_field_explicit_reverse():
class ReporterPetsFilter(django_filters.FilterSet):
-
class Meta:
model = Reporter
- fields = ['articles']
+ fields = ["articles"]
- field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter)
+ field = DjangoFilterConnectionField(
+ ReporterNode, filterset_class=ReporterPetsFilter
+ )
filterset_class = field.filterset_class
- multiple_filter = filterset_class.base_filters['articles']
+ multiple_filter = filterset_class.base_filters["articles"]
assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter)
assert multiple_filter.field_class == GlobalIDMultipleChoiceField
def test_filter_filterset_related_results():
class ReporterFilterNode(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
- filter_fields = {
- 'first_name': ['icontains']
- }
+ interfaces = (Node,)
+ filter_fields = {"first_name": ["icontains"]}
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
- r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com')
- r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com')
- r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com')
+ r1 = Reporter.objects.create(
+ first_name="A test user", last_name="Last Name", email="test1@test.com"
+ )
+ r2 = Reporter.objects.create(
+ first_name="Other test user",
+ last_name="Other Last Name",
+ email="test2@test.com",
+ )
+ r3 = Reporter.objects.create(
+ first_name="Random", last_name="RandomLast", email="random@test.com"
+ )
- query = '''
+ query = """
query {
allReporters(firstName_Icontains: "test") {
edges {
@@ -344,12 +414,12 @@ def test_filter_filterset_related_results():
}
}
}
- '''
+ """
schema = Schema(query=Query)
result = schema.execute(query)
assert not result.errors
# We should only get two reporters
- assert len(result.data['allReporters']['edges']) == 2
+ assert len(result.data["allReporters"]["edges"]) == 2
def test_recursive_filter_connection():
@@ -361,77 +431,73 @@ def test_recursive_filter_connection():
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterFilterNode)
- assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode
+ assert (
+ ReporterFilterNode._meta.fields["child_reporters"].node_type
+ == ReporterFilterNode
+ )
def test_should_query_filter_node_limit():
class ReporterFilter(FilterSet):
- limit = NumberFilter(method='filter_limit')
+ limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value):
return queryset[:value]
class Meta:
model = Reporter
- fields = ['first_name', ]
+ fields = ["first_name"]
class ReporterType(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class ArticleType(DjangoObjectType):
-
class Meta:
model = Article
- interfaces = (Node, )
- filter_fields = ('lang', )
+ interfaces = (Node,)
+ filter_fields = ("lang",)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
- ReporterType,
- filterset_class=ReporterFilter
+ ReporterType, filterset_class=ReporterFilter
)
def resolve_all_reporters(self, info, **args):
- return Reporter.objects.order_by('a_choice')
+ return Reporter.objects.order_by("a_choice")
Reporter.objects.create(
- first_name='Bob',
- last_name='Doe',
- email='bobdoe@example.com',
- a_choice=2
+ first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
- first_name='John',
- last_name='Doe',
- email='johndoe@example.com',
- a_choice=1
+ first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
Article.objects.create(
- headline='Article Node 1',
+ headline="Article Node 1",
pub_date=datetime.now(),
+ pub_date_time=datetime.now(),
reporter=r,
editor=r,
- lang='es'
+ lang="es",
)
Article.objects.create(
- headline='Article Node 2',
+ headline="Article Node 2",
pub_date=datetime.now(),
+ pub_date_time=datetime.now(),
reporter=r,
editor=r,
- lang='en'
+ lang="en",
)
schema = Schema(query=Query)
- query = '''
+ query = """
query NodeFilteringQuery {
allReporters(limit: 1) {
edges {
@@ -450,24 +516,23 @@ def test_should_query_filter_node_limit():
}
}
}
- '''
+ """
expected = {
- 'allReporters': {
- 'edges': [{
- 'node': {
- 'id': 'UmVwb3J0ZXJUeXBlOjI=',
- 'firstName': 'John',
- 'articles': {
- 'edges': [{
- 'node': {
- 'id': 'QXJ0aWNsZVR5cGU6MQ==',
- 'lang': 'ES'
- }
- }]
+ "allReporters": {
+ "edges": [
+ {
+ "node": {
+ "id": "UmVwb3J0ZXJUeXBlOjI=",
+ "firstName": "John",
+ "articles": {
+ "edges": [
+ {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}}
+ ]
+ },
}
}
- }]
+ ]
}
}
@@ -478,45 +543,37 @@ def test_should_query_filter_node_limit():
def test_should_query_filter_node_double_limit_raises():
class ReporterFilter(FilterSet):
- limit = NumberFilter(method='filter_limit')
+ limit = NumberFilter(method="filter_limit")
def filter_limit(self, queryset, name, value):
return queryset[:value]
class Meta:
model = Reporter
- fields = ['first_name', ]
+ fields = ["first_name"]
class ReporterType(DjangoObjectType):
-
class Meta:
model = Reporter
- interfaces = (Node, )
+ interfaces = (Node,)
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(
- ReporterType,
- filterset_class=ReporterFilter
+ ReporterType, filterset_class=ReporterFilter
)
def resolve_all_reporters(self, info, **args):
- return Reporter.objects.order_by('a_choice')[:2]
+ return Reporter.objects.order_by("a_choice")[:2]
Reporter.objects.create(
- first_name='Bob',
- last_name='Doe',
- email='bobdoe@example.com',
- a_choice=2
+ first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2
)
r = Reporter.objects.create(
- first_name='John',
- last_name='Doe',
- email='johndoe@example.com',
- a_choice=1
+ first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
schema = Schema(query=Query)
- query = '''
+ query = """
query NodeFilteringQuery {
allReporters(limit: 1) {
edges {
@@ -527,10 +584,116 @@ def test_should_query_filter_node_double_limit_raises():
}
}
}
- '''
+ """
result = schema.execute(query)
assert len(result.errors) == 1
assert str(result.errors[0]) == (
- 'Received two sliced querysets (high mark) in the connection, please slice only in one.'
+ "Received two sliced querysets (high mark) in the connection, please slice only in one."
)
+
+
+def test_order_by_is_perserved():
+ class ReporterType(DjangoObjectType):
+ class Meta:
+ model = Reporter
+ interfaces = (Node,)
+ filter_fields = ()
+
+ class Query(ObjectType):
+ all_reporters = DjangoFilterConnectionField(
+ ReporterType, reverse_order=Boolean()
+ )
+
+ def resolve_all_reporters(self, info, reverse_order=False, **args):
+ reporters = Reporter.objects.order_by("first_name")
+
+ if reverse_order:
+ return reporters.reverse()
+
+ return reporters
+
+ Reporter.objects.create(first_name="b")
+ r = Reporter.objects.create(first_name="a")
+
+ schema = Schema(query=Query)
+ query = """
+ query NodeFilteringQuery {
+ allReporters(first: 1) {
+ edges {
+ node {
+ firstName
+ }
+ }
+ }
+ }
+ """
+ expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}}
+
+ result = schema.execute(query)
+ assert not result.errors
+ assert result.data == expected
+
+ reverse_query = """
+ query NodeFilteringQuery {
+ allReporters(first: 1, reverseOrder: true) {
+ edges {
+ node {
+ firstName
+ }
+ }
+ }
+ }
+ """
+
+ reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}}
+
+ reverse_result = schema.execute(reverse_query)
+
+ assert not reverse_result.errors
+ assert reverse_result.data == reverse_expected
+
+
+def test_annotation_is_perserved():
+ class ReporterType(DjangoObjectType):
+ full_name = String()
+
+ def resolve_full_name(instance, info, **args):
+ return instance.full_name
+
+ class Meta:
+ model = Reporter
+ interfaces = (Node,)
+ filter_fields = ()
+
+ class Query(ObjectType):
+ all_reporters = DjangoFilterConnectionField(ReporterType)
+
+ def resolve_all_reporters(self, info, **args):
+ return Reporter.objects.annotate(
+ full_name=Concat(
+ "first_name", Value(" "), "last_name", output_field=TextField()
+ )
+ )
+
+ Reporter.objects.create(first_name="John", last_name="Doe")
+
+ schema = Schema(query=Query)
+
+ query = """
+ query NodeFilteringQuery {
+ allReporters(first: 1) {
+ edges {
+ node {
+ fullName
+ }
+ }
+ }
+ }
+ """
+ expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}}
+
+ result = schema.execute(query)
+
+ assert not result.errors
+ assert result.data == expected
diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py
index 6b938ce..cfa5621 100644
--- a/graphene_django/filter/utils.py
+++ b/graphene_django/filter/utils.py
@@ -8,7 +8,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
a Graphene Field. These arguments will be available to
filter against in the GraphQL
"""
- from ..form_converter import convert_form_field
+ from ..forms.converter import convert_form_field
args = {}
for name, filter_field in six.iteritems(filterset_class.base_filters):
diff --git a/graphene_django/forms/__init__.py b/graphene_django/forms/__init__.py
new file mode 100644
index 0000000..066eec4
--- /dev/null
+++ b/graphene_django/forms/__init__.py
@@ -0,0 +1 @@
+from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa
diff --git a/graphene_django/form_converter.py b/graphene_django/forms/converter.py
similarity index 71%
rename from graphene_django/form_converter.py
rename to graphene_django/forms/converter.py
index 46a38b3..87180b2 100644
--- a/graphene_django/form_converter.py
+++ b/graphene_django/forms/converter.py
@@ -1,30 +1,24 @@
from django import forms
-from django.forms.fields import BaseTemporalField
+from django.core.exceptions import ImproperlyConfigured
-from graphene import ID, Boolean, Float, Int, List, String, UUID
+from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
-from .utils import import_single_dispatch
+from ..utils import import_single_dispatch
+
singledispatch = import_single_dispatch()
-try:
- UUIDField = forms.UUIDField
-except AttributeError:
- class UUIDField(object):
- pass
-
@singledispatch
def convert_form_field(field):
- raise Exception(
+ raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) "
- "to Graphene type" %
- (field, field.__class__)
+ "to Graphene type" % (field, field.__class__)
)
-@convert_form_field.register(BaseTemporalField)
+@convert_form_field.register(forms.fields.BaseTemporalField)
@convert_form_field.register(forms.CharField)
@convert_form_field.register(forms.EmailField)
@convert_form_field.register(forms.SlugField)
@@ -36,7 +30,7 @@ def convert_form_field_to_string(field):
return String(description=field.help_text, required=field.required)
-@convert_form_field.register(UUIDField)
+@convert_form_field.register(forms.UUIDField)
def convert_form_field_to_uuid(field):
return UUID(description=field.help_text, required=field.required)
@@ -69,6 +63,21 @@ def convert_form_field_to_list(field):
return List(ID, required=field.required)
+@convert_form_field.register(forms.DateField)
+def convert_form_field_to_date(field):
+ return Date(description=field.help_text, required=field.required)
+
+
+@convert_form_field.register(forms.DateTimeField)
+def convert_form_field_to_datetime(field):
+ return DateTime(description=field.help_text, required=field.required)
+
+
+@convert_form_field.register(forms.TimeField)
+def convert_form_field_to_time(field):
+ return Time(description=field.help_text, required=field.required)
+
+
@convert_form_field.register(forms.ModelChoiceField)
@convert_form_field.register(GlobalIDFormField)
def convert_form_field_to_id(field):
diff --git a/graphene_django/forms.py b/graphene_django/forms/forms.py
similarity index 72%
rename from graphene_django/forms.py
rename to graphene_django/forms/forms.py
index a54f0a5..14e68c8 100644
--- a/graphene_django/forms.py
+++ b/graphene_django/forms/forms.py
@@ -8,9 +8,7 @@ from graphql_relay import from_global_id
class GlobalIDFormField(Field):
- default_error_messages = {
- 'invalid': _('Invalid ID specified.'),
- }
+ default_error_messages = {"invalid": _("Invalid ID specified.")}
def clean(self, value):
if not value and not self.required:
@@ -19,21 +17,21 @@ class GlobalIDFormField(Field):
try:
_type, _id = from_global_id(value)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
- raise ValidationError(self.error_messages['invalid'])
+ raise ValidationError(self.error_messages["invalid"])
try:
CharField().clean(_id)
CharField().clean(_type)
except ValidationError:
- raise ValidationError(self.error_messages['invalid'])
+ raise ValidationError(self.error_messages["invalid"])
return value
class GlobalIDMultipleChoiceField(MultipleChoiceField):
default_error_messages = {
- 'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'),
- 'invalid_list': _('Enter a list of values.'),
+ "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."),
+ "invalid_list": _("Enter a list of values."),
}
def valid_value(self, value):
diff --git a/graphene_django/forms/mutation.py b/graphene_django/forms/mutation.py
new file mode 100644
index 0000000..63ea089
--- /dev/null
+++ b/graphene_django/forms/mutation.py
@@ -0,0 +1,192 @@
+# from django import forms
+from collections import OrderedDict
+
+import graphene
+from graphene import Field, InputField
+from graphene.relay.mutation import ClientIDMutation
+from graphene.types.mutation import MutationOptions
+
+# from graphene.types.inputobjecttype import (
+# InputObjectTypeOptions,
+# InputObjectType,
+# )
+from graphene.types.utils import yank_fields_from_attrs
+from graphene_django.registry import get_global_registry
+
+from .converter import convert_form_field
+from .types import ErrorType
+
+
+def fields_for_form(form, only_fields, exclude_fields):
+ fields = OrderedDict()
+ for name, field in form.fields.items():
+ is_not_in_only = only_fields and name not in only_fields
+ is_excluded = (
+ name
+ in exclude_fields # or
+ # name in already_created_fields
+ )
+
+ if is_not_in_only or is_excluded:
+ continue
+
+ fields[name] = convert_form_field(field)
+ return fields
+
+
+class BaseDjangoFormMutation(ClientIDMutation):
+ class Meta:
+ abstract = True
+
+ @classmethod
+ def mutate_and_get_payload(cls, root, info, **input):
+ form = cls.get_form(root, info, **input)
+
+ if form.is_valid():
+ return cls.perform_mutate(form, info)
+ else:
+ errors = [
+ ErrorType(field=key, messages=value)
+ for key, value in form.errors.items()
+ ]
+
+ return cls(errors=errors)
+
+ @classmethod
+ def get_form(cls, root, info, **input):
+ form_kwargs = cls.get_form_kwargs(root, info, **input)
+ return cls._meta.form_class(**form_kwargs)
+
+ @classmethod
+ def get_form_kwargs(cls, root, info, **input):
+ kwargs = {"data": input}
+
+ pk = input.pop("id", None)
+ if pk:
+ instance = cls._meta.model._default_manager.get(pk=pk)
+ kwargs["instance"] = instance
+
+ return kwargs
+
+
+# class DjangoFormInputObjectTypeOptions(InputObjectTypeOptions):
+# form_class = None
+
+
+# class DjangoFormInputObjectType(InputObjectType):
+# class Meta:
+# abstract = True
+
+# @classmethod
+# def __init_subclass_with_meta__(cls, form_class=None,
+# only_fields=(), exclude_fields=(), _meta=None, **options):
+# if not _meta:
+# _meta = DjangoFormInputObjectTypeOptions(cls)
+# assert isinstance(form_class, forms.Form), (
+# 'form_class must be an instance of django.forms.Form'
+# )
+# _meta.form_class = form_class
+# form = form_class()
+# fields = fields_for_form(form, only_fields, exclude_fields)
+# super(DjangoFormInputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, fields=fields, **options)
+
+
+class DjangoFormMutationOptions(MutationOptions):
+ form_class = None
+
+
+class DjangoFormMutation(BaseDjangoFormMutation):
+ class Meta:
+ abstract = True
+
+ errors = graphene.List(ErrorType)
+
+ @classmethod
+ def __init_subclass_with_meta__(
+ cls, form_class=None, only_fields=(), exclude_fields=(), **options
+ ):
+
+ if not form_class:
+ raise Exception("form_class is required for DjangoFormMutation")
+
+ form = form_class()
+ input_fields = fields_for_form(form, only_fields, exclude_fields)
+ output_fields = fields_for_form(form, only_fields, exclude_fields)
+
+ _meta = DjangoFormMutationOptions(cls)
+ _meta.form_class = form_class
+ _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
+
+ input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
+ super(DjangoFormMutation, cls).__init_subclass_with_meta__(
+ _meta=_meta, input_fields=input_fields, **options
+ )
+
+ @classmethod
+ def perform_mutate(cls, form, info):
+ form.save()
+ return cls(errors=[])
+
+
+class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions):
+ model = None
+ return_field_name = None
+
+
+class DjangoModelFormMutation(BaseDjangoFormMutation):
+ class Meta:
+ abstract = True
+
+ errors = graphene.List(ErrorType)
+
+ @classmethod
+ def __init_subclass_with_meta__(
+ cls,
+ form_class=None,
+ model=None,
+ return_field_name=None,
+ only_fields=(),
+ exclude_fields=(),
+ **options
+ ):
+
+ if not form_class:
+ raise Exception("form_class is required for DjangoModelFormMutation")
+
+ if not model:
+ model = form_class._meta.model
+
+ if not model:
+ raise Exception("model is required for DjangoModelFormMutation")
+
+ form = form_class()
+ input_fields = fields_for_form(form, only_fields, exclude_fields)
+ if "id" not in exclude_fields:
+ input_fields["id"] = graphene.ID()
+
+ registry = get_global_registry()
+ model_type = registry.get_type_for_model(model)
+ return_field_name = return_field_name
+ if not return_field_name:
+ model_name = model.__name__
+ return_field_name = model_name[:1].lower() + model_name[1:]
+
+ output_fields = OrderedDict()
+ output_fields[return_field_name] = graphene.Field(model_type)
+
+ _meta = DjangoModelDjangoFormMutationOptions(cls)
+ _meta.form_class = form_class
+ _meta.model = model
+ _meta.return_field_name = return_field_name
+ _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
+
+ input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
+ super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(
+ _meta=_meta, input_fields=input_fields, **options
+ )
+
+ @classmethod
+ def perform_mutate(cls, form, info):
+ obj = form.save()
+ kwargs = {cls._meta.return_field_name: obj}
+ return cls(errors=[], **kwargs)
diff --git a/graphene_django/forms/tests/__init__.py b/graphene_django/forms/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/graphene_django/forms/tests/test_converter.py b/graphene_django/forms/tests/test_converter.py
new file mode 100644
index 0000000..955b952
--- /dev/null
+++ b/graphene_django/forms/tests/test_converter.py
@@ -0,0 +1,114 @@
+from django import forms
+from py.test import raises
+
+import graphene
+from graphene import (
+ String,
+ Int,
+ Boolean,
+ Float,
+ ID,
+ UUID,
+ List,
+ NonNull,
+ DateTime,
+ Date,
+ Time,
+)
+
+from ..converter import convert_form_field
+
+
+def assert_conversion(django_field, graphene_field, *args):
+ field = django_field(*args, help_text="Custom Help Text")
+ graphene_type = convert_form_field(field)
+ assert isinstance(graphene_type, graphene_field)
+ field = graphene_type.Field()
+ assert field.description == "Custom Help Text"
+ return field
+
+
+def test_should_unknown_django_field_raise_exception():
+ with raises(Exception) as excinfo:
+ convert_form_field(None)
+ assert "Don't know how to convert the Django form field" in str(excinfo.value)
+
+
+def test_should_date_convert_date():
+ assert_conversion(forms.DateField, Date)
+
+
+def test_should_time_convert_time():
+ assert_conversion(forms.TimeField, Time)
+
+
+def test_should_date_time_convert_date_time():
+ assert_conversion(forms.DateTimeField, DateTime)
+
+
+def test_should_char_convert_string():
+ assert_conversion(forms.CharField, String)
+
+
+def test_should_email_convert_string():
+ assert_conversion(forms.EmailField, String)
+
+
+def test_should_slug_convert_string():
+ assert_conversion(forms.SlugField, String)
+
+
+def test_should_url_convert_string():
+ assert_conversion(forms.URLField, String)
+
+
+def test_should_choice_convert_string():
+ assert_conversion(forms.ChoiceField, String)
+
+
+def test_should_base_field_convert_string():
+ assert_conversion(forms.Field, String)
+
+
+def test_should_regex_convert_string():
+ assert_conversion(forms.RegexField, String, "[0-9]+")
+
+
+def test_should_uuid_convert_string():
+ if hasattr(forms, "UUIDField"):
+ assert_conversion(forms.UUIDField, UUID)
+
+
+def test_should_integer_convert_int():
+ assert_conversion(forms.IntegerField, Int)
+
+
+def test_should_boolean_convert_boolean():
+ field = assert_conversion(forms.BooleanField, Boolean)
+ assert isinstance(field.type, NonNull)
+
+
+def test_should_nullboolean_convert_boolean():
+ field = assert_conversion(forms.NullBooleanField, Boolean)
+ assert not isinstance(field.type, NonNull)
+
+
+def test_should_float_convert_float():
+ assert_conversion(forms.FloatField, Float)
+
+
+def test_should_decimal_convert_float():
+ assert_conversion(forms.DecimalField, Float)
+
+
+def test_should_multiple_choice_convert_connectionorlist():
+ field = forms.ModelMultipleChoiceField(queryset=None)
+ graphene_type = convert_form_field(field)
+ assert isinstance(graphene_type, List)
+ assert graphene_type.of_type == ID
+
+
+def test_should_manytoone_convert_connectionorlist():
+ field = forms.ModelChoiceField(queryset=None)
+ graphene_type = convert_form_field(field)
+ assert isinstance(graphene_type, ID)
diff --git a/graphene_django/forms/tests/test_mutation.py b/graphene_django/forms/tests/test_mutation.py
new file mode 100644
index 0000000..df0ffd5
--- /dev/null
+++ b/graphene_django/forms/tests/test_mutation.py
@@ -0,0 +1,141 @@
+from django import forms
+from django.test import TestCase
+from py.test import raises
+
+from graphene_django.tests.models import Pet, Film, FilmDetails
+from ..mutation import DjangoFormMutation, DjangoModelFormMutation
+
+
+class MyForm(forms.Form):
+ text = forms.CharField()
+
+
+class PetForm(forms.ModelForm):
+ class Meta:
+ model = Pet
+ fields = '__all__'
+
+
+def test_needs_form_class():
+ with raises(Exception) as exc:
+
+ class MyMutation(DjangoFormMutation):
+ pass
+
+ assert exc.value.args[0] == "form_class is required for DjangoFormMutation"
+
+
+def test_has_output_fields():
+ class MyMutation(DjangoFormMutation):
+ class Meta:
+ form_class = MyForm
+
+ assert "errors" in MyMutation._meta.fields
+
+
+def test_has_input_fields():
+ class MyMutation(DjangoFormMutation):
+ class Meta:
+ form_class = MyForm
+
+ assert "text" in MyMutation.Input._meta.fields
+
+
+class ModelFormMutationTests(TestCase):
+ def test_default_meta_fields(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+
+ self.assertEqual(PetMutation._meta.model, Pet)
+ self.assertEqual(PetMutation._meta.return_field_name, "pet")
+ self.assertIn("pet", PetMutation._meta.fields)
+
+ def test_default_input_meta_fields(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+
+ self.assertEqual(PetMutation._meta.model, Pet)
+ self.assertEqual(PetMutation._meta.return_field_name, "pet")
+ self.assertIn("name", PetMutation.Input._meta.fields)
+ self.assertIn("client_mutation_id", PetMutation.Input._meta.fields)
+ self.assertIn("id", PetMutation.Input._meta.fields)
+
+ def test_exclude_fields_input_meta_fields(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+ exclude_fields = ['id']
+
+ self.assertEqual(PetMutation._meta.model, Pet)
+ self.assertEqual(PetMutation._meta.return_field_name, "pet")
+ self.assertIn("name", PetMutation.Input._meta.fields)
+ self.assertIn("age", PetMutation.Input._meta.fields)
+ self.assertIn("client_mutation_id", PetMutation.Input._meta.fields)
+ self.assertNotIn("id", PetMutation.Input._meta.fields)
+
+ def test_return_field_name_is_camelcased(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+ model = FilmDetails
+
+ self.assertEqual(PetMutation._meta.model, FilmDetails)
+ self.assertEqual(PetMutation._meta.return_field_name, "filmDetails")
+
+ def test_custom_return_field_name(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+ model = Film
+ return_field_name = "animal"
+
+ self.assertEqual(PetMutation._meta.model, Film)
+ self.assertEqual(PetMutation._meta.return_field_name, "animal")
+ self.assertIn("animal", PetMutation._meta.fields)
+
+ def test_model_form_mutation_mutate(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+
+ pet = Pet.objects.create(name="Axel", age=10)
+
+ result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia", age=10)
+
+ self.assertEqual(Pet.objects.count(), 1)
+ pet.refresh_from_db()
+ self.assertEqual(pet.name, "Mia")
+ self.assertEqual(result.errors, [])
+
+ def test_model_form_mutation_updates_existing_(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+
+ result = PetMutation.mutate_and_get_payload(None, None, name="Mia", age=10)
+
+ self.assertEqual(Pet.objects.count(), 1)
+ pet = Pet.objects.get()
+ self.assertEqual(pet.name, "Mia")
+ self.assertEqual(pet.age, 10)
+ self.assertEqual(result.errors, [])
+
+ def test_model_form_mutation_mutate_invalid_form(self):
+ class PetMutation(DjangoModelFormMutation):
+ class Meta:
+ form_class = PetForm
+
+ result = PetMutation.mutate_and_get_payload(None, None)
+
+ # A pet was not created
+ self.assertEqual(Pet.objects.count(), 0)
+
+
+ fields_w_error = [e.field for e in result.errors]
+ self.assertEqual(len(result.errors), 2)
+ self.assertIn("name", fields_w_error)
+ self.assertEqual(result.errors[0].messages, ["This field is required."])
+ self.assertIn("age", fields_w_error)
+ self.assertEqual(result.errors[1].messages, ["This field is required."])
diff --git a/graphene_django/forms/types.py b/graphene_django/forms/types.py
new file mode 100644
index 0000000..1fe33f3
--- /dev/null
+++ b/graphene_django/forms/types.py
@@ -0,0 +1,6 @@
+import graphene
+
+
+class ErrorType(graphene.ObjectType):
+ field = graphene.String()
+ messages = graphene.List(graphene.String)
diff --git a/graphene_django/management/commands/graphql_schema.py b/graphene_django/management/commands/graphql_schema.py
index 7e2dbac..4e526ec 100644
--- a/graphene_django/management/commands/graphql_schema.py
+++ b/graphene_django/management/commands/graphql_schema.py
@@ -1,79 +1,51 @@
import importlib
import json
-from distutils.version import StrictVersion
-from optparse import make_option
-from django import get_version as get_django_version
from django.core.management.base import BaseCommand, CommandError
from graphene_django.settings import graphene_settings
-LT_DJANGO_1_8 = StrictVersion(get_django_version()) < StrictVersion('1.8')
-if LT_DJANGO_1_8:
- class CommandArguments(BaseCommand):
- option_list = BaseCommand.option_list + (
- make_option(
- '--schema',
- type=str,
- dest='schema',
- default='',
- help='Django app containing schema to dump, e.g. myproject.core.schema.schema',
- ),
- make_option(
- '--out',
- type=str,
- dest='out',
- default='',
- help='Output file (default: schema.json)'
- ),
- make_option(
- '--indent',
- type=int,
- dest='indent',
- default=None,
- help='Output file indent (default: None)'
- ),
+class CommandArguments(BaseCommand):
+ def add_arguments(self, parser):
+ parser.add_argument(
+ "--schema",
+ type=str,
+ dest="schema",
+ default=graphene_settings.SCHEMA,
+ help="Django app containing schema to dump, e.g. myproject.core.schema.schema",
)
-else:
- class CommandArguments(BaseCommand):
- def add_arguments(self, parser):
- parser.add_argument(
- '--schema',
- type=str,
- dest='schema',
- default=graphene_settings.SCHEMA,
- help='Django app containing schema to dump, e.g. myproject.core.schema.schema')
+ parser.add_argument(
+ "--out",
+ type=str,
+ dest="out",
+ default=graphene_settings.SCHEMA_OUTPUT,
+ help="Output file, --out=- prints to stdout (default: schema.json)",
+ )
- parser.add_argument(
- '--out',
- type=str,
- dest='out',
- default=graphene_settings.SCHEMA_OUTPUT,
- help='Output file (default: schema.json)')
-
- parser.add_argument(
- '--indent',
- type=int,
- dest='indent',
- default=graphene_settings.SCHEMA_INDENT,
- help='Output file indent (default: None)')
+ parser.add_argument(
+ "--indent",
+ type=int,
+ dest="indent",
+ default=graphene_settings.SCHEMA_INDENT,
+ help="Output file indent (default: None)",
+ )
class Command(CommandArguments):
- help = 'Dump Graphene schema JSON to file'
+ help = "Dump Graphene schema JSON to file"
can_import_settings = True
def save_file(self, out, schema_dict, indent):
- with open(out, 'w') as outfile:
+ with open(out, "w") as outfile:
json.dump(schema_dict, outfile, indent=indent)
def handle(self, *args, **options):
- options_schema = options.get('schema')
+ options_schema = options.get("schema")
if options_schema and type(options_schema) is str:
- module_str, schema_name = options_schema.rsplit('.', 1)
+ module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name)
@@ -83,16 +55,21 @@ class Command(CommandArguments):
else:
schema = graphene_settings.SCHEMA
- out = options.get('out') or graphene_settings.SCHEMA_OUTPUT
+ out = options.get("out") or graphene_settings.SCHEMA_OUTPUT
if not schema:
- raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema')
+ raise CommandError(
+ "Specify schema on GRAPHENE.SCHEMA setting or by using --schema"
+ )
- indent = options.get('indent')
- schema_dict = {'data': schema.introspect()}
- self.save_file(out, schema_dict, indent)
+ indent = options.get("indent")
+ schema_dict = {"data": schema.introspect()}
+ if out == '-':
+ self.stdout.write(json.dumps(schema_dict, indent=indent))
+ else:
+ self.save_file(out, schema_dict, indent)
- style = getattr(self, 'style', None)
- success = getattr(style, 'SUCCESS', lambda x: x)
+ style = getattr(self, "style", None)
+ success = getattr(style, "SUCCESS", lambda x: x)
- self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out))
+ self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out))
diff --git a/graphene_django/registry.py b/graphene_django/registry.py
index 4e681cc..50a8ae5 100644
--- a/graphene_django/registry.py
+++ b/graphene_django/registry.py
@@ -1,25 +1,32 @@
-
class Registry(object):
-
def __init__(self):
self._registry = {}
- self._registry_models = {}
+ self._field_registry = {}
def register(self, cls):
from .types import DjangoObjectType
+
assert issubclass(
- cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
- cls.__name__)
- assert cls._meta.registry == self, 'Registry for a Model have to match.'
+ cls, DjangoObjectType
+ ), 'Only DjangoObjectTypes can be registered, received "{}"'.format(
+ cls.__name__
+ )
+ assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)
# )
- if not getattr(cls._meta, 'skip_registry', False):
+ if not getattr(cls._meta, "skip_registry", False):
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model):
return self._registry.get(model)
+ def register_converted_field(self, field, converted):
+ self._field_registry[field] = converted
+
+ def get_converted_field(self, field):
+ return self._field_registry.get(field)
+
registry = None
diff --git a/graphene_django/rest_framework/models.py b/graphene_django/rest_framework/models.py
new file mode 100644
index 0000000..848837b
--- /dev/null
+++ b/graphene_django/rest_framework/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class MyFakeModel(models.Model):
+ cool_name = models.CharField(max_length=50)
+ created = models.DateTimeField(auto_now_add=True)
diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py
index 94d1e4b..5e343aa 100644
--- a/graphene_django/rest_framework/mutation.py
+++ b/graphene_django/rest_framework/mutation.py
@@ -1,20 +1,21 @@
from collections import OrderedDict
+from django.shortcuts import get_object_or_404
+
import graphene
from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation
-from graphene.types.objecttype import (
- yank_fields_from_attrs
-)
+from graphene.types.objecttype import yank_fields_from_attrs
-from .serializer_converter import (
- convert_serializer_field
-)
+from .serializer_converter import convert_serializer_field
from .types import ErrorType
class SerializerMutationOptions(MutationOptions):
+ lookup_field = None
+ model_class = None
+ model_operations = ["create", "update"]
serializer_class = None
@@ -23,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals
for name, field in serializer.fields.items():
is_not_in_only = only_fields and name not in only_fields
is_excluded = (
- name in exclude_fields # or
+ name
+ in exclude_fields # or
# name in already_created_fields
)
@@ -39,37 +41,86 @@ class SerializerMutation(ClientIDMutation):
abstract = True
errors = graphene.List(
- ErrorType,
- description='May contain more than one error for same field.'
+ ErrorType, description="May contain more than one error for same field."
)
@classmethod
- def __init_subclass_with_meta__(cls, serializer_class=None,
- only_fields=(), exclude_fields=(), **options):
+ def __init_subclass_with_meta__(
+ cls,
+ lookup_field=None,
+ serializer_class=None,
+ model_class=None,
+ model_operations=["create", "update"],
+ only_fields=(),
+ exclude_fields=(),
+ **options
+ ):
if not serializer_class:
- raise Exception('serializer_class is required for the SerializerMutation')
+ raise Exception("serializer_class is required for the SerializerMutation")
+
+ if "update" not in model_operations and "create" not in model_operations:
+ raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class()
- input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
- output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
+ if model_class is None:
+ serializer_meta = getattr(serializer_class, "Meta", None)
+ if serializer_meta:
+ model_class = getattr(serializer_meta, "model", None)
+
+ if lookup_field is None and model_class:
+ lookup_field = model_class._meta.pk.name
+
+ input_fields = fields_for_serializer(
+ serializer, only_fields, exclude_fields, is_input=True
+ )
+ output_fields = fields_for_serializer(
+ serializer, only_fields, exclude_fields, is_input=False
+ )
_meta = SerializerMutationOptions(cls)
+ _meta.lookup_field = lookup_field
+ _meta.model_operations = model_operations
_meta.serializer_class = serializer_class
- _meta.fields = yank_fields_from_attrs(
- output_fields,
- _as=Field,
+ _meta.model_class = model_class
+ _meta.fields = yank_fields_from_attrs(output_fields, _as=Field)
+
+ input_fields = yank_fields_from_attrs(input_fields, _as=InputField)
+ super(SerializerMutation, cls).__init_subclass_with_meta__(
+ _meta=_meta, input_fields=input_fields, **options
)
- input_fields = yank_fields_from_attrs(
- input_fields,
- _as=InputField,
- )
- super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
+ @classmethod
+ def get_serializer_kwargs(cls, root, info, **input):
+ lookup_field = cls._meta.lookup_field
+ model_class = cls._meta.model_class
+
+ if model_class:
+ if "update" in cls._meta.model_operations and lookup_field in input:
+ instance = get_object_or_404(
+ model_class, **{lookup_field: input[lookup_field]}
+ )
+ elif "create" in cls._meta.model_operations:
+ instance = None
+ else:
+ raise Exception(
+ 'Invalid update operation. Input parameter "{}" required.'.format(
+ lookup_field
+ )
+ )
+
+ return {
+ "instance": instance,
+ "data": input,
+ "context": {"request": info.context},
+ }
+
+ return {"data": input, "context": {"request": info.context}}
@classmethod
def mutate_and_get_payload(cls, root, info, **input):
- serializer = cls._meta.serializer_class(data=input)
+ kwargs = cls.get_serializer_kwargs(root, info, **input)
+ serializer = cls._meta.serializer_class(**kwargs)
if serializer.is_valid():
return cls.perform_mutate(serializer, info)
@@ -84,4 +135,9 @@ class SerializerMutation(ClientIDMutation):
@classmethod
def perform_mutate(cls, serializer, info):
obj = serializer.save()
- return cls(errors=None, **obj)
+
+ kwargs = {}
+ for f, field in serializer.fields.items():
+ kwargs[f] = field.get_attribute(obj)
+
+ return cls(errors=None, **kwargs)
diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py
index 6a57f5f..9f8e516 100644
--- a/graphene_django/rest_framework/serializer_converter.py
+++ b/graphene_django/rest_framework/serializer_converter.py
@@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True):
graphql_type = get_graphene_type_from_serializer_field(field)
args = []
- kwargs = {
- 'description': field.help_text,
- 'required': is_input and field.required,
- }
+ kwargs = {"description": field.help_text, "required": is_input and field.required}
# if it is a tuple or a list it means that we are returning
# the graphql type and the child type
if isinstance(graphql_type, (list, tuple)):
- kwargs['of_type'] = graphql_type[1]
+ kwargs["of_type"] = graphql_type[1]
graphql_type = graphql_type[0]
if isinstance(field, serializers.ModelSerializer):
@@ -46,6 +43,15 @@ def convert_serializer_field(field, is_input=True):
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
+ elif isinstance(field, serializers.ListSerializer):
+ field = field.child
+ if is_input:
+ kwargs["of_type"] = convert_serializer_to_input_type(field.__class__)
+ else:
+ del kwargs["of_type"]
+ global_registry = get_global_registry()
+ field_model = field.Meta.model
+ args = [global_registry.get_type_for_model(field_model)]
return graphql_type(*args, **kwargs)
@@ -59,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class):
}
return type(
- '{}Input'.format(serializer.__class__.__name__),
+ "{}Input".format(serializer.__class__.__name__),
(graphene.InputObjectType,),
- items
+ items,
)
@@ -75,6 +81,12 @@ def convert_serializer_to_field(field):
return graphene.Field
+@get_graphene_type_from_serializer_field.register(serializers.ListSerializer)
+def convert_list_serializer_to_field(field):
+ child_type = get_graphene_type_from_serializer_field(field.child)
+ return (graphene.List, child_type)
+
+
@get_graphene_type_from_serializer_field.register(serializers.IntegerField)
def convert_serializer_field_to_int(field):
return graphene.Int
@@ -92,9 +104,13 @@ def convert_serializer_field_to_float(field):
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
+def convert_serializer_field_to_datetime_time(field):
+ return graphene.types.datetime.DateTime
+
+
@get_graphene_type_from_serializer_field.register(serializers.DateField)
def convert_serializer_field_to_date_time(field):
- return graphene.types.datetime.DateTime
+ return graphene.types.datetime.Date
@get_graphene_type_from_serializer_field.register(serializers.TimeField)
diff --git a/graphene_django/rest_framework/tests/test_field_converter.py b/graphene_django/rest_framework/tests/test_field_converter.py
index 623cf58..6fa4ca8 100644
--- a/graphene_django/rest_framework/tests/test_field_converter.py
+++ b/graphene_django/rest_framework/tests/test_field_converter.py
@@ -1,8 +1,10 @@
import copy
-from rest_framework import serializers
-from py.test import raises
import graphene
+from django.db import models
+from graphene import InputObjectType
+from py.test import raises
+from rest_framework import serializers
from ..serializer_converter import convert_serializer_field
from ..types import DictType
@@ -14,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
# Remove `source=` from the field declaration.
# since we are reusing the same child in when testing the required attribute
- if 'child' in kwargs:
- kwargs['child'] = copy.deepcopy(kwargs['child'])
+ if "child" in kwargs:
+ kwargs["child"] = copy.deepcopy(kwargs["child"])
field = rest_framework_field(**kwargs)
@@ -23,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs):
def assert_conversion(rest_framework_field, graphene_field, **kwargs):
- graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs)
+ graphene_type = _get_type(
+ rest_framework_field, help_text="Custom Help Text", **kwargs
+ )
assert isinstance(graphene_type, graphene_field)
graphene_type_required = _get_type(
- rest_framework_field, help_text='Custom Help Text', required=True, **kwargs
+ rest_framework_field, help_text="Custom Help Text", required=True, **kwargs
)
assert isinstance(graphene_type_required, graphene_field)
@@ -37,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs):
def test_should_unknown_rest_framework_field_raise_exception():
with raises(Exception) as excinfo:
convert_serializer_field(None)
- assert 'Don\'t know how to convert the serializer field' in str(excinfo.value)
+ assert "Don't know how to convert the serializer field" in str(excinfo.value)
def test_should_char_convert_string():
@@ -65,20 +69,19 @@ def test_should_base_field_convert_string():
def test_should_regex_convert_string():
- assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+')
+ assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+")
def test_should_uuid_convert_string():
- if hasattr(serializers, 'UUIDField'):
+ if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String)
def test_should_model_convert_field():
-
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = None
- fields = '__all__'
+ fields = "__all__"
assert_conversion(MyModelSerializer, graphene.Field, is_input=False)
@@ -87,8 +90,8 @@ def test_should_date_time_convert_datetime():
assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime)
-def test_should_date_convert_datetime():
- assert_conversion(serializers.DateField, graphene.types.datetime.DateTime)
+def test_should_date_convert_date():
+ assert_conversion(serializers.DateField, graphene.types.datetime.Date)
def test_should_time_convert_time():
@@ -108,7 +111,9 @@ def test_should_float_convert_float():
def test_should_decimal_convert_float():
- assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2)
+ assert_conversion(
+ serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2
+ )
def test_should_list_convert_to_list():
@@ -118,7 +123,7 @@ def test_should_list_convert_to_list():
field_a = assert_conversion(
serializers.ListField,
graphene.List,
- child=serializers.IntegerField(min_value=0, max_value=100)
+ child=serializers.IntegerField(min_value=0, max_value=100),
)
assert field_a.of_type == graphene.Int
@@ -128,6 +133,34 @@ def test_should_list_convert_to_list():
assert field_b.of_type == graphene.String
+def test_should_list_serializer_convert_to_list():
+ class FooModel(models.Model):
+ pass
+
+ class ChildSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FooModel
+ fields = "__all__"
+
+ class ParentSerializer(serializers.ModelSerializer):
+ child = ChildSerializer(many=True)
+
+ class Meta:
+ model = FooModel
+ fields = "__all__"
+
+ converted_type = convert_serializer_field(
+ ParentSerializer().get_fields()["child"], is_input=True
+ )
+ assert isinstance(converted_type, graphene.List)
+
+ converted_type = convert_serializer_field(
+ ParentSerializer().get_fields()["child"], is_input=False
+ )
+ assert isinstance(converted_type, graphene.List)
+ assert converted_type.of_type is None
+
+
def test_should_dict_convert_dict():
assert_conversion(serializers.DictField, DictType)
@@ -141,7 +174,7 @@ def test_should_file_convert_string():
def test_should_filepath_convert_string():
- assert_conversion(serializers.FilePathField, graphene.String, path='/')
+ assert_conversion(serializers.FilePathField, graphene.String, path="/")
def test_should_ip_convert_string():
@@ -157,6 +190,8 @@ def test_should_json_convert_jsonstring():
def test_should_multiplechoicefield_convert_to_list_of_string():
- field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3])
+ field = assert_conversion(
+ serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3]
+ )
assert field.of_type == graphene.String
diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py
index 852265d..4dccc18 100644
--- a/graphene_django/rest_framework/tests/test_mutation.py
+++ b/graphene_django/rest_framework/tests/test_mutation.py
@@ -1,21 +1,40 @@
-from django.db import models
-from graphene import Field
+import datetime
+
+from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType
from py.test import raises
+from py.test import mark
from rest_framework import serializers
from ...types import DjangoObjectType
+from ..models import MyFakeModel
from ..mutation import SerializerMutation
-class MyFakeModel(models.Model):
- cool_name = models.CharField(max_length=50)
+def mock_info():
+ return ResolveInfo(
+ None,
+ None,
+ None,
+ None,
+ schema=None,
+ fragments=None,
+ root_value=None,
+ operation=None,
+ variable_values=None,
+ context=None,
+ )
class MyModelSerializer(serializers.ModelSerializer):
class Meta:
model = MyFakeModel
- fields = '__all__'
+ fields = "__all__"
+
+
+class MyModelMutation(SerializerMutation):
+ class Meta:
+ serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer):
@@ -28,10 +47,11 @@ class MySerializer(serializers.Serializer):
def test_needs_serializer_class():
with raises(Exception) as exc:
+
class MyMutation(SerializerMutation):
pass
- assert str(exc.value) == 'serializer_class is required for the SerializerMutation'
+ assert str(exc.value) == "serializer_class is required for the SerializerMutation"
def test_has_fields():
@@ -39,9 +59,9 @@ def test_has_fields():
class Meta:
serializer_class = MySerializer
- assert 'text' in MyMutation._meta.fields
- assert 'model' in MyMutation._meta.fields
- assert 'errors' in MyMutation._meta.fields
+ assert "text" in MyMutation._meta.fields
+ assert "model" in MyMutation._meta.fields
+ assert "errors" in MyMutation._meta.fields
def test_has_input_fields():
@@ -49,12 +69,24 @@ def test_has_input_fields():
class Meta:
serializer_class = MySerializer
- assert 'text' in MyMutation.Input._meta.fields
- assert 'model' in MyMutation.Input._meta.fields
+ assert "text" in MyMutation.Input._meta.fields
+ assert "model" in MyMutation.Input._meta.fields
+
+
+def test_exclude_fields():
+ class MyMutation(SerializerMutation):
+ class Meta:
+ serializer_class = MyModelSerializer
+ exclude_fields = ["created"]
+
+ assert "cool_name" in MyMutation._meta.fields
+ assert "created" not in MyMutation._meta.fields
+ assert "errors" in MyMutation._meta.fields
+ assert "cool_name" in MyMutation.Input._meta.fields
+ assert "created" not in MyMutation.Input._meta.fields
def test_nested_model():
-
class MyFakeModelGrapheneType(DjangoObjectType):
class Meta:
model = MyFakeModel
@@ -63,37 +95,85 @@ def test_nested_model():
class Meta:
serializer_class = MySerializer
- model_field = MyMutation._meta.fields['model']
+ model_field = MyMutation._meta.fields["model"]
assert isinstance(model_field, Field)
assert model_field.type == MyFakeModelGrapheneType
- model_input = MyMutation.Input._meta.fields['model']
+ model_input = MyMutation.Input._meta.fields["model"]
model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType)
- assert 'cool_name' in model_input_type._meta.fields
+ assert "cool_name" in model_input_type._meta.fields
+ assert "created" in model_input_type._meta.fields
def test_mutate_and_get_payload_success():
-
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
- result = MyMutation.mutate_and_get_payload(None, None, **{
- 'text': 'value',
- 'model': {
- 'cool_name': 'other_value'
- }
- })
+ result = MyMutation.mutate_and_get_payload(
+ None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}}
+ )
assert result.errors is None
-def test_mutate_and_get_payload_error():
+@mark.django_db
+def test_model_add_mutate_and_get_payload_success():
+ result = MyModelMutation.mutate_and_get_payload(
+ None, mock_info(), **{"cool_name": "Narf"}
+ )
+ assert result.errors is None
+ assert result.cool_name == "Narf"
+ assert isinstance(result.created, datetime.datetime)
+
+@mark.django_db
+def test_model_update_mutate_and_get_payload_success():
+ instance = MyFakeModel.objects.create(cool_name="Narf")
+ result = MyModelMutation.mutate_and_get_payload(
+ None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"}
+ )
+ assert result.errors is None
+ assert result.cool_name == "New Narf"
+
+
+@mark.django_db
+def test_model_invalid_update_mutate_and_get_payload_success():
+ class InvalidModelMutation(SerializerMutation):
+ class Meta:
+ serializer_class = MyModelSerializer
+ model_operations = ["update"]
+
+ with raises(Exception) as exc:
+ result = InvalidModelMutation.mutate_and_get_payload(
+ None, mock_info(), **{"cool_name": "Narf"}
+ )
+
+ assert '"id" required' in str(exc.value)
+
+
+def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MySerializer
# missing required fields
- result = MyMutation.mutate_and_get_payload(None, None, **{})
- assert len(result.errors) > 0
\ No newline at end of file
+ result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
+ assert len(result.errors) > 0
+
+
+def test_model_mutate_and_get_payload_error():
+ # missing required fields
+ result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
+ assert len(result.errors) > 0
+
+
+def test_invalid_serializer_operations():
+ with raises(Exception) as exc:
+
+ class MyModelMutation(SerializerMutation):
+ class Meta:
+ serializer_class = MyModelSerializer
+ model_operations = ["Add"]
+
+ assert "model_operations" in str(exc.value)
diff --git a/graphene_django/rest_framework/types.py b/graphene_django/rest_framework/types.py
index 956dc43..4c84c69 100644
--- a/graphene_django/rest_framework/types.py
+++ b/graphene_django/rest_framework/types.py
@@ -3,8 +3,8 @@ from graphene.types.unmountedtype import UnmountedType
class ErrorType(graphene.ObjectType):
- field = graphene.String()
- messages = graphene.List(graphene.String)
+ field = graphene.String(required=True)
+ messages = graphene.List(graphene.NonNull(graphene.String), required=True)
class DictType(UnmountedType):
diff --git a/graphene_django/settings.py b/graphene_django/settings.py
index 46d70ee..7cd750a 100644
--- a/graphene_django/settings.py
+++ b/graphene_django/settings.py
@@ -26,27 +26,22 @@ except ImportError:
# Copied shamelessly from Django REST Framework
DEFAULTS = {
- 'SCHEMA': None,
- 'SCHEMA_OUTPUT': 'schema.json',
- 'SCHEMA_INDENT': None,
- 'MIDDLEWARE': (),
+ "SCHEMA": None,
+ "SCHEMA_OUTPUT": "schema.json",
+ "SCHEMA_INDENT": None,
+ "MIDDLEWARE": (),
# Set to True if the connection fields must have
# either the first or last argument
- 'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False,
+ "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False,
# Max items returned in ConnectionFields / FilterConnectionFields
- 'RELAY_CONNECTION_MAX_LIMIT': 100,
+ "RELAY_CONNECTION_MAX_LIMIT": 100,
}
if settings.DEBUG:
- DEFAULTS['MIDDLEWARE'] += (
- 'graphene_django.debug.DjangoDebugMiddleware',
- )
+ DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",)
# List of settings that may be in string import notation.
-IMPORT_STRINGS = (
- 'MIDDLEWARE',
- 'SCHEMA',
-)
+IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA")
def perform_import(val, setting_name):
@@ -69,12 +64,17 @@ def import_from_string(val, setting_name):
"""
try:
# Nod to tastypie's use of importlib.
- parts = val.split('.')
- module_path, class_name = '.'.join(parts[:-1]), parts[-1]
+ parts = val.split(".")
+ module_path, class_name = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
- msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
+ msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (
+ val,
+ setting_name,
+ e.__class__.__name__,
+ e,
+ )
raise ImportError(msg)
@@ -96,8 +96,8 @@ class GrapheneSettings(object):
@property
def user_settings(self):
- if not hasattr(self, '_user_settings'):
- self._user_settings = getattr(settings, 'GRAPHENE', {})
+ if not hasattr(self, "_user_settings"):
+ self._user_settings = getattr(settings, "GRAPHENE", {})
return self._user_settings
def __getattr__(self, attr):
@@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_graphene_settings(*args, **kwargs):
global graphene_settings
- setting, value = kwargs['setting'], kwargs['value']
- if setting == 'GRAPHENE':
+ setting, value = kwargs["setting"], kwargs["value"]
+ if setting == "GRAPHENE":
graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS)
diff --git a/graphene_django/templates/graphene/graphiql.html b/graphene_django/templates/graphene/graphiql.html
index 949b850..1ba0613 100644
--- a/graphene_django/templates/graphene/graphiql.html
+++ b/graphene_django/templates/graphene/graphiql.html
@@ -16,11 +16,11 @@ add "&raw" to the end of the URL within a browser.
width: 100%;
}
-
-
-
-
-
+
+
+
+
+