Merge branch 'master' into form_mutations

# Conflicts:
#	graphene_django/forms/converter.py
#	graphene_django/forms/tests/test_converter.py
This commit is contained in:
Syrus Akbary 2018-06-05 13:22:27 -07:00
commit a9e5beb9eb
48 changed files with 1183 additions and 235 deletions

View File

@ -11,6 +11,9 @@ install:
pip install -e .[test] pip install -e .[test]
pip install psycopg2 # Required for Django postgres fields testing pip install psycopg2 # Required for Django postgres fields testing
pip install django==$DJANGO_VERSION pip install django==$DJANGO_VERSION
if (($(echo "$DJANGO_VERSION <= 1.9" | bc -l))); then # DRF dropped 1.8 and 1.9 support at 3.7.0
pip install djangorestframework==3.6.4
fi
python setup.py develop python setup.py develop
elif [ "$TEST_TYPE" = lint ]; then elif [ "$TEST_TYPE" = lint ]; then
pip install flake8 pip install flake8

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016-Present Syrus Akbary
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,2 +1,2 @@
include README.md include README.md LICENSE
recursive-include graphene_django/templates * recursive-include graphene_django/templates *

View File

@ -9,10 +9,10 @@ A [Django](https://www.djangoproject.com/) integration for [Graphene](http://gra
## Installation ## Installation
For instaling graphene, just run this command in your shell For installing graphene, just run this command in your shell
```bash ```bash
pip install "graphene-django>=2.0.dev" pip install "graphene-django>=2.0"
``` ```
### Settings ### Settings
@ -67,8 +67,7 @@ class User(DjangoObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
users = graphene.List(User) users = graphene.List(User)
@graphene.resolve_only_args def resolve_users(self, info):
def resolve_users(self):
return UserModel.objects.all() return UserModel.objects.all()
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)

View File

@ -13,11 +13,11 @@ A `Django <https://www.djangoproject.com/>`__ integration for
Installation Installation
------------ ------------
For instaling graphene, just run this command in your shell For installing graphene, just run this command in your shell
.. code:: bash .. code:: bash
pip install "graphene-django>=2.0.dev" pip install "graphene-django>=2.0"
Settings Settings
~~~~~~~~ ~~~~~~~~

View File

@ -8,6 +8,7 @@ SECRET_KEY = 1
INSTALLED_APPS = [ INSTALLED_APPS = [
'graphene_django', 'graphene_django',
'graphene_django.rest_framework',
'graphene_django.tests', 'graphene_django.tests',
'starwars', 'starwars',
] ]

View File

@ -34,7 +34,7 @@ This is easy, simply use the ``only_fields`` meta attribute.
only_fields = ('title', 'content') only_fields = ('title', 'content')
interfaces = (relay.Node, ) interfaces = (relay.Node, )
conversely you can use ``exclude_fields`` meta atrribute. conversely you can use ``exclude_fields`` meta attribute.
.. code:: python .. code:: python
@ -61,7 +61,7 @@ define a resolve method for that field and return the desired queryset.
from .models import Post from .models import Post
class Query(ObjectType): class Query(ObjectType):
all_posts = DjangoFilterConnectionField(CategoryNode) all_posts = DjangoFilterConnectionField(PostNode)
def resolve_all_posts(self, args, info): def resolve_all_posts(self, args, info):
return Post.objects.filter(published=True) return Post.objects.filter(published=True)
@ -79,14 +79,14 @@ with the context argument.
from .models import Post from .models import Post
class Query(ObjectType): class Query(ObjectType):
my_posts = DjangoFilterConnectionField(CategoryNode) my_posts = DjangoFilterConnectionField(PostNode)
def resolve_my_posts(self, args, context, info): def resolve_my_posts(self, info):
# context will reference to the Django request # context will reference to the Django request
if not context.user.is_authenticated(): if not info.context.user.is_authenticated():
return Post.objects.none() return Post.objects.none()
else: else:
return Post.objects.filter(owner=context.user) return Post.objects.filter(owner=info.context.user)
If you're using your own view, passing the request context into the If you're using your own view, passing the request context into the
schema is simple. schema is simple.

View File

@ -126,3 +126,23 @@ create your own ``Filterset`` as follows:
# We specify our custom AnimalFilter using the filterset_class param # We specify our custom AnimalFilter using the filterset_class param
all_animals = DjangoFilterConnectionField(AnimalNode, all_animals = DjangoFilterConnectionField(AnimalNode,
filterset_class=AnimalFilter) filterset_class=AnimalFilter)
The context argument is passed on as the `request argument <http://django-filter.readthedocs.io/en/latest/guide/usage.html#request-based-filtering>`__
in a ``django_filters.FilterSet`` instance. You can use this to customize your
filters to be context-dependent. We could modify the ``AnimalFilter`` above to
pre-filter animals owned by the authenticated user (set in ``context.user``).
.. code:: python
class AnimalFilter(django_filters.FilterSet):
# Do case-insensitive lookups on 'name'
name = django_filters.CharFilter(lookup_type='iexact')
class Meta:
model = Animal
fields = ['name', 'genus', 'is_domesticated']
@property
def qs(self):
# The query context can be found in self.request.
return super(AnimalFilter, self).qs.filter(owner=self.request.user)

View File

@ -1,3 +1,3 @@
sphinx sphinx
# Docs template # Docs template
https://github.com/graphql-python/graphene-python.org/archive/docs.zip http://graphene-python.org/sphinx_graphene_theme.zip

View File

@ -19,3 +19,50 @@ You can create a Mutation based on a serializer by using the
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
Create/Update Operations
---------------------
By default ModelSerializers accept create and update operations. To
customize this use the `model_operations` attribute. The update
operation looks up models by the primary key by default. You can
customize the look up with the lookup attribute.
Other default attributes:
`partial = False`: Accept updates without all the input fields.
.. code:: python
from graphene_django.rest_framework.mutation import SerializerMutation
class AwesomeModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['create', 'update']
lookup_field = 'id'
Overriding Update Queries
-------------------------
Use the method `get_serializer_kwargs` to override how
updates are applied.
.. code:: python
from graphene_django.rest_framework.mutation import SerializerMutation
class AwesomeModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
if 'id' in input:
instance = Post.objects.filter(id=input['id'], owner=info.context.user).first()
if instance:
return {'instance': instance, 'data': input, 'partial': True}
else:
raise http.Http404
return {'data': input, 'partial': True}

View File

@ -8,14 +8,14 @@ Our primary focus here is to give a good understanding of how to connect models
A good idea is to check the `graphene <http://docs.graphene-python.org/en/latest/>`__ documentation first. A good idea is to check the `graphene <http://docs.graphene-python.org/en/latest/>`__ documentation first.
Setup the Django project Set up the Django project
------------------------ -------------------------
You can find the entire project in ``examples/cookbook-plain``. You can find the entire project in ``examples/cookbook-plain``.
---- ----
We will setup the project, create the following: We will set up the project, create the following:
- A Django project called ``cookbook`` - A Django project called ``cookbook``
- An app within ``cookbook`` called ``ingredients`` - An app within ``cookbook`` called ``ingredients``
@ -68,7 +68,8 @@ Let's get started with these models:
class Ingredient(models.Model): class Ingredient(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
notes = models.TextField() notes = models.TextField()
category = models.ForeignKey(Category, related_name='ingredients') category = models.ForeignKey(Category, related_name='ingredients',
on_delete=models.CASCADE)
def __str__(self): def __str__(self):
return self.name return self.name
@ -80,7 +81,7 @@ Add ingredients as INSTALLED_APPS:
INSTALLED_APPS = [ INSTALLED_APPS = [
... ...
# Install the ingredients app # Install the ingredients app
'ingredients', 'cookbook.ingredients',
] ]
Don't forget to create & run migrations: Don't forget to create & run migrations:
@ -153,7 +154,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
model = Ingredient model = Ingredient
class Query(graphene.AbstractType): class Query(object):
all_categories = graphene.List(CategoryType) all_categories = graphene.List(CategoryType)
all_ingredients = graphene.List(IngredientType) all_ingredients = graphene.List(IngredientType)
@ -426,7 +427,7 @@ We can update our schema to support that, by adding new query for ``ingredient``
model = Ingredient model = Ingredient
class Query(graphene.AbstractType): class Query(object):
category = graphene.Field(CategoryType, category = graphene.Field(CategoryType,
id=graphene.Int(), id=graphene.Int(),
name=graphene.String()) name=graphene.String())
@ -445,8 +446,8 @@ We can update our schema to support that, by adding new query for ``ingredient``
return Ingredient.objects.all() return Ingredient.objects.all()
def resolve_category(self, info, **kwargs): def resolve_category(self, info, **kwargs):
id = kargs.get('id') id = kwargs.get('id')
name = kargs.get('name') name = kwargs.get('name')
if id is not None: if id is not None:
return Category.objects.get(pk=id) return Category.objects.get(pk=id)
@ -457,8 +458,8 @@ We can update our schema to support that, by adding new query for ``ingredient``
return None return None
def resolve_ingredient(self, info, **kwargs): def resolve_ingredient(self, info, **kwargs):
id = kargs.get('id') id = kwargs.get('id')
name = kargs.get('name') name = kwargs.get('name')
if id is not None: if id is not None:
return Ingredient.objects.get(pk=id) return Ingredient.objects.get(pk=id)

View File

@ -118,7 +118,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
.. code:: python .. code:: python
# cookbook/ingredients/schema.py # cookbook/ingredients/schema.py
from graphene import relay, ObjectType, AbstractType from graphene import relay, ObjectType
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
@ -147,7 +147,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following:
interfaces = (relay.Node, ) interfaces = (relay.Node, )
class Query(AbstractType): class Query(object):
category = relay.Node.Field(CategoryNode) category = relay.Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql) [http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries! and run some queries!
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema) (See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries) for some example queries)

View File

@ -14,7 +14,7 @@ class IngredientType(DjangoObjectType):
model = Ingredient model = Ingredient
class Query(graphene.AbstractType): class Query(object):
category = graphene.Field(CategoryType, category = graphene.Field(CategoryType,
id=graphene.Int(), id=graphene.Int(),
name=graphene.String()) name=graphene.String())

View File

@ -14,7 +14,7 @@ class RecipeIngredientType(DjangoObjectType):
model = RecipeIngredient model = RecipeIngredient
class Query(graphene.AbstractType): class Query(object):
recipe = graphene.Field(RecipeType, recipe = graphene.Field(RecipeType,
id=graphene.Int(), id=graphene.Int(),
title=graphene.String()) title=graphene.String())

View File

@ -1,4 +1,4 @@
graphene graphene
graphene-django graphene-django
graphql-core graphql-core>=2.1rc1
django==1.9 django==1.9

View File

@ -60,5 +60,5 @@ Now you should be ready to start the server:
Now head on over to Now head on over to
[http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql) [http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql)
and run some queries! and run some queries!
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema) (See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries) for some example queries)

View File

@ -1,5 +1,5 @@
from cookbook.ingredients.models import Category, Ingredient from cookbook.ingredients.models import Category, Ingredient
from graphene import AbstractType, Node from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType):
} }
class Query(AbstractType): class Query(object):
category = Node.Field(CategoryNode) category = Node.Field(CategoryNode)
all_categories = DjangoFilterConnectionField(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode)

View File

@ -1,5 +1,5 @@
from cookbook.recipes.models import Recipe, RecipeIngredient from cookbook.recipes.models import Recipe, RecipeIngredient
from graphene import AbstractType, Node from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
@ -24,7 +24,7 @@ class RecipeIngredientNode(DjangoObjectType):
} }
class Query(AbstractType): class Query(object):
recipe = Node.Field(RecipeNode) recipe = Node.Field(RecipeNode)
all_recipes = DjangoFilterConnectionField(RecipeNode) all_recipes = DjangoFilterConnectionField(RecipeNode)

View File

@ -1,5 +1,5 @@
graphene graphene
graphene-django graphene-django
graphql-core graphql-core>=2.1rc1
django==1.9 django==1.9
django-filter==0.11.0 django-filter==0.11.0

View File

@ -5,7 +5,7 @@ from .fields import (
DjangoConnectionField, DjangoConnectionField,
) )
__version__ = '2.0.dev2017083101' __version__ = '2.0.1'
__all__ = [ __all__ = [
'__version__', '__version__',

View File

@ -3,7 +3,7 @@ from django.utils.encoding import force_text
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
NonNull, String, UUID) NonNull, String, UUID)
from graphene.types.datetime import DateTime, Time from graphene.types.datetime import DateTime, Date, Time
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.utils.str_converters import to_camel_case, to_const from graphene.utils.str_converters import to_camel_case, to_const
from graphql import assert_valid_name from graphql import assert_valid_name
@ -40,6 +40,10 @@ def get_choices(choices):
def convert_django_field_with_choices(field, registry=None): def convert_django_field_with_choices(field, registry=None):
if registry is not None:
converted = registry.get_converted_field(field)
if converted:
return converted
choices = getattr(field, 'choices', None) choices = getattr(field, 'choices', None)
if choices: if choices:
meta = field.model._meta meta = field.model._meta
@ -55,8 +59,12 @@ def convert_django_field_with_choices(field, registry=None):
return named_choices_descriptions[self.name] return named_choices_descriptions[self.name]
enum = Enum(name, list(named_choices), type=EnumWithDescriptionsType) enum = Enum(name, list(named_choices), type=EnumWithDescriptionsType)
return enum(description=field.help_text, required=not field.null) converted = enum(description=field.help_text, required=not field.null)
return convert_django_field(field, registry) else:
converted = convert_django_field(field, registry)
if registry is not None:
registry.register_converted_field(field, converted)
return converted
@singledispatch @singledispatch
@ -113,9 +121,14 @@ def convert_field_to_float(field, registry=None):
return Float(description=field.help_text, required=not field.null) return Float(description=field.help_text, required=not field.null)
@convert_django_field.register(models.DateTimeField)
def convert_datetime_to_string(field, registry=None):
return DateTime(description=field.help_text, required=not field.null)
@convert_django_field.register(models.DateField) @convert_django_field.register(models.DateField)
def convert_date_to_string(field, registry=None): def convert_date_to_string(field, registry=None):
return DateTime(description=field.help_text, required=not field.null) return Date(description=field.help_text, required=not field.null)
@convert_django_field.register(models.TimeField) @convert_django_field.register(models.TimeField)

View File

@ -116,7 +116,7 @@ class DjangoConnectionField(ConnectionField):
if last: if last:
assert last <= max_limit, ( assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
).format(first, info.field_name, max_limit) ).format(last, info.field_name, max_limit)
args['last'] = min(last, max_limit) args['last'] = min(last, max_limit)
iterable = resolver(root, info, **args) iterable = resolver(root, info, **args)

View File

@ -43,8 +43,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filtering_args(self): def filtering_args(self):
return get_filtering_args_from_filterset(self.filterset_class, self.node_type) return get_filtering_args_from_filterset(self.filterset_class, self.node_type)
@staticmethod @classmethod
def merge_querysets(default_queryset, queryset): def merge_querysets(cls, default_queryset, queryset):
# There could be the case where the default queryset (returned from the filterclass) # There could be the case where the default queryset (returned from the filterclass)
# and the resolver queryset have some limits on it. # and the resolver queryset have some limits on it.
# We only would be able to apply one of those, but not both # We only would be able to apply one of those, but not both
@ -61,7 +61,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
low = default_queryset.query.low_mark or queryset.query.low_mark low = default_queryset.query.low_mark or queryset.query.low_mark
high = default_queryset.query.high_mark or queryset.query.high_mark high = default_queryset.query.high_mark or queryset.query.high_mark
default_queryset.query.clear_limits() default_queryset.query.clear_limits()
queryset = default_queryset & queryset queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset)
queryset.query.set_limits(low, high) queryset.query.set_limits(low, high)
return queryset return queryset
@ -72,7 +72,8 @@ class DjangoFilterConnectionField(DjangoConnectionField):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class( qs = filterset_class(
data=filter_kwargs, data=filter_kwargs,
queryset=default_manager.get_queryset() queryset=default_manager.get_queryset(),
request=info.context
).qs ).qs
return super(DjangoFilterConnectionField, cls).connection_resolver( return super(DjangoFilterConnectionField, cls).connection_resolver(

View File

@ -57,7 +57,7 @@ class GrapheneFilterSetMixin(BaseFilterSet):
Global IDs (the default implementation expects database Global IDs (the default implementation expects database
primary keys) primary keys)
""" """
rel = f.field.rel rel = f.field.remote_field if hasattr(f.field, 'remote_field') else f.field.rel
default = { default = {
'name': name, 'name': name,
'label': capfirst(rel.related_name) 'label': capfirst(rel.related_name)

View File

@ -2,7 +2,7 @@ from datetime import datetime
import pytest import pytest
from graphene import Field, ObjectType, Schema, Argument, Float from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.forms import (GlobalIDFormField, from graphene_django.forms import (GlobalIDFormField,
@ -10,6 +10,10 @@ from graphene_django.forms import (GlobalIDFormField,
from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.tests.models import Article, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
# for annotation test
from django.db.models import TextField, Value
from django.db.models.functions import Concat
pytestmark = [] pytestmark = []
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
@ -136,6 +140,48 @@ def test_filter_shortcut_filterset_extra_meta():
assert 'headline' not in field.filterset_class.get_fields() assert 'headline' not in field.filterset_class.get_fields()
def test_filter_shortcut_filterset_context():
class ArticleContextFilter(django_filters.FilterSet):
class Meta:
model = Article
exclude = set()
@property
def qs(self):
qs = super(ArticleContextFilter, self).qs
return qs.filter(reporter=self.request.reporter)
class Query(ObjectType):
context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter)
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1, editor=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2, editor=r2)
class context(object):
reporter = r2
query = '''
query {
contextArticles {
edges {
node {
headline
}
}
}
}
'''
schema = Schema(query=Query)
result = schema.execute(query, context_value=context())
assert not result.errors
assert len(result.data['contextArticles']['edges']) == 1
assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2'
def test_filter_filterset_information_on_meta(): def test_filter_filterset_information_on_meta():
class ReporterFilterNode(DjangoObjectType): class ReporterFilterNode(DjangoObjectType):
@ -199,8 +245,8 @@ def test_filter_filterset_related_results():
r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com')
r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com')
Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1) Article.objects.create(headline='a1', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r1)
Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2) Article.objects.create(headline='a2', pub_date=datetime.now(), pub_date_time=datetime.now(), reporter=r2)
query = ''' query = '''
query { query {
@ -418,6 +464,7 @@ def test_should_query_filter_node_limit():
Article.objects.create( Article.objects.create(
headline='Article Node 1', headline='Article Node 1',
pub_date=datetime.now(), pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang='es'
@ -425,6 +472,7 @@ def test_should_query_filter_node_limit():
Article.objects.create( Article.objects.create(
headline='Article Node 2', headline='Article Node 2',
pub_date=datetime.now(), pub_date=datetime.now(),
pub_date_time=datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='en' lang='en'
@ -534,3 +582,135 @@ def test_should_query_filter_node_double_limit_raises():
assert str(result.errors[0]) == ( assert str(result.errors[0]) == (
'Received two sliced querysets (high mark) in the connection, please slice only in one.' 'Received two sliced querysets (high mark) in the connection, please slice only in one.'
) )
def test_order_by_is_perserved():
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ()
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean())
def resolve_all_reporters(self, info, reverse_order=False, **args):
reporters = Reporter.objects.order_by('first_name')
if reverse_order:
return reporters.reverse()
return reporters
Reporter.objects.create(
first_name='b',
)
r = Reporter.objects.create(
first_name='a',
)
schema = Schema(query=Query)
query = '''
query NodeFilteringQuery {
allReporters(first: 1) {
edges {
node {
firstName
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'firstName': 'a',
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
reverse_query = '''
query NodeFilteringQuery {
allReporters(first: 1, reverseOrder: true) {
edges {
node {
firstName
}
}
}
}
'''
reverse_expected = {
'allReporters': {
'edges': [{
'node': {
'firstName': 'b',
}
}]
}
}
reverse_result = schema.execute(reverse_query)
assert not reverse_result.errors
assert reverse_result.data == reverse_expected
def test_annotation_is_perserved():
class ReporterType(DjangoObjectType):
full_name = String()
def resolve_full_name(instance, info, **args):
return instance.full_name
class Meta:
model = Reporter
interfaces = (Node, )
filter_fields = ()
class Query(ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate(
full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField())
)
Reporter.objects.create(
first_name='John',
last_name='Doe',
)
schema = Schema(query=Query)
query = '''
query NodeFilteringQuery {
allReporters(first: 1) {
edges {
node {
fullName
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'fullName': 'John Doe',
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -1,17 +1,11 @@
from django import forms from django import forms
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
import graphene from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time
from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from ..utils import import_single_dispatch from ..utils import import_single_dispatch
try:
UUIDField = forms.UUIDField
except AttributeError:
class UUIDField(object):
pass
singledispatch = import_single_dispatch() singledispatch = import_single_dispatch()
@ -34,54 +28,69 @@ def convert_form_field(field):
@convert_form_field.register(forms.RegexField) @convert_form_field.register(forms.RegexField)
@convert_form_field.register(forms.Field) @convert_form_field.register(forms.Field)
def convert_form_field_to_string(field): def convert_form_field_to_string(field):
return graphene.String(description=field.help_text, required=field.required) return String(description=field.help_text, required=field.required)
@convert_form_field.register(UUIDField) @convert_form_field.register(forms.UUIDField)
def convert_form_field_to_uuid(field): def convert_form_field_to_uuid(field):
return graphene.UUID(description=field.help_text, required=field.required) return UUID(description=field.help_text, required=field.required)
@convert_form_field.register(forms.IntegerField) @convert_form_field.register(forms.IntegerField)
@convert_form_field.register(forms.NumberInput) @convert_form_field.register(forms.NumberInput)
def convert_form_field_to_int(field): def convert_form_field_to_int(field):
return graphene.Int(description=field.help_text, required=field.required) return Int(description=field.help_text, required=field.required)
@convert_form_field.register(forms.BooleanField) @convert_form_field.register(forms.BooleanField)
def convert_form_field_to_boolean(field): def convert_form_field_to_boolean(field):
return graphene.Boolean(description=field.help_text, required=True) return Boolean(description=field.help_text, required=True)
@convert_form_field.register(forms.NullBooleanField) @convert_form_field.register(forms.NullBooleanField)
def convert_form_field_to_nullboolean(field): def convert_form_field_to_nullboolean(field):
return graphene.Boolean(description=field.help_text) return Boolean(description=field.help_text)
@convert_form_field.register(forms.DecimalField) @convert_form_field.register(forms.DecimalField)
@convert_form_field.register(forms.FloatField) @convert_form_field.register(forms.FloatField)
def convert_form_field_to_float(field): def convert_form_field_to_float(field):
return graphene.Float(description=field.help_text, required=field.required) return Float(description=field.help_text, required=field.required)
@convert_form_field.register(forms.ModelMultipleChoiceField) @convert_form_field.register(forms.ModelMultipleChoiceField)
@convert_form_field.register(GlobalIDMultipleChoiceField) @convert_form_field.register(GlobalIDMultipleChoiceField)
def convert_form_field_to_list(field): def convert_form_field_to_list(field):
return graphene.List(graphene.ID, required=field.required) return List(ID, required=field.required)
@convert_form_field.register(forms.DateField)
def convert_form_field_to_date(field):
return Date(description=field.help_text, required=field.required)
@convert_form_field.register(forms.DateTimeField)
def convert_form_field_to_datetime(field):
return DateTime(description=field.help_text, required=field.required)
@convert_form_field.register(forms.TimeField)
def convert_form_field_to_time(field):
return Time(description=field.help_text, required=field.required)
@convert_form_field.register(forms.ModelChoiceField) @convert_form_field.register(forms.ModelChoiceField)
@convert_form_field.register(GlobalIDFormField) @convert_form_field.register(GlobalIDFormField)
def convert_form_field_to_id(field): def convert_form_field_to_id(field):
return graphene.ID(required=field.required) return ID(required=field.required)
@convert_form_field.register(forms.DateField) @convert_form_field.register(forms.DateField)
@convert_form_field.register(forms.DateTimeField) @convert_form_field.register(forms.DateTimeField)
def convert_form_field_to_datetime(field): def convert_form_field_to_datetime(field):
return graphene.types.datetime.DateTime(description=field.help_text, required=field.required) return DateTime(description=field.help_text, required=field.required)
@convert_form_field.register(forms.TimeField) @convert_form_field.register(forms.TimeField)
def convert_form_field_to_time(field): def convert_form_field_to_time(field):
return graphene.types.datetime.Time(description=field.help_text, required=field.required) return Time(description=field.help_text, required=field.required)

View File

@ -2,7 +2,7 @@ from django import forms
from py.test import raises from py.test import raises
import graphene import graphene
from graphene import ID, List, NonNull from graphene import String, Int, Boolean, Float, ID, UUID, List, NonNull, DateTime, Date, Time
from ..converter import convert_form_field from ..converter import convert_form_field
@ -22,71 +22,71 @@ def test_should_unknown_django_field_raise_exception():
assert 'Don\'t know how to convert the Django form field' in str(excinfo.value) assert 'Don\'t know how to convert the Django form field' in str(excinfo.value)
def test_should_date_convert_string(): def test_should_date_convert_date():
assert_conversion(forms.DateField, graphene.types.datetime.DateTime) assert_conversion(forms.DateField, Date)
def test_should_time_convert_string(): def test_should_time_convert_time():
assert_conversion(forms.TimeField, graphene.types.datetime.Time) assert_conversion(forms.TimeField, Time)
def test_should_date_time_convert_string(): def test_should_date_time_convert_date_time():
assert_conversion(forms.DateTimeField, graphene.types.datetime.DateTime) assert_conversion(forms.DateTimeField, DateTime)
def test_should_char_convert_string(): def test_should_char_convert_string():
assert_conversion(forms.CharField, graphene.String) assert_conversion(forms.CharField, String)
def test_should_email_convert_string(): def test_should_email_convert_string():
assert_conversion(forms.EmailField, graphene.String) assert_conversion(forms.EmailField, String)
def test_should_slug_convert_string(): def test_should_slug_convert_string():
assert_conversion(forms.SlugField, graphene.String) assert_conversion(forms.SlugField, String)
def test_should_url_convert_string(): def test_should_url_convert_string():
assert_conversion(forms.URLField, graphene.String) assert_conversion(forms.URLField, String)
def test_should_choice_convert_string(): def test_should_choice_convert_string():
assert_conversion(forms.ChoiceField, graphene.String) assert_conversion(forms.ChoiceField, String)
def test_should_base_field_convert_string(): def test_should_base_field_convert_string():
assert_conversion(forms.Field, graphene.String) assert_conversion(forms.Field, String)
def test_should_regex_convert_string(): def test_should_regex_convert_string():
assert_conversion(forms.RegexField, graphene.String, '[0-9]+') assert_conversion(forms.RegexField, String, '[0-9]+')
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(forms, 'UUIDField'): if hasattr(forms, 'UUIDField'):
assert_conversion(forms.UUIDField, graphene.UUID) assert_conversion(forms.UUIDField, UUID)
def test_should_integer_convert_int(): def test_should_integer_convert_int():
assert_conversion(forms.IntegerField, graphene.Int) assert_conversion(forms.IntegerField, Int)
def test_should_boolean_convert_boolean(): def test_should_boolean_convert_boolean():
field = assert_conversion(forms.BooleanField, graphene.Boolean) field = assert_conversion(forms.BooleanField, Boolean)
assert isinstance(field.type, NonNull) assert isinstance(field.type, NonNull)
def test_should_nullboolean_convert_boolean(): def test_should_nullboolean_convert_boolean():
field = assert_conversion(forms.NullBooleanField, graphene.Boolean) field = assert_conversion(forms.NullBooleanField, Boolean)
assert not isinstance(field.type, NonNull) assert not isinstance(field.type, NonNull)
def test_should_float_convert_float(): def test_should_float_convert_float():
assert_conversion(forms.FloatField, graphene.Float) assert_conversion(forms.FloatField, Float)
def test_should_decimal_convert_float(): def test_should_decimal_convert_float():
assert_conversion(forms.DecimalField, graphene.Float) assert_conversion(forms.DecimalField, Float)
def test_should_multiple_choice_convert_connectionorlist(): def test_should_multiple_choice_convert_connectionorlist():
@ -99,4 +99,4 @@ def test_should_multiple_choice_convert_connectionorlist():
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
field = forms.ModelChoiceField(queryset=None) field = forms.ModelChoiceField(queryset=None)
graphene_type = convert_form_field(field) graphene_type = convert_form_field(field)
assert isinstance(graphene_type, graphene.ID) assert isinstance(graphene_type, ID)

View File

@ -1,64 +1,34 @@
import importlib import importlib
import json import json
from distutils.version import StrictVersion
from optparse import make_option
from django import get_version as get_django_version
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from graphene_django.settings import graphene_settings from graphene_django.settings import graphene_settings
LT_DJANGO_1_8 = StrictVersion(get_django_version()) < StrictVersion('1.8')
if LT_DJANGO_1_8: class CommandArguments(BaseCommand):
class CommandArguments(BaseCommand):
option_list = BaseCommand.option_list + (
make_option(
'--schema',
type=str,
dest='schema',
default='',
help='Django app containing schema to dump, e.g. myproject.core.schema.schema',
),
make_option(
'--out',
type=str,
dest='out',
default='',
help='Output file (default: schema.json)'
),
make_option(
'--indent',
type=int,
dest='indent',
default=None,
help='Output file indent (default: None)'
),
)
else:
class CommandArguments(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument(
'--schema', '--schema',
type=str, type=str,
dest='schema', dest='schema',
default=graphene_settings.SCHEMA, default=graphene_settings.SCHEMA,
help='Django app containing schema to dump, e.g. myproject.core.schema.schema') help='Django app containing schema to dump, e.g. myproject.core.schema.schema')
parser.add_argument( parser.add_argument(
'--out', '--out',
type=str, type=str,
dest='out', dest='out',
default=graphene_settings.SCHEMA_OUTPUT, default=graphene_settings.SCHEMA_OUTPUT,
help='Output file (default: schema.json)') help='Output file (default: schema.json)')
parser.add_argument( parser.add_argument(
'--indent', '--indent',
type=int, type=int,
dest='indent', dest='indent',
default=graphene_settings.SCHEMA_INDENT, default=graphene_settings.SCHEMA_INDENT,
help='Output file indent (default: None)') help='Output file indent (default: None)')
class Command(CommandArguments): class Command(CommandArguments):

View File

@ -3,7 +3,7 @@ class Registry(object):
def __init__(self): def __init__(self):
self._registry = {} self._registry = {}
self._registry_models = {} self._field_registry = {}
def register(self, cls): def register(self, cls):
from .types import DjangoObjectType from .types import DjangoObjectType
@ -20,6 +20,12 @@ class Registry(object):
def get_type_for_model(self, model): def get_type_for_model(self, model):
return self._registry.get(model) return self._registry.get(model)
def register_converted_field(self, field, converted):
self._field_registry[field] = converted
def get_converted_field(self, field):
return self._field_registry.get(field)
registry = None registry = None

View File

@ -0,0 +1,6 @@
from django.db import models
class MyFakeModel(models.Model):
cool_name = models.CharField(max_length=50)
created = models.DateTimeField(auto_now_add=True)

View File

@ -1,5 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from django.shortcuts import get_object_or_404
import graphene import graphene
from graphene.types import Field, InputField from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions from graphene.types.mutation import MutationOptions
@ -15,6 +17,9 @@ from .types import ErrorType
class SerializerMutationOptions(MutationOptions): class SerializerMutationOptions(MutationOptions):
lookup_field = None
model_class = None
model_operations = ['create', 'update']
serializer_class = None serializer_class = None
@ -44,18 +49,34 @@ class SerializerMutation(ClientIDMutation):
) )
@classmethod @classmethod
def __init_subclass_with_meta__(cls, serializer_class=None, def __init_subclass_with_meta__(cls, lookup_field=None,
serializer_class=None, model_class=None,
model_operations=['create', 'update'],
only_fields=(), exclude_fields=(), **options): only_fields=(), exclude_fields=(), **options):
if not serializer_class: if not serializer_class:
raise Exception('serializer_class is required for the SerializerMutation') raise Exception('serializer_class is required for the SerializerMutation')
if 'update' not in model_operations and 'create' not in model_operations:
raise Exception('model_operations must contain "create" and/or "update"')
serializer = serializer_class() serializer = serializer_class()
if model_class is None:
serializer_meta = getattr(serializer_class, 'Meta', None)
if serializer_meta:
model_class = getattr(serializer_meta, 'model', None)
if lookup_field is None and model_class:
lookup_field = model_class._meta.pk.name
input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True)
output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False)
_meta = SerializerMutationOptions(cls) _meta = SerializerMutationOptions(cls)
_meta.lookup_field = lookup_field
_meta.model_operations = model_operations
_meta.serializer_class = serializer_class _meta.serializer_class = serializer_class
_meta.model_class = model_class
_meta.fields = yank_fields_from_attrs( _meta.fields = yank_fields_from_attrs(
output_fields, output_fields,
_as=Field, _as=Field,
@ -67,9 +88,35 @@ class SerializerMutation(ClientIDMutation):
) )
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod
def get_serializer_kwargs(cls, root, info, **input):
lookup_field = cls._meta.lookup_field
model_class = cls._meta.model_class
if model_class:
if 'update' in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404(model_class, **{
lookup_field: input[lookup_field]})
elif 'create' in cls._meta.model_operations:
instance = None
else:
raise Exception(
'Invalid update operation. Input parameter "{}" required.'.format(
lookup_field
))
return {
'instance': instance,
'data': input,
'context': {'request': info.context}
}
return {'data': input, 'context': {'request': info.context}}
@classmethod @classmethod
def mutate_and_get_payload(cls, root, info, **input): def mutate_and_get_payload(cls, root, info, **input):
serializer = cls._meta.serializer_class(data=input) kwargs = cls.get_serializer_kwargs(root, info, **input)
serializer = cls._meta.serializer_class(**kwargs)
if serializer.is_valid(): if serializer.is_valid():
return cls.perform_mutate(serializer, info) return cls.perform_mutate(serializer, info)
@ -84,4 +131,9 @@ class SerializerMutation(ClientIDMutation):
@classmethod @classmethod
def perform_mutate(cls, serializer, info): def perform_mutate(cls, serializer, info):
obj = serializer.save() obj = serializer.save()
return cls(errors=None, **obj)
kwargs = {}
for f, field in serializer.fields.items():
kwargs[f] = field.get_attribute(obj)
return cls(errors=None, **kwargs)

View File

@ -46,6 +46,15 @@ def convert_serializer_field(field, is_input=True):
global_registry = get_global_registry() global_registry = get_global_registry()
field_model = field.Meta.model field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)] args = [global_registry.get_type_for_model(field_model)]
elif isinstance(field, serializers.ListSerializer):
field = field.child
if is_input:
kwargs['of_type'] = convert_serializer_to_input_type(field.__class__)
else:
del kwargs['of_type']
global_registry = get_global_registry()
field_model = field.Meta.model
args = [global_registry.get_type_for_model(field_model)]
return graphql_type(*args, **kwargs) return graphql_type(*args, **kwargs)
@ -75,6 +84,12 @@ def convert_serializer_to_field(field):
return graphene.Field return graphene.Field
@get_graphene_type_from_serializer_field.register(serializers.ListSerializer)
def convert_list_serializer_to_field(field):
child_type = get_graphene_type_from_serializer_field(field.child)
return (graphene.List, child_type)
@get_graphene_type_from_serializer_field.register(serializers.IntegerField) @get_graphene_type_from_serializer_field.register(serializers.IntegerField)
def convert_serializer_field_to_int(field): def convert_serializer_field_to_int(field):
return graphene.Int return graphene.Int
@ -92,9 +107,13 @@ def convert_serializer_field_to_float(field):
@get_graphene_type_from_serializer_field.register(serializers.DateTimeField) @get_graphene_type_from_serializer_field.register(serializers.DateTimeField)
def convert_serializer_field_to_datetime_time(field):
return graphene.types.datetime.DateTime
@get_graphene_type_from_serializer_field.register(serializers.DateField) @get_graphene_type_from_serializer_field.register(serializers.DateField)
def convert_serializer_field_to_date_time(field): def convert_serializer_field_to_date_time(field):
return graphene.types.datetime.DateTime return graphene.types.datetime.Date
@get_graphene_type_from_serializer_field.register(serializers.TimeField) @get_graphene_type_from_serializer_field.register(serializers.TimeField)

View File

@ -1,8 +1,10 @@
import copy import copy
from rest_framework import serializers
from py.test import raises
import graphene import graphene
from django.db import models
from graphene import InputObjectType
from py.test import raises
from rest_framework import serializers
from ..serializer_converter import convert_serializer_field from ..serializer_converter import convert_serializer_field
from ..types import DictType from ..types import DictType
@ -74,7 +76,6 @@ def test_should_uuid_convert_string():
def test_should_model_convert_field(): def test_should_model_convert_field():
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = None model = None
@ -87,8 +88,8 @@ def test_should_date_time_convert_datetime():
assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime) assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime)
def test_should_date_convert_datetime(): def test_should_date_convert_date():
assert_conversion(serializers.DateField, graphene.types.datetime.DateTime) assert_conversion(serializers.DateField, graphene.types.datetime.Date)
def test_should_time_convert_time(): def test_should_time_convert_time():
@ -128,6 +129,30 @@ def test_should_list_convert_to_list():
assert field_b.of_type == graphene.String assert field_b.of_type == graphene.String
def test_should_list_serializer_convert_to_list():
class FooModel(models.Model):
pass
class ChildSerializer(serializers.ModelSerializer):
class Meta:
model = FooModel
fields = '__all__'
class ParentSerializer(serializers.ModelSerializer):
child = ChildSerializer(many=True)
class Meta:
model = FooModel
fields = '__all__'
converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=True)
assert isinstance(converted_type, graphene.List)
converted_type = convert_serializer_field(ParentSerializer().get_fields()['child'], is_input=False)
assert isinstance(converted_type, graphene.List)
assert converted_type.of_type is None
def test_should_dict_convert_dict(): def test_should_dict_convert_dict():
assert_conversion(serializers.DictField, DictType) assert_conversion(serializers.DictField, DictType)
@ -157,6 +182,6 @@ def test_should_json_convert_jsonstring():
def test_should_multiplechoicefield_convert_to_list_of_string(): def test_should_multiplechoicefield_convert_to_list_of_string():
field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3]) field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3])
assert field.of_type == graphene.String assert field.of_type == graphene.String

View File

@ -1,15 +1,28 @@
from django.db import models import datetime
from graphene import Field
from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from py.test import raises from py.test import raises
from py.test import mark
from rest_framework import serializers from rest_framework import serializers
from ...types import DjangoObjectType from ...types import DjangoObjectType
from ..models import MyFakeModel
from ..mutation import SerializerMutation from ..mutation import SerializerMutation
def mock_info():
class MyFakeModel(models.Model): return ResolveInfo(
cool_name = models.CharField(max_length=50) None,
None,
None,
None,
schema=None,
fragments=None,
root_value=None,
operation=None,
variable_values=None,
context=None
)
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):
@ -17,6 +30,9 @@ class MyModelSerializer(serializers.ModelSerializer):
model = MyFakeModel model = MyFakeModel
fields = '__all__' fields = '__all__'
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
class MySerializer(serializers.Serializer): class MySerializer(serializers.Serializer):
text = serializers.CharField() text = serializers.CharField()
@ -53,6 +69,19 @@ def test_has_input_fields():
assert 'model' in MyMutation.Input._meta.fields assert 'model' in MyMutation.Input._meta.fields
def test_exclude_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
exclude_fields = ['created']
assert 'cool_name' in MyMutation._meta.fields
assert 'created' not in MyMutation._meta.fields
assert 'errors' in MyMutation._meta.fields
assert 'cool_name' in MyMutation.Input._meta.fields
assert 'created' not in MyMutation.Input._meta.fields
def test_nested_model(): def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType): class MyFakeModelGrapheneType(DjangoObjectType):
@ -71,6 +100,7 @@ def test_nested_model():
model_input_type = model_input._type.of_type model_input_type = model_input._type.of_type
assert issubclass(model_input_type, InputObjectType) assert issubclass(model_input_type, InputObjectType)
assert 'cool_name' in model_input_type._meta.fields assert 'cool_name' in model_input_type._meta.fields
assert 'created' in model_input_type._meta.fields
def test_mutate_and_get_payload_success(): def test_mutate_and_get_payload_success():
@ -79,7 +109,7 @@ def test_mutate_and_get_payload_success():
class Meta: class Meta:
serializer_class = MySerializer serializer_class = MySerializer
result = MyMutation.mutate_and_get_payload(None, None, **{ result = MyMutation.mutate_and_get_payload(None, mock_info(), **{
'text': 'value', 'text': 'value',
'model': { 'model': {
'cool_name': 'other_value' 'cool_name': 'other_value'
@ -88,6 +118,39 @@ def test_mutate_and_get_payload_success():
assert result.errors is None assert result.errors is None
@mark.django_db
def test_model_add_mutate_and_get_payload_success():
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{
'cool_name': 'Narf',
})
assert result.errors is None
assert result.cool_name == 'Narf'
assert isinstance(result.created, datetime.datetime)
@mark.django_db
def test_model_update_mutate_and_get_payload_success():
instance = MyFakeModel.objects.create(cool_name="Narf")
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{
'id': instance.id,
'cool_name': 'New Narf',
})
assert result.errors is None
assert result.cool_name == 'New Narf'
@mark.django_db
def test_model_invalid_update_mutate_and_get_payload_success():
class InvalidModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['update']
with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload(None, mock_info(), **{
'cool_name': 'Narf',
})
assert '"id" required' in str(exc.value)
def test_mutate_and_get_payload_error(): def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
@ -95,5 +158,19 @@ def test_mutate_and_get_payload_error():
serializer_class = MySerializer serializer_class = MySerializer
# missing required fields # missing required fields
result = MyMutation.mutate_and_get_payload(None, None, **{}) result = MyMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0 assert len(result.errors) > 0
def test_model_mutate_and_get_payload_error():
# missing required fields
result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{})
assert len(result.errors) > 0
def test_invalid_serializer_operations():
with raises(Exception) as exc:
class MyModelMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
model_operations = ['Add']
assert 'model_operations' in str(exc.value)

View File

@ -3,8 +3,8 @@ from graphene.types.unmountedtype import UnmountedType
class ErrorType(graphene.ObjectType): class ErrorType(graphene.ObjectType):
field = graphene.String() field = graphene.String(required=True)
messages = graphene.List(graphene.String) messages = graphene.List(graphene.NonNull(graphene.String), required=True)
class DictType(UnmountedType): class DictType(UnmountedType):

View File

@ -16,11 +16,11 @@ add "&raw" to the end of the URL within a browser.
width: 100%; width: 100%;
} }
</style> </style>
<link href="//cdn.jsdelivr.net/graphiql/{{graphiql_version}}/graphiql.css" rel="stylesheet" /> <link href="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script> <script src="//cdn.jsdelivr.net/npm/whatwg-fetch@2.0.3/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.1/react.min.js"></script> <script src="//cdn.jsdelivr.net/npm/react@16.2.0/umd/react.production.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.1/react-dom.min.js"></script> <script src="//cdn.jsdelivr.net/npm/react-dom@16.2.0/umd/react-dom.production.min.js"></script>
<script src="//cdn.jsdelivr.net/graphiql/{{graphiql_version}}/graphiql.min.js"></script> <script src="//cdn.jsdelivr.net/npm/graphiql@{{graphiql_version}}/graphiql.min.js"></script>
</head> </head>
<body> <body>
<script> <script>

View File

@ -22,6 +22,9 @@ class Film(models.Model):
reporters = models.ManyToManyField('Reporter', reporters = models.ManyToManyField('Reporter',
related_name='films') related_name='films')
class DoeReporterManager(models.Manager):
def get_queryset(self):
return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe")
class Reporter(models.Model): class Reporter(models.Model):
first_name = models.CharField(max_length=30) first_name = models.CharField(max_length=30)
@ -29,14 +32,44 @@ class Reporter(models.Model):
email = models.EmailField() email = models.EmailField()
pets = models.ManyToManyField('self') pets = models.ManyToManyField('self')
a_choice = models.CharField(max_length=30, choices=CHOICES) a_choice = models.CharField(max_length=30, choices=CHOICES)
objects = models.Manager()
doe_objects = DoeReporterManager()
reporter_type = models.IntegerField(
'Reporter Type',
null=True,
blank=True,
choices=[(1, u'Regular'), (2, u'CNN Reporter')]
)
def __str__(self): # __unicode__ on Python 2 def __str__(self): # __unicode__ on Python 2
return "%s %s" % (self.first_name, self.last_name) return "%s %s" % (self.first_name, self.last_name)
def __init__(self, *args, **kwargs):
"""
Override the init method so that during runtime, Django
can know that this object can be a CNNReporter by casting
it to the proxy model. Otherwise, as far as Django knows,
when a CNNReporter is pulled from the database, it is still
of type Reporter. This was added to test proxy model support.
"""
super(Reporter, self).__init__(*args, **kwargs)
if self.reporter_type == 2: # quick and dirty way without enums
self.__class__ = CNNReporter
class CNNReporter(Reporter):
"""
This class is a proxy model for Reporter, used for testing
proxy model support
"""
class Meta:
proxy = True
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField() pub_date = models.DateField()
pub_date_time = models.DateTimeField()
reporter = models.ForeignKey(Reporter, related_name='articles') reporter = models.ForeignKey(Reporter, related_name='articles')
editor = models.ForeignKey(Reporter, related_name='edited_articles_+') editor = models.ForeignKey(Reporter, related_name='edited_articles_+')
lang = models.CharField(max_length=2, help_text='Language', choices=[ lang = models.CharField(max_length=2, help_text='Language', choices=[

View File

@ -5,7 +5,7 @@ from py.test import raises
import graphene import graphene
from graphene.relay import ConnectionField, Node from graphene.relay import ConnectionField, Node
from graphene.types.datetime import DateTime, Time from graphene.types.datetime import DateTime, Date, Time
from graphene.types.json import JSONString from graphene.types.json import JSONString
from ..compat import JSONField, ArrayField, HStoreField, RangeField, MissingType from ..compat import JSONField, ArrayField, HStoreField, RangeField, MissingType
@ -38,9 +38,12 @@ def test_should_unknown_django_field_raise_exception():
convert_django_field(None) convert_django_field(None)
assert 'Don\'t know how to convert the Django field' in str(excinfo.value) assert 'Don\'t know how to convert the Django field' in str(excinfo.value)
def test_should_date_time_convert_string():
assert_conversion(models.DateTimeField, DateTime)
def test_should_date_convert_string(): def test_should_date_convert_string():
assert_conversion(models.DateField, DateTime) assert_conversion(models.DateField, Date)
def test_should_time_convert_string(): def test_should_time_convert_string():

View File

@ -1,7 +1,7 @@
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from py.test import raises from py.test import raises
from ..forms import GlobalIDFormField from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc' # 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'
@ -18,6 +18,17 @@ def test_global_id_invalid():
field.clean('badvalue') field.clean('badvalue')
def test_global_id_multiple_valid():
field = GlobalIDMultipleChoiceField()
field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw=='])
def test_global_id_multiple_invalid():
field = GlobalIDMultipleChoiceField()
with raises(ValidationError):
field.clean(['badvalue', 'another bad avue'])
def test_global_id_none(): def test_global_id_none():
field = GlobalIDFormField() field = GlobalIDFormField()
with raises(ValidationError): with raises(ValidationError):

View File

@ -13,7 +13,11 @@ from ..compat import MissingType, JSONField
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from ..settings import graphene_settings from ..settings import graphene_settings
from .models import Article, Reporter from .models import (
Article,
CNNReporter,
Reporter,
)
pytestmark = pytest.mark.django_db pytestmark = pytest.mark.django_db
@ -371,6 +375,7 @@ def test_should_query_node_filtering():
Article.objects.create( Article.objects.create(
headline='Article Node 1', headline='Article Node 1',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang='es'
@ -378,6 +383,7 @@ def test_should_query_node_filtering():
Article.objects.create( Article.objects.create(
headline='Article Node 2', headline='Article Node 2',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='en' lang='en'
@ -453,6 +459,7 @@ def test_should_query_node_multiple_filtering():
Article.objects.create( Article.objects.create(
headline='Article Node 1', headline='Article Node 1',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang='es'
@ -460,6 +467,7 @@ def test_should_query_node_multiple_filtering():
Article.objects.create( Article.objects.create(
headline='Article Node 2', headline='Article Node 2',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang='es'
@ -467,6 +475,7 @@ def test_should_query_node_multiple_filtering():
Article.objects.create( Article.objects.create(
headline='Article Node 3', headline='Article Node 3',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='en' lang='en'
@ -606,6 +615,53 @@ def test_should_error_if_first_is_greater_than_max():
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False
def test_should_error_if_last_is_greater_than_max():
graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 100
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
schema = graphene.Schema(query=Query)
query = '''
query NodeFilteringQuery {
allReporters(last: 101) {
edges {
node {
id
}
}
}
}
'''
expected = {
'allReporters': None
}
result = schema.execute(query)
assert len(result.errors) == 1
assert str(result.errors[0]) == (
'Requesting 101 records on the `allReporters` connection '
'exceeds the `last` limit of 100 records.'
)
assert result.data == expected
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False
def test_should_query_promise_connectionfields(): def test_should_query_promise_connectionfields():
from promise import Promise from promise import Promise
@ -620,7 +676,7 @@ def test_should_query_promise_connectionfields():
def resolve_all_reporters(self, info, **args): def resolve_all_reporters(self, info, **args):
return Promise.resolve([Reporter(id=1)]) return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
query = ''' query = '''
query ReporterPromiseConnectionQuery { query ReporterPromiseConnectionQuery {
@ -648,6 +704,109 @@ def test_should_query_promise_connectionfields():
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_query_connectionfields_with_last():
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
schema = graphene.Schema(query=Query)
query = '''
query ReporterLastQuery {
allReporters(last: 1) {
edges {
node {
id
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE='
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_connectionfields_with_manager():
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
r = Reporter.objects.create(
first_name='John',
last_name='NotDoe',
email='johndoe@example.com',
a_choice=1
)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType, on='doe_objects')
def resolve_all_reporters(self, info, **args):
return Reporter.objects.all()
schema = graphene.Schema(query=Query)
query = '''
query ReporterLastQuery {
allReporters(first: 2) {
edges {
node {
id
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE='
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_dataloader_fields(): def test_should_query_dataloader_fields():
from promise import Promise from promise import Promise
@ -689,9 +848,11 @@ def test_should_query_dataloader_fields():
email='johndoe@example.com', email='johndoe@example.com',
a_choice=1 a_choice=1
) )
Article.objects.create( Article.objects.create(
headline='Article Node 1', headline='Article Node 1',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='es' lang='es'
@ -699,6 +860,7 @@ def test_should_query_dataloader_fields():
Article.objects.create( Article.objects.create(
headline='Article Node 2', headline='Article Node 2',
pub_date=datetime.date.today(), pub_date=datetime.date.today(),
pub_date_time=datetime.datetime.now(),
reporter=r, reporter=r,
editor=r, editor=r,
lang='en' lang='en'
@ -748,3 +910,171 @@ def test_should_query_dataloader_fields():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_should_handle_inherited_choices():
class BaseModel(models.Model):
choice_field = models.IntegerField(choices=((0, 'zero'), (1, 'one')))
class ChildModel(BaseModel):
class Meta:
proxy = True
class BaseType(DjangoObjectType):
class Meta:
model = BaseModel
class ChildType(DjangoObjectType):
class Meta:
model = ChildModel
class Query(graphene.ObjectType):
base = graphene.Field(BaseType)
child = graphene.Field(ChildType)
schema = graphene.Schema(query=Query)
query = '''
query {
child {
choiceField
}
}
'''
result = schema.execute(query)
assert not result.errors
def test_proxy_model_support():
"""
This test asserts that we can query for all Reporters,
even if some are of a proxy model type at runtime.
"""
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
use_connection = True
reporter_1 = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
reporter_2 = CNNReporter.objects.create(
first_name='Some',
last_name='Guy',
email='someguy@cnn.com',
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
schema = graphene.Schema(query=Query)
query = '''
query ProxyModelQuery {
allReporters {
edges {
node {
id
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE=',
},
},
{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjI=',
},
}
]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_proxy_model_fails():
"""
This test asserts that if you try to query for a proxy model,
that query will fail with:
GraphQLError('Expected value of type "CNNReporterType" but got:
CNNReporter.',)
This is because a proxy model has the identical model definition
to its superclass, and defines its behavior at runtime, rather than
at the database level. Currently, filtering objects of the proxy models'
type isn't supported. It would require a field on the model that would
represent the type, and it doesn't seem like there is a clear way to
enforce this pattern across all projects
"""
class CNNReporterType(DjangoObjectType):
class Meta:
model = CNNReporter
interfaces = (Node, )
use_connection = True
reporter_1 = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
reporter_2 = CNNReporter.objects.create(
first_name='Some',
last_name='Guy',
email='someguy@cnn.com',
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(CNNReporterType)
schema = graphene.Schema(query=Query)
query = '''
query ProxyModelQuery {
allReporters {
edges {
node {
id
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE=',
},
},
{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjI=',
},
}
]
}
}
result = schema.execute(query)
assert result.errors

View File

@ -35,6 +35,7 @@ def test_should_map_fields_correctly():
'email', 'email',
'pets', 'pets',
'a_choice', 'a_choice',
'reporter_type'
] ]
assert sorted(fields[-2:]) == [ assert sorted(fields[-2:]) == [

View File

@ -1,10 +1,10 @@
from mock import patch from mock import patch
from graphene import Interface, ObjectType, Schema from graphene import Interface, ObjectType, Schema, Connection, String
from graphene.relay import Node from graphene.relay import Node
from .. import registry from .. import registry
from ..types import DjangoObjectType from ..types import DjangoObjectType, DjangoObjectTypeOptions
from .models import Article as ArticleModel from .models import Article as ArticleModel
from .models import Reporter as ReporterModel from .models import Reporter as ReporterModel
@ -17,11 +17,23 @@ class Reporter(DjangoObjectType):
model = ReporterModel model = ReporterModel
class ArticleConnection(Connection):
'''Article Connection'''
test = String()
def resolve_test():
return 'test'
class Meta:
abstract = True
class Article(DjangoObjectType): class Article(DjangoObjectType):
'''Article description''' '''Article description'''
class Meta: class Meta:
model = ArticleModel model = ArticleModel
interfaces = (Node, ) interfaces = (Node, )
connection_class = ArticleConnection
class RootQuery(ObjectType): class RootQuery(ObjectType):
@ -46,13 +58,33 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields(): def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields fields = Reporter._meta.fields
fields = list(fields.keys()) fields = list(fields.keys())
assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice'] assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice', 'reporter_type']
assert sorted(fields[-2:]) == ['articles', 'films'] assert sorted(fields[-2:]) == ['articles', 'films']
def test_django_objecttype_with_node_have_correct_fields(): def test_django_objecttype_with_node_have_correct_fields():
fields = Article._meta.fields fields = Article._meta.fields
assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'reporter', 'editor', 'lang', 'importance'] assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'pub_date_time', 'reporter', 'editor', 'lang', 'importance']
def test_django_objecttype_with_custom_meta():
class ArticleTypeOptions(DjangoObjectTypeOptions):
'''Article Type Options'''
class ArticleType(DjangoObjectType):
class Meta:
abstract = True
@classmethod
def __init_subclass_with_meta__(cls, **options):
options.setdefault('_meta', ArticleTypeOptions(cls))
super(ArticleType, cls).__init_subclass_with_meta__(**options)
class Article(ArticleType):
class Meta:
model = ArticleModel
assert isinstance(Article._meta, ArticleTypeOptions)
def test_schema_representation(): def test_schema_representation():
@ -64,7 +96,8 @@ schema {
type Article implements Node { type Article implements Node {
id: ID! id: ID!
headline: String! headline: String!
pubDate: DateTime! pubDate: Date!
pubDateTime: DateTime!
reporter: Reporter! reporter: Reporter!
editor: Reporter! editor: Reporter!
lang: ArticleLang! lang: ArticleLang!
@ -74,6 +107,7 @@ type Article implements Node {
type ArticleConnection { type ArticleConnection {
pageInfo: PageInfo! pageInfo: PageInfo!
edges: [ArticleEdge]! edges: [ArticleEdge]!
test: String
} }
type ArticleEdge { type ArticleEdge {
@ -91,6 +125,8 @@ enum ArticleLang {
EN EN
} }
scalar Date
scalar DateTime scalar DateTime
interface Node { interface Node {
@ -111,6 +147,7 @@ type Reporter {
email: String! email: String!
pets: [Reporter] pets: [Reporter]
aChoice: ReporterAChoice! aChoice: ReporterAChoice!
reporterType: ReporterReporterType
articles(before: String, after: String, first: Int, last: Int): ArticleConnection articles(before: String, after: String, first: Int, last: Int): ArticleConnection
} }
@ -119,6 +156,11 @@ enum ReporterAChoice {
A_2 A_2
} }
enum ReporterReporterType {
A_1
A_2
}
type RootQuery { type RootQuery {
node(id: ID!): Node node(id: ID!): Node
} }

View File

@ -30,6 +30,20 @@ jl = lambda **kwargs: json.dumps([kwargs])
def test_graphiql_is_enabled(client): def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT='text/html') response = client.get(url_string(), HTTP_ACCEPT='text/html')
assert response.status_code == 200 assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'text/html'
def test_qfactor_graphiql(client):
response = client.get(url_string(query='{test}'), HTTP_ACCEPT='application/json;q=0.8, text/html;q=0.9')
assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'text/html'
def test_qfactor_json(client):
response = client.get(url_string(query='{test}'), HTTP_ACCEPT='text/html;q=0.8, application/json;q=0.9')
assert response.status_code == 200
assert response['Content-Type'].split(';')[0] == 'application/json'
assert response_json(response) == {
'data': {'test': "Hello World"}
}
def test_allows_get_with_query_param(client): def test_allows_get_with_query_param(client):
@ -386,6 +400,24 @@ def test_allows_post_with_get_operation_name(client):
} }
@pytest.mark.urls('graphene_django.tests.urls_inherited')
def test_inherited_class_with_attributes_works(client):
inherited_url = '/graphql/inherited/'
# Check schema and pretty attributes work
response = client.post(url_string(inherited_url, query='{test}'))
assert response.content.decode() == (
'{\n'
' "data": {\n'
' "test": "Hello World"\n'
' }\n'
'}'
)
# Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT='text/html')
assert response.status_code == 200
@pytest.mark.urls('graphene_django.tests.urls_pretty') @pytest.mark.urls('graphene_django.tests.urls_pretty')
def test_supports_pretty_printing(client): def test_supports_pretty_printing(client):
response = client.get(url_string(query='{test}')) response = client.get(url_string(query='{test}'))
@ -416,7 +448,11 @@ def test_handles_field_errors_caught_by_graphql(client):
assert response.status_code == 200 assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
'data': None, 'data': None,
'errors': [{'locations': [{'column': 2, 'line': 1}], 'message': 'Throws!'}] 'errors': [{
'locations': [{'column': 2, 'line': 1}],
'path': ['thrower'],
'message': 'Throws!',
}]
} }
@ -425,7 +461,7 @@ def test_handles_syntax_errors_caught_by_graphql(client):
assert response.status_code == 400 assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
'errors': [{'locations': [{'column': 1, 'line': 1}], 'errors': [{'locations': [{'column': 1, 'line': 1}],
'message': 'Syntax Error GraphQL request (1:1) ' 'message': 'Syntax Error GraphQL (1:1) '
'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}] 'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}]
} }

View File

@ -0,0 +1,14 @@
from django.conf.urls import url
from ..views import GraphQLView
from .schema_view import schema
class CustomGraphQLView(GraphQLView):
schema = schema
graphiql = True
pretty = True
urlpatterns = [
url(r'^graphql/inherited/$', CustomGraphQLView.as_view()),
]

View File

@ -45,7 +45,7 @@ class DjangoObjectType(ObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
only_fields=(), exclude_fields=(), filter_fields=None, connection=None, only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
use_connection=None, interfaces=(), **options): connection_class=None, use_connection=None, interfaces=(), _meta=None, **options):
assert is_valid_django_model(model), ( assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".' 'You need to pass a valid Django Model in {}.Meta, received "{}".'
).format(cls.__name__, model) ).format(cls.__name__, model)
@ -71,14 +71,20 @@ class DjangoObjectType(ObjectType):
if use_connection and not connection: if use_connection and not connection:
# We create the connection automatically # We create the connection automatically
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) if not connection_class:
connection_class = Connection
connection = connection_class.create_type(
'{}Connection'.format(cls.__name__), node=cls)
if connection is not None: if connection is not None:
assert issubclass(connection, Connection), ( assert issubclass(connection, Connection), (
"The connection must be a Connection. Received {}" "The connection must be a Connection. Received {}"
).format(connection.__name__) ).format(connection.__name__)
_meta = DjangoObjectTypeOptions(cls) if not _meta:
_meta = DjangoObjectTypeOptions(cls)
_meta.model = model _meta.model = model
_meta.registry = registry _meta.registry = registry
_meta.filter_fields = filter_fields _meta.filter_fields = filter_fields
@ -104,7 +110,8 @@ class DjangoObjectType(ObjectType):
raise Exception(( raise Exception((
'Received incompatible instance "{}".' 'Received incompatible instance "{}".'
).format(root)) ).format(root))
model = root._meta.model
model = root._meta.model._meta.concrete_model
return model == cls._meta.model return model == cls._meta.model
@classmethod @classmethod

View File

@ -10,12 +10,11 @@ from django.utils.decorators import method_decorator
from django.views.generic import View from django.views.generic import View
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from graphql import Source, execute, parse, validate from graphql import get_default_backend
from graphql.error import format_error as format_graphql_error from graphql.error import format_error as format_graphql_error
from graphql.error import GraphQLError from graphql.error import GraphQLError
from graphql.execution import ExecutionResult from graphql.execution import ExecutionResult
from graphql.type.schema import GraphQLSchema from graphql.type.schema import GraphQLSchema
from graphql.utils.get_operation_ast import get_operation_ast
from .settings import graphene_settings from .settings import graphene_settings
@ -35,8 +34,8 @@ def get_accepted_content_types(request):
match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)', match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)',
parts[1]) parts[1])
if match: if match:
return parts[0], float(match.group(2)) return parts[0].strip(), float(match.group(2))
return parts[0], 1 return parts[0].strip(), 1
raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',') raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',')
qualified_content_types = map(qualify, raw_content_types) qualified_content_types = map(qualify, raw_content_types)
@ -53,36 +52,43 @@ def instantiate_middleware(middlewares):
class GraphQLView(View): class GraphQLView(View):
graphiql_version = '0.10.2' graphiql_version = '0.11.10'
graphiql_template = 'graphene/graphiql.html' graphiql_template = 'graphene/graphiql.html'
schema = None schema = None
graphiql = False graphiql = False
executor = None executor = None
backend = None
middleware = None middleware = None
root_value = None root_value = None
pretty = False pretty = False
batch = False batch = False
def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False, def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False,
batch=False): batch=False, backend=None):
if not schema: if not schema:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
if backend is None:
backend = get_default_backend()
if middleware is None: if middleware is None:
middleware = graphene_settings.MIDDLEWARE middleware = graphene_settings.MIDDLEWARE
self.schema = schema self.schema = self.schema or schema
if middleware is not None: if middleware is not None:
self.middleware = list(instantiate_middleware(middleware)) self.middleware = list(instantiate_middleware(middleware))
self.executor = executor self.executor = executor
self.root_value = root_value self.root_value = root_value
self.pretty = pretty self.pretty = self.pretty or pretty
self.graphiql = graphiql self.graphiql = self.graphiql or graphiql
self.batch = batch self.batch = self.batch or batch
self.backend = backend
assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' assert isinstance(
assert not all((graphiql, batch)), 'Use either graphiql or batch processing' self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
assert not all((graphiql, batch)
), 'Use either graphiql or batch processing'
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def get_root_value(self, request): def get_root_value(self, request):
@ -94,24 +100,31 @@ class GraphQLView(View):
def get_context(self, request): def get_context(self, request):
return request return request
def get_backend(self, request):
return self.backend
@method_decorator(ensure_csrf_cookie) @method_decorator(ensure_csrf_cookie)
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
try: try:
if request.method.lower() not in ('get', 'post'): if request.method.lower() not in ('get', 'post'):
raise HttpError(HttpResponseNotAllowed(['GET', 'POST'], 'GraphQL only supports GET and POST requests.')) raise HttpError(HttpResponseNotAllowed(
['GET', 'POST'], 'GraphQL only supports GET and POST requests.'))
data = self.parse_body(request) data = self.parse_body(request)
show_graphiql = self.graphiql and self.can_display_graphiql(request, data) show_graphiql = self.graphiql and self.can_display_graphiql(
request, data)
if self.batch: if self.batch:
responses = [self.get_response(request, entry) for entry in data] responses = [self.get_response(request, entry) for entry in data]
result = '[{}]'.format(','.join([response[0] for response in responses])) result = '[{}]'.format(','.join([response[0] for response in responses]))
status_code = max(responses, key=lambda response: response[1])[1] status_code = responses and max(responses, key=lambda response: response[1])[1] or 200
else: else:
result, status_code = self.get_response(request, data, show_graphiql) result, status_code = self.get_response(
request, data, show_graphiql)
if show_graphiql: if show_graphiql:
query, variables, operation_name, id = self.get_graphql_params(request, data) query, variables, operation_name, id = self.get_graphql_params(
request, data)
return self.render_graphiql( return self.render_graphiql(
request, request,
graphiql_version=self.graphiql_version, graphiql_version=self.graphiql_version,
@ -136,7 +149,8 @@ class GraphQLView(View):
return response return response
def get_response(self, request, data, show_graphiql=False): def get_response(self, request, data, show_graphiql=False):
query, variables, operation_name, id = self.get_graphql_params(request, data) query, variables, operation_name, id = self.get_graphql_params(
request, data)
execution_result = self.execute_graphql_request( execution_result = self.execute_graphql_request(
request, request,
@ -152,7 +166,8 @@ class GraphQLView(View):
response = {} response = {}
if execution_result.errors: if execution_result.errors:
response['errors'] = [self.format_error(e) for e in execution_result.errors] response['errors'] = [self.format_error(
e) for e in execution_result.errors]
if execution_result.invalid: if execution_result.invalid:
status_code = 400 status_code = 400
@ -209,54 +224,52 @@ class GraphQLView(View):
except AssertionError as e: except AssertionError as e:
raise HttpError(HttpResponseBadRequest(str(e))) raise HttpError(HttpResponseBadRequest(str(e)))
except (TypeError, ValueError): except (TypeError, ValueError):
raise HttpError(HttpResponseBadRequest('POST body sent invalid JSON.')) raise HttpError(HttpResponseBadRequest(
'POST body sent invalid JSON.'))
elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']: elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']:
return request.POST return request.POST
return {} return {}
def execute(self, *args, **kwargs):
return execute(self.schema, *args, **kwargs)
def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False): def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False):
if not query: if not query:
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(HttpResponseBadRequest('Must provide query string.')) raise HttpError(HttpResponseBadRequest(
'Must provide query string.'))
source = Source(query, name='GraphQL request')
try: try:
document_ast = parse(source) backend = self.get_backend(request)
validation_errors = validate(self.schema, document_ast) document = backend.document_from_string(self.schema, query)
if validation_errors:
return ExecutionResult(
errors=validation_errors,
invalid=True,
)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e], invalid=True) return ExecutionResult(errors=[e], invalid=True)
if request.method.lower() == 'get': if request.method.lower() == 'get':
operation_ast = get_operation_ast(document_ast, operation_name) operation_type = document.get_operation_type(operation_name)
if operation_ast and operation_ast.operation != 'query': if operation_type and operation_type != 'query':
if show_graphiql: if show_graphiql:
return None return None
raise HttpError(HttpResponseNotAllowed( raise HttpError(HttpResponseNotAllowed(
['POST'], 'Can only perform a {} operation from a POST request.'.format(operation_ast.operation) ['POST'], 'Can only perform a {} operation from a POST request.'.format(
operation_type)
)) ))
try: try:
return self.execute( extra_options = {}
document_ast, if self.executor:
root_value=self.get_root_value(request), # We only include it optionally since
variable_values=variables, # executor is not a valid argument in all backends
extra_options['executor'] = self.executor
return document.execute(
root=self.get_root_value(request),
variables=variables,
operation_name=operation_name, operation_name=operation_name,
context_value=self.get_context(request), context=self.get_context(request),
middleware=self.get_middleware(request), middleware=self.get_middleware(request),
executor=self.executor, **extra_options
) )
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e], invalid=True) return ExecutionResult(errors=[e], invalid=True)
@ -269,10 +282,13 @@ class GraphQLView(View):
@classmethod @classmethod
def request_wants_html(cls, request): def request_wants_html(cls, request):
accepted = get_accepted_content_types(request) accepted = get_accepted_content_types(request)
html_index = accepted.count('text/html') accepted_length = len(accepted)
json_index = accepted.count('application/json') # the list will be ordered in preferred first - so we have to make
# sure the most preferred gets the highest number
html_priority = accepted_length - accepted.index('text/html') if 'text/html' in accepted else 0
json_priority = accepted_length - accepted.index('application/json') if 'application/json' in accepted else 0
return html_index > json_index return html_priority > json_priority
@staticmethod @staticmethod
def get_graphql_params(request, data): def get_graphql_params(request, data):
@ -283,10 +299,12 @@ class GraphQLView(View):
if variables and isinstance(variables, six.text_type): if variables and isinstance(variables, six.text_type):
try: try:
variables = json.loads(variables) variables = json.loads(variables)
except: except Exception:
raise HttpError(HttpResponseBadRequest('Variables are invalid JSON.')) raise HttpError(HttpResponseBadRequest(
'Variables are invalid JSON.'))
operation_name = request.GET.get('operationName') or data.get('operationName') operation_name = request.GET.get(
'operationName') or data.get('operationName')
if operation_name == "null": if operation_name == "null":
operation_name = None operation_name = None
@ -302,5 +320,6 @@ class GraphQLView(View):
@staticmethod @staticmethod
def get_content_type(request): def get_content_type(request):
meta = request.META meta = request.META
content_type = meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) content_type = meta.get(
'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
return content_type.split(';', 1)[0].lower() return content_type.split(';', 1)[0].lower()

View File

@ -21,9 +21,10 @@ tests_require = [
'mock', 'mock',
'pytz', 'pytz',
'django-filter', 'django-filter',
'pytest-django==2.9.1', 'pytest-django>=3.2.1',
] + rest_framework_require ] + rest_framework_require
django_version = 'Django>=1.8.0,<2' if sys.version_info[0] < 3 else 'Django>=1.8.0'
setup( setup(
name='graphene-django', name='graphene-django',
version=version, version=version,
@ -57,11 +58,12 @@ setup(
install_requires=[ install_requires=[
'six>=1.10.0', 'six>=1.10.0',
'graphene>=2.0.dev', 'graphene>=2.0.1,<3',
'Django>=1.8.0', 'graphql-core>=2.1rc1',
django_version,
'iso8601', 'iso8601',
'singledispatch>=3.4.0.3', 'singledispatch>=3.4.0.3',
'promise>=2.1.dev', 'promise>=2.1',
], ],
setup_requires=[ setup_requires=[
'pytest-runner', 'pytest-runner',