diff --git a/.travis.yml b/.travis.yml index 1999433..a8375ee 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,9 +11,6 @@ install: pip install -e .[test] pip install psycopg2 # Required for Django postgres fields testing pip install django==$DJANGO_VERSION - if (($(echo "$DJANGO_VERSION <= 1.9" | bc -l))); then # DRF dropped 1.8 and 1.9 support at 3.7.0 - pip install djangorestframework==3.6.4 - fi python setup.py develop elif [ "$TEST_TYPE" = lint ]; then pip install flake8 @@ -38,13 +35,19 @@ env: matrix: fast_finish: true include: + - python: '3.4' + env: TEST_TYPE=build DJANGO_VERSION=2.0 + - python: '3.5' + env: TEST_TYPE=build DJANGO_VERSION=2.0 + - python: '3.6' + env: TEST_TYPE=build DJANGO_VERSION=2.0 + - python: '3.5' + env: TEST_TYPE=build DJANGO_VERSION=2.1 + - python: '3.6' + env: TEST_TYPE=build DJANGO_VERSION=2.1 - python: '2.7' - env: TEST_TYPE=build DJANGO_VERSION=1.8 - - python: '2.7' - env: TEST_TYPE=build DJANGO_VERSION=1.9 - - python: '2.7' - env: TEST_TYPE=build DJANGO_VERSION=1.10 - - python: '2.7' + env: TEST_TYPE=lint + - python: '3.6' env: TEST_TYPE=lint deploy: provider: pypi diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b4a4b70 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016-Present Syrus Akbary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in index 8fffb71..3c3d4f9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ -include README.md +include README.md LICENSE recursive-include graphene_django/templates * diff --git a/README.md b/README.md index 4e0b01d..ef3f40c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ pip install "graphene-django>=2.0" ```python INSTALLED_APPS = ( # ... + 'django.contrib.staticfiles', # Required for GraphiQL 'graphene_django', ) diff --git a/docs/authorization.rst b/docs/authorization.rst index 707dbf6..86ad66a 100644 --- a/docs/authorization.rst +++ b/docs/authorization.rst @@ -20,7 +20,7 @@ Let's use a simple example model. Limiting Field Access --------------------- -This is easy, simply use the ``only_fields`` meta attribute. +To limit fields in a GraphQL query simply use the ``only_fields`` meta attribute. .. code:: python @@ -61,10 +61,11 @@ define a resolve method for that field and return the desired queryset. from .models import Post class Query(ObjectType): - all_posts = DjangoFilterConnectionField(CategoryNode) + all_posts = DjangoFilterConnectionField(PostNode) + + def resolve_all_posts(self, info): + return Post.objects.filter(published=True) - def resolve_all_posts(self, args, info): - return Post.objects.filter(published=True) User-based Queryset Filtering ----------------------------- @@ -79,7 +80,7 @@ with the context argument. from .models import Post class Query(ObjectType): - my_posts = DjangoFilterConnectionField(CategoryNode) + my_posts = DjangoFilterConnectionField(PostNode) def resolve_my_posts(self, info): # context will reference to the Django request @@ -95,7 +96,7 @@ schema is simple. result = schema.execute(query, context_value=request) -Filtering ID-based node access +Filtering ID-based Node Access ------------------------------ In order to add authorization to id-based node access, we need to add a @@ -113,23 +114,25 @@ method to your ``DjangoObjectType``. interfaces = (relay.Node, ) @classmethod - def get_node(cls, id, context, info): + def get_node(cls, info, id): try: post = cls._meta.model.objects.get(id=id) except cls._meta.model.DoesNotExist: return None - if post.published or context.user == post.owner: + if post.published or info.context.user == post.owner: return post return None -Adding login required + +Adding Login Required --------------------- -If you want to use the standard Django LoginRequiredMixin_ you can create your own view, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``: +To restrict users from accessing the GraphQL API page the standard Django LoginRequiredMixin_ can be used to create your own standard Django Class Based View, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``.: .. code:: python - + #views.py + from django.contrib.auth.mixins import LoginRequiredMixin from graphene_django.views import GraphQLView @@ -137,7 +140,9 @@ If you want to use the standard Django LoginRequiredMixin_ you can create your o class PrivateGraphQLView(LoginRequiredMixin, GraphQLView): pass -After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``: +After this, you can use the new ``PrivateGraphQLView`` in the project's URL Configuration file ``url.py``: + +For Django 1.9 and below: .. code:: python @@ -145,5 +150,14 @@ After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``: # some other urls url(r'^graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)), ] + +For Django 2.0 and above: + +.. code:: python + + urlpatterns = [ + # some other urls + path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)), + ] .. _LoginRequiredMixin: https://docs.djangoproject.com/en/1.10/topics/auth/default/#the-loginrequired-mixin diff --git a/docs/filtering.rst b/docs/filtering.rst index b5ae158..feafd40 100644 --- a/docs/filtering.rst +++ b/docs/filtering.rst @@ -2,9 +2,9 @@ Filtering ========= Graphene integrates with -`django-filter `__ to provide -filtering of results. See the `usage -documentation `__ +`django-filter `__ (2.x for +Python 3 or 1.x for Python 2) to provide filtering of results. See the `usage +documentation `__ for details on the format for ``filter_fields``. This filtering is automatically available when implementing a ``relay.Node``. @@ -15,7 +15,7 @@ You will need to install it manually, which can be done as follows: .. code:: bash # You'll need to django-filter - pip install django-filter + pip install django-filter>=2 Note: The techniques below are demoed in the `cookbook example app `__. @@ -26,7 +26,7 @@ Filterable fields The ``filter_fields`` parameter is used to specify the fields which can be filtered upon. The value specified here is passed directly to ``django-filter``, so see the `filtering -documentation `__ +documentation `__ for full details on the range of options available. For example: @@ -127,7 +127,7 @@ create your own ``Filterset`` as follows: all_animals = DjangoFilterConnectionField(AnimalNode, filterset_class=AnimalFilter) -The context argument is passed on as the `request argument `__ +The context argument is passed on as the `request argument `__ in a ``django_filters.FilterSet`` instance. You can use this to customize your filters to be context-dependent. We could modify the ``AnimalFilter`` above to pre-filter animals owned by the authenticated user (set in ``context.user``). @@ -145,4 +145,4 @@ pre-filter animals owned by the authenticated user (set in ``context.user``). @property def qs(self): # The query context can be found in self.request. - return super(AnimalFilter, self).filter(owner=self.request.user) + return super(AnimalFilter, self).qs.filter(owner=self.request.user) diff --git a/docs/form-mutations.rst b/docs/form-mutations.rst new file mode 100644 index 0000000..e721a78 --- /dev/null +++ b/docs/form-mutations.rst @@ -0,0 +1,68 @@ +Integration with Django forms +============================= + +Graphene-Django comes with mutation classes that will convert the fields on Django forms into inputs on a mutation. +*Note: the API is experimental and will likely change in the future.* + +FormMutation +------------ + +.. code:: python + + class MyForm(forms.Form): + name = forms.CharField() + + class MyMutation(FormMutation): + class Meta: + form_class = MyForm + +``MyMutation`` will automatically receive an ``input`` argument. This argument should be a ``dict`` where the key is ``name`` and the value is a string. + +ModelFormMutation +----------------- + +``ModelFormMutation`` will pull the fields from a ``ModelForm``. + +.. code:: python + + class Pet(models.Model): + name = models.CharField() + + class PetForm(forms.ModelForm): + class Meta: + model = Pet + fields = ('name',) + + # This will get returned when the mutation completes successfully + class PetType(DjangoObjectType): + class Meta: + model = Pet + + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + +``PetMutation`` will grab the fields from ``PetForm`` and turn them into inputs. If the form is valid then the mutation +will lookup the ``DjangoObjectType`` for the ``Pet`` model and return that under the key ``pet``. Otherwise it will +return a list of errors. + +You can change the input name (default is ``input``) and the return field name (default is the model name lowercase). + +.. code:: python + + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + input_field_name = 'data' + return_field_name = 'my_pet' + +Form validation +--------------- + +Form mutations will call ``is_valid()`` on your forms. + +If the form is valid then ``form_valid(form, info)`` is called on the mutation. Override this method to change how +the form is saved or to return a different Graphene object type. + +If the form is *not* valid then a list of errors will be returned. These errors have two fields: ``field``, a string +containing the name of the invalid form field, and ``messages``, a list of strings with the validation messages. diff --git a/docs/index.rst b/docs/index.rst index 256da68..7c64ae7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,4 +12,5 @@ Contents: authorization debug rest-framework + form-mutations introspection diff --git a/docs/requirements.txt b/docs/requirements.txt index 2548604..220b7cf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,3 @@ sphinx # Docs template -https://github.com/graphql-python/graphene-python.org/archive/docs.zip +http://graphene-python.org/sphinx_graphene_theme.zip diff --git a/docs/rest-framework.rst b/docs/rest-framework.rst index 5e5dd70..ce666de 100644 --- a/docs/rest-framework.rst +++ b/docs/rest-framework.rst @@ -19,3 +19,46 @@ You can create a Mutation based on a serializer by using the class Meta: serializer_class = MySerializer +Create/Update Operations +--------------------- + +By default ModelSerializers accept create and update operations. To +customize this use the `model_operations` attribute. The update +operation looks up models by the primary key by default. You can +customize the look up with the lookup attribute. + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class AwesomeModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ['create', 'update'] + lookup_field = 'id' + +Overriding Update Queries +------------------------- + +Use the method `get_serializer_kwargs` to override how +updates are applied. + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class AwesomeModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + + @classmethod + def get_serializer_kwargs(cls, root, info, **input): + if 'id' in input: + instance = Post.objects.filter(id=input['id'], owner=info.context.user).first() + if instance: + return {'instance': instance, 'data': input, 'partial': True} + + else: + raise http.Http404 + + return {'data': input, 'partial': True} diff --git a/docs/tutorial-plain.rst b/docs/tutorial-plain.rst index cf877eb..a87b011 100644 --- a/docs/tutorial-plain.rst +++ b/docs/tutorial-plain.rst @@ -68,8 +68,8 @@ Let's get started with these models: class Ingredient(models.Model): name = models.CharField(max_length=100) notes = models.TextField() - category = models.ForeignKey(Category, related_name='ingredients', - on_delete=models.CASCADE) + category = models.ForeignKey( + Category, related_name='ingredients', on_delete=models.CASCADE) def __str__(self): return self.name @@ -84,6 +84,7 @@ Add ingredients as INSTALLED_APPS: 'cookbook.ingredients', ] + Don't forget to create & run migrations: .. code:: bash @@ -112,6 +113,18 @@ Alternatively you can use the Django admin interface to create some data yourself. You'll need to run the development server (see below), and create a login for yourself too (``./manage.py createsuperuser``). +Register models with admin panel: + +.. code:: python + + # cookbook/ingredients/admin.py + from django.contrib import admin + from cookbook.ingredients.models import Category, Ingredient + + admin.site.register(Category) + admin.site.register(Ingredient) + + Hello GraphQL - Schema and Object Types --------------------------------------- @@ -166,9 +179,9 @@ Create ``cookbook/ingredients/schema.py`` and type the following: return Ingredient.objects.select_related('category').all() -Note that the above ``Query`` class is marked as 'abstract'. This is -because we will now create a project-level query which will combine all -our app-level queries. +Note that the above ``Query`` class is a mixin, inheriting from +``object``. This is because we will now create a project-level query +class which will combine all our app-level mixins. Create the parent project-level ``cookbook/schema.py``: diff --git a/docs/tutorial-relay.rst b/docs/tutorial-relay.rst index 3ac4cec..f2502d7 100644 --- a/docs/tutorial-relay.rst +++ b/docs/tutorial-relay.rst @@ -10,7 +10,7 @@ app `__ -* `GraphQL Relay Specification `__ +* `GraphQL Relay Specification `__ Setup the Django project ------------------------ @@ -118,7 +118,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following: .. code:: python # cookbook/ingredients/schema.py - from graphene import relay, ObjectType, AbstractType + from graphene import relay, ObjectType from graphene_django import DjangoObjectType from graphene_django.filter import DjangoFilterConnectionField @@ -147,7 +147,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following: interfaces = (relay.Node, ) - class Query(AbstractType): + class Query(object): category = relay.Node.Field(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode) diff --git a/examples/cookbook-plain/README.md b/examples/cookbook-plain/README.md index 4075082..0ec906b 100644 --- a/examples/cookbook-plain/README.md +++ b/examples/cookbook-plain/README.md @@ -3,7 +3,7 @@ Cookbook Example Django Project This example project demos integration between Graphene and Django. The project contains two apps, one named `ingredients` and another -named `recepies`. +named `recipes`. Getting started --------------- diff --git a/examples/cookbook-plain/cookbook/ingredients/migrations/0003_auto_20181018_1746.py b/examples/cookbook-plain/cookbook/ingredients/migrations/0003_auto_20181018_1746.py new file mode 100644 index 0000000..184e79e --- /dev/null +++ b/examples/cookbook-plain/cookbook/ingredients/migrations/0003_auto_20181018_1746.py @@ -0,0 +1,17 @@ +# Generated by Django 2.0 on 2018-10-18 17:46 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('ingredients', '0002_auto_20161104_0050'), + ] + + operations = [ + migrations.AlterModelOptions( + name='category', + options={'verbose_name_plural': 'Categories'}, + ), + ] diff --git a/examples/cookbook-plain/cookbook/ingredients/models.py b/examples/cookbook-plain/cookbook/ingredients/models.py index 2f0eba3..5836949 100644 --- a/examples/cookbook-plain/cookbook/ingredients/models.py +++ b/examples/cookbook-plain/cookbook/ingredients/models.py @@ -2,6 +2,8 @@ from django.db import models class Category(models.Model): + class Meta: + verbose_name_plural = 'Categories' name = models.CharField(max_length=100) def __str__(self): @@ -11,7 +13,7 @@ class Category(models.Model): class Ingredient(models.Model): name = models.CharField(max_length=100) notes = models.TextField(null=True, blank=True) - category = models.ForeignKey(Category, related_name='ingredients') + category = models.ForeignKey(Category, related_name='ingredients', on_delete=models.CASCADE) def __str__(self): return self.name diff --git a/examples/cookbook-plain/cookbook/ingredients/schema.py b/examples/cookbook-plain/cookbook/ingredients/schema.py index 895f216..e7ef688 100644 --- a/examples/cookbook-plain/cookbook/ingredients/schema.py +++ b/examples/cookbook-plain/cookbook/ingredients/schema.py @@ -1,7 +1,7 @@ import graphene from graphene_django.types import DjangoObjectType -from cookbook.ingredients.models import Category, Ingredient +from .models import Category, Ingredient class CategoryType(DjangoObjectType): @@ -14,7 +14,7 @@ class IngredientType(DjangoObjectType): model = Ingredient -class Query(graphene.AbstractType): +class Query(object): category = graphene.Field(CategoryType, id=graphene.Int(), name=graphene.String()) @@ -25,17 +25,14 @@ class Query(graphene.AbstractType): name=graphene.String()) all_ingredients = graphene.List(IngredientType) - def resolve_all_categories(self, args, context, info): + def resolve_all_categories(self, context): return Category.objects.all() - def resolve_all_ingredients(self, args, context, info): + def resolve_all_ingredients(self, context): # We can easily optimize query count in the resolve method return Ingredient.objects.select_related('category').all() - def resolve_category(self, args, context, info): - id = args.get('id') - name = args.get('name') - + def resolve_category(self, context, id=None, name=None): if id is not None: return Category.objects.get(pk=id) @@ -44,10 +41,7 @@ class Query(graphene.AbstractType): return None - def resolve_ingredient(self, args, context, info): - id = args.get('id') - name = args.get('name') - + def resolve_ingredient(self, context, id=None, name=None): if id is not None: return Ingredient.objects.get(pk=id) diff --git a/examples/cookbook-plain/cookbook/ingredients/tests.py b/examples/cookbook-plain/cookbook/ingredients/tests.py new file mode 100644 index 0000000..4929020 --- /dev/null +++ b/examples/cookbook-plain/cookbook/ingredients/tests.py @@ -0,0 +1,2 @@ + +# Create your tests here. diff --git a/examples/cookbook-plain/cookbook/ingredients/views.py b/examples/cookbook-plain/cookbook/ingredients/views.py new file mode 100644 index 0000000..b8e4ee0 --- /dev/null +++ b/examples/cookbook-plain/cookbook/ingredients/views.py @@ -0,0 +1,2 @@ + +# Create your views here. diff --git a/examples/cookbook-plain/cookbook/recipes/migrations/0003_auto_20181018_1728.py b/examples/cookbook-plain/cookbook/recipes/migrations/0003_auto_20181018_1728.py new file mode 100644 index 0000000..7a8df49 --- /dev/null +++ b/examples/cookbook-plain/cookbook/recipes/migrations/0003_auto_20181018_1728.py @@ -0,0 +1,18 @@ +# Generated by Django 2.0 on 2018-10-18 17:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('recipes', '0002_auto_20161104_0106'), + ] + + operations = [ + migrations.AlterField( + model_name='recipeingredient', + name='unit', + field=models.CharField(choices=[('unit', 'Units'), ('kg', 'Kilograms'), ('l', 'Litres'), ('st', 'Shots')], max_length=20), + ), + ] diff --git a/examples/cookbook-plain/cookbook/recipes/models.py b/examples/cookbook-plain/cookbook/recipes/models.py index e688044..382b88e 100644 --- a/examples/cookbook-plain/cookbook/recipes/models.py +++ b/examples/cookbook-plain/cookbook/recipes/models.py @@ -1,16 +1,18 @@ from django.db import models -from cookbook.ingredients.models import Ingredient +from ..ingredients.models import Ingredient class Recipe(models.Model): title = models.CharField(max_length=100) instructions = models.TextField() + def __str__(self): + return self.title class RecipeIngredient(models.Model): - recipe = models.ForeignKey(Recipe, related_name='amounts') - ingredient = models.ForeignKey(Ingredient, related_name='used_by') + recipe = models.ForeignKey(Recipe, related_name='amounts', on_delete=models.CASCADE) + ingredient = models.ForeignKey(Ingredient, related_name='used_by', on_delete=models.CASCADE) amount = models.FloatField() unit = models.CharField(max_length=20, choices=( ('unit', 'Units'), diff --git a/examples/cookbook-plain/cookbook/recipes/schema.py b/examples/cookbook-plain/cookbook/recipes/schema.py index 8ea1ccd..74692f8 100644 --- a/examples/cookbook-plain/cookbook/recipes/schema.py +++ b/examples/cookbook-plain/cookbook/recipes/schema.py @@ -1,7 +1,7 @@ import graphene from graphene_django.types import DjangoObjectType -from cookbook.recipes.models import Recipe, RecipeIngredient +from .models import Recipe, RecipeIngredient class RecipeType(DjangoObjectType): @@ -14,7 +14,7 @@ class RecipeIngredientType(DjangoObjectType): model = RecipeIngredient -class Query(graphene.AbstractType): +class Query(object): recipe = graphene.Field(RecipeType, id=graphene.Int(), title=graphene.String()) @@ -24,10 +24,7 @@ class Query(graphene.AbstractType): id=graphene.Int()) all_recipeingredients = graphene.List(RecipeIngredientType) - def resolve_recipe(self, args, context, info): - id = args.get('id') - title = args.get('title') - + def resolve_recipe(self, context, id=None, title=None): if id is not None: return Recipe.objects.get(pk=id) @@ -36,17 +33,15 @@ class Query(graphene.AbstractType): return None - def resolve_recipeingredient(self, args, context, info): - id = args.get('id') - + def resolve_recipeingredient(self, context, id=None): if id is not None: return RecipeIngredient.objects.get(pk=id) return None - def resolve_all_recipes(self, args, context, info): + def resolve_all_recipes(self, context): return Recipe.objects.all() - def resolve_all_recipeingredients(self, args, context, info): + def resolve_all_recipeingredients(self, context): related = ['recipe', 'ingredient'] return RecipeIngredient.objects.select_related(*related).all() diff --git a/examples/cookbook-plain/cookbook/recipes/tests.py b/examples/cookbook-plain/cookbook/recipes/tests.py new file mode 100644 index 0000000..4929020 --- /dev/null +++ b/examples/cookbook-plain/cookbook/recipes/tests.py @@ -0,0 +1,2 @@ + +# Create your tests here. diff --git a/examples/cookbook-plain/cookbook/recipes/views.py b/examples/cookbook-plain/cookbook/recipes/views.py new file mode 100644 index 0000000..b8e4ee0 --- /dev/null +++ b/examples/cookbook-plain/cookbook/recipes/views.py @@ -0,0 +1,2 @@ + +# Create your views here. diff --git a/examples/cookbook-plain/cookbook/settings.py b/examples/cookbook-plain/cookbook/settings.py index 948292d..d846db4 100644 --- a/examples/cookbook-plain/cookbook/settings.py +++ b/examples/cookbook-plain/cookbook/settings.py @@ -44,13 +44,12 @@ INSTALLED_APPS = [ 'cookbook.recipes.apps.RecipesConfig', ] -MIDDLEWARE_CLASSES = [ +MIDDLEWARE = [ 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', ] diff --git a/examples/cookbook-plain/cookbook/urls.py b/examples/cookbook-plain/cookbook/urls.py index 9f8755b..4f87da0 100644 --- a/examples/cookbook-plain/cookbook/urls.py +++ b/examples/cookbook-plain/cookbook/urls.py @@ -1,10 +1,10 @@ -from django.conf.urls import url +from django.urls import path from django.contrib import admin from graphene_django.views import GraphQLView urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^graphql', GraphQLView.as_view(graphiql=True)), + path('admin/', admin.site.urls), + path('graphql/', GraphQLView.as_view(graphiql=True)), ] diff --git a/examples/cookbook-plain/requirements.txt b/examples/cookbook-plain/requirements.txt index a693bd1..539fd67 100644 --- a/examples/cookbook-plain/requirements.txt +++ b/examples/cookbook-plain/requirements.txt @@ -1,4 +1,4 @@ graphene graphene-django -graphql-core -django==1.9 +graphql-core>=2.1rc1 +django==2.1.2 diff --git a/examples/cookbook/cookbook/ingredients/admin.py b/examples/cookbook/cookbook/ingredients/admin.py index 2b16cdc..b57cbc3 100644 --- a/examples/cookbook/cookbook/ingredients/admin.py +++ b/examples/cookbook/cookbook/ingredients/admin.py @@ -2,9 +2,11 @@ from django.contrib import admin from cookbook.ingredients.models import Category, Ingredient + @admin.register(Ingredient) class IngredientAdmin(admin.ModelAdmin): - list_display = ("id","name","category") - list_editable = ("name","category") - + list_display = ('id', 'name', 'category') + list_editable = ('name', 'category') + + admin.site.register(Category) diff --git a/examples/cookbook/cookbook/ingredients/models.py b/examples/cookbook/cookbook/ingredients/models.py index a072bcf..2f0eba3 100644 --- a/examples/cookbook/cookbook/ingredients/models.py +++ b/examples/cookbook/cookbook/ingredients/models.py @@ -10,7 +10,7 @@ class Category(models.Model): class Ingredient(models.Model): name = models.CharField(max_length=100) - notes = models.TextField(null=True,blank=True) + notes = models.TextField(null=True, blank=True) category = models.ForeignKey(Category, related_name='ingredients') def __str__(self): diff --git a/examples/cookbook/cookbook/ingredients/schema.py b/examples/cookbook/cookbook/ingredients/schema.py index b8b3c12..5ad92e8 100644 --- a/examples/cookbook/cookbook/ingredients/schema.py +++ b/examples/cookbook/cookbook/ingredients/schema.py @@ -1,5 +1,5 @@ from cookbook.ingredients.models import Category, Ingredient -from graphene import AbstractType, Node +from graphene import Node from graphene_django.filter import DjangoFilterConnectionField from graphene_django.types import DjangoObjectType @@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType): } -class Query(AbstractType): +class Query(object): category = Node.Field(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode) diff --git a/examples/cookbook/cookbook/recipes/admin.py b/examples/cookbook/cookbook/recipes/admin.py index 57e0418..10d568f 100644 --- a/examples/cookbook/cookbook/recipes/admin.py +++ b/examples/cookbook/cookbook/recipes/admin.py @@ -2,9 +2,11 @@ from django.contrib import admin from cookbook.recipes.models import Recipe, RecipeIngredient + class RecipeIngredientInline(admin.TabularInline): - model = RecipeIngredient + model = RecipeIngredient + @admin.register(Recipe) class RecipeAdmin(admin.ModelAdmin): - inlines = [RecipeIngredientInline] + inlines = [RecipeIngredientInline] diff --git a/examples/cookbook/cookbook/recipes/models.py b/examples/cookbook/cookbook/recipes/models.py index f666fe8..ca12fac 100644 --- a/examples/cookbook/cookbook/recipes/models.py +++ b/examples/cookbook/cookbook/recipes/models.py @@ -8,6 +8,7 @@ class Recipe(models.Model): instructions = models.TextField() __unicode__ = lambda self: self.title + class RecipeIngredient(models.Model): recipe = models.ForeignKey(Recipe, related_name='amounts') ingredient = models.ForeignKey(Ingredient, related_name='used_by') diff --git a/examples/cookbook/cookbook/recipes/schema.py b/examples/cookbook/cookbook/recipes/schema.py index 35a21de..8018322 100644 --- a/examples/cookbook/cookbook/recipes/schema.py +++ b/examples/cookbook/cookbook/recipes/schema.py @@ -1,5 +1,5 @@ from cookbook.recipes.models import Recipe, RecipeIngredient -from graphene import AbstractType, Node +from graphene import Node from graphene_django.filter import DjangoFilterConnectionField from graphene_django.types import DjangoObjectType @@ -24,7 +24,7 @@ class RecipeIngredientNode(DjangoObjectType): } -class Query(AbstractType): +class Query(object): recipe = Node.Field(RecipeNode) all_recipes = DjangoFilterConnectionField(RecipeNode) diff --git a/examples/cookbook/cookbook/schema.py b/examples/cookbook/cookbook/schema.py index 910e259..f8606a7 100644 --- a/examples/cookbook/cookbook/schema.py +++ b/examples/cookbook/cookbook/schema.py @@ -5,7 +5,9 @@ import graphene from graphene_django.debug import DjangoDebug -class Query(cookbook.recipes.schema.Query, cookbook.ingredients.schema.Query, graphene.ObjectType): +class Query(cookbook.ingredients.schema.Query, + cookbook.recipes.schema.Query, + graphene.ObjectType): debug = graphene.Field(DjangoDebug, name='__debug') diff --git a/examples/cookbook/cookbook/settings.py b/examples/cookbook/cookbook/settings.py index 1916201..948292d 100644 --- a/examples/cookbook/cookbook/settings.py +++ b/examples/cookbook/cookbook/settings.py @@ -1,3 +1,4 @@ +# flake8: noqa """ Django settings for cookbook project. diff --git a/examples/cookbook/cookbook/urls.py b/examples/cookbook/cookbook/urls.py index 9410ca5..9f8755b 100644 --- a/examples/cookbook/cookbook/urls.py +++ b/examples/cookbook/cookbook/urls.py @@ -3,6 +3,7 @@ from django.contrib import admin from graphene_django.views import GraphQLView + urlpatterns = [ url(r'^admin/', admin.site.urls), url(r'^graphql', GraphQLView.as_view(graphiql=True)), diff --git a/examples/cookbook/requirements.txt b/examples/cookbook/requirements.txt index 66fa629..b2ace1f 100644 --- a/examples/cookbook/requirements.txt +++ b/examples/cookbook/requirements.txt @@ -1,5 +1,5 @@ graphene graphene-django -graphql-core +graphql-core>=2.1rc1 django==1.9 -django-filter==0.11.0 +django-filter>=2 diff --git a/examples/cookbook/setup.cfg b/examples/cookbook/setup.cfg new file mode 100644 index 0000000..8c6a6e8 --- /dev/null +++ b/examples/cookbook/setup.cfg @@ -0,0 +1,2 @@ +[flake8] +exclude=migrations,.git,__pycache__ diff --git a/examples/starwars/models.py b/examples/starwars/models.py index 2f80e27..45741da 100644 --- a/examples/starwars/models.py +++ b/examples/starwars/models.py @@ -5,7 +5,7 @@ from django.db import models class Character(models.Model): name = models.CharField(max_length=50) - ship = models.ForeignKey('Ship', blank=True, null=True, related_name='characters') + ship = models.ForeignKey('Ship', on_delete=models.CASCADE, blank=True, null=True, related_name='characters') def __str__(self): return self.name @@ -13,7 +13,7 @@ class Character(models.Model): class Faction(models.Model): name = models.CharField(max_length=50) - hero = models.ForeignKey(Character) + hero = models.ForeignKey(Character, on_delete=models.CASCADE) def __str__(self): return self.name @@ -21,7 +21,7 @@ class Faction(models.Model): class Ship(models.Model): name = models.CharField(max_length=50) - faction = models.ForeignKey(Faction, related_name='ships') + faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name='ships') def __str__(self): return self.name diff --git a/graphene_django/__init__.py b/graphene_django/__init__.py index 5ba360f..4538cb3 100644 --- a/graphene_django/__init__.py +++ b/graphene_django/__init__.py @@ -1,14 +1,6 @@ -from .types import ( - DjangoObjectType, -) -from .fields import ( - DjangoConnectionField, -) +from .types import DjangoObjectType +from .fields import DjangoConnectionField -__version__ = '2.0.0' +__version__ = "2.2.0" -__all__ = [ - '__version__', - 'DjangoObjectType', - 'DjangoConnectionField' -] +__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"] diff --git a/graphene_django/compat.py b/graphene_django/compat.py index 0269e33..4a51de8 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -5,13 +5,7 @@ class MissingType(object): try: # Postgres fields are only available in Django with psycopg2 installed # and we cannot have psycopg2 on PyPy - from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField + from django.contrib.postgres.fields import (ArrayField, HStoreField, + JSONField, RangeField) except ImportError: - ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4 - - -try: - # Postgres fields are only available in Django 1.9+ - from django.contrib.postgres.fields import JSONField -except ImportError: - JSONField = MissingType + ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4 diff --git a/graphene_django/converter.py b/graphene_django/converter.py index f270e46..bf7c26b 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -2,9 +2,22 @@ from django.contrib.gis.db.models import GeometryField from django.db import models from django.utils.encoding import force_text -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - NonNull, String, UUID) -from graphene.types.datetime import DateTime, Time +from graphene import ( + ID, + Boolean, + Dynamic, + Enum, + Field, + Float, + Int, + List, + NonNull, + String, + UUID, + DateTime, + Date, + Time, +) from graphene.types.json import JSONString from graphene.utils.str_converters import to_camel_case, to_const from graphql import assert_valid_name @@ -34,7 +47,7 @@ def get_choices(choices): else: name = convert_choice_name(value) while name in converted_names: - name += '_' + str(len(converted_names)) + name += "_" + str(len(converted_names)) converted_names.append(name) description = help_text yield name, value, description @@ -45,16 +58,15 @@ def convert_django_field_with_choices(field, registry=None): converted = registry.get_converted_field(field) if converted: return converted - choices = getattr(field, 'choices', None) + choices = getattr(field, "choices", None) if choices: meta = field.model._meta - name = to_camel_case('{}_{}'.format(meta.object_name, field.name)) + name = to_camel_case("{}_{}".format(meta.object_name, field.name)) choices = list(get_choices(choices)) named_choices = [(c[0], c[1]) for c in choices] named_choices_descriptions = {c[0]: c[2] for c in choices} class EnumWithDescriptionsType(object): - @property def description(self): return named_choices_descriptions[self.name] @@ -71,8 +83,8 @@ def convert_django_field_with_choices(field, registry=None): @singledispatch def convert_django_field(field, registry=None): raise Exception( - "Don't know how to convert the Django field %s (%s)" % - (field, field.__class__)) + "Don't know how to convert the Django field %s (%s)" % (field, field.__class__) + ) @convert_django_field.register(models.CharField) @@ -82,6 +94,7 @@ def convert_django_field(field, registry=None): @convert_django_field.register(models.URLField) @convert_django_field.register(models.GenericIPAddressField) @convert_django_field.register(models.FileField) +@convert_django_field.register(models.FilePathField) @convert_django_field.register(GeometryField) def convert_field_to_string(field, registry=None): return String(description=field.help_text, required=not field.null) @@ -123,9 +136,14 @@ def convert_field_to_float(field, registry=None): return Float(description=field.help_text, required=not field.null) +@convert_django_field.register(models.DateTimeField) +def convert_datetime_to_string(field, registry=None): + return DateTime(description=field.help_text, required=not field.null) + + @convert_django_field.register(models.DateField) def convert_date_to_string(field, registry=None): - return DateTime(description=field.help_text, required=not field.null) + return Date(description=field.help_text, required=not field.null) @convert_django_field.register(models.TimeField) @@ -144,7 +162,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None): # We do this for a bug in Django 1.8, where null attr # is not available in the OneToOneRel instance - null = getattr(field, 'null', True) + null = getattr(field, "null", True) return Field(_type, required=not null) return Dynamic(dynamic_type) @@ -168,6 +186,7 @@ def convert_field_to_list_or_connection(field, registry=None): # defined filter_fields in the DjangoObjectType Meta if _type._meta.filter_fields: from .filter.fields import DjangoFilterConnectionField + return DjangoFilterConnectionField(_type) return DjangoConnectionField(_type) diff --git a/graphene_django/debug/__init__.py b/graphene_django/debug/__init__.py index cd5015e..3e078da 100644 --- a/graphene_django/debug/__init__.py +++ b/graphene_django/debug/__init__.py @@ -1,4 +1,4 @@ from .middleware import DjangoDebugMiddleware from .types import DjangoDebug -__all__ = ['DjangoDebugMiddleware', 'DjangoDebug'] +__all__ = ["DjangoDebugMiddleware", "DjangoDebug"] diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index 2b11f7e..48d471f 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -7,7 +7,6 @@ from .types import DjangoDebug class DjangoDebugContext(object): - def __init__(self): self.debug_promise = None self.promises = [] @@ -38,20 +37,21 @@ class DjangoDebugContext(object): class DjangoDebugMiddleware(object): - def resolve(self, next, root, info, **args): context = info.context - django_debug = getattr(context, 'django_debug', None) + django_debug = getattr(context, "django_debug", None) if not django_debug: if context is None: - raise Exception('DjangoDebug cannot be executed in None contexts') + raise Exception("DjangoDebug cannot be executed in None contexts") try: context.django_debug = DjangoDebugContext() except Exception: - raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format( - context.__class__.__name__ - )) - if info.schema.get_type('DjangoDebug') == info.return_type: + raise Exception( + "DjangoDebug need the context to be writable, context received: {}.".format( + context.__class__.__name__ + ) + ) + if info.schema.get_type("DjangoDebug") == info.return_type: return context.django_debug.get_debug_promise() promise = next(root, info, **args) context.django_debug.add_promise(promise) diff --git a/graphene_django/debug/sql/tracking.py b/graphene_django/debug/sql/tracking.py index 9d14e4b..f96583b 100644 --- a/graphene_django/debug/sql/tracking.py +++ b/graphene_django/debug/sql/tracking.py @@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception): class ThreadLocalState(local): - def __init__(self): self.enabled = True @@ -35,7 +34,7 @@ recording = state.recording # export function def wrap_cursor(connection, panel): - if not hasattr(connection, '_graphene_cursor'): + if not hasattr(connection, "_graphene_cursor"): connection._graphene_cursor = connection.cursor def cursor(): @@ -46,7 +45,7 @@ def wrap_cursor(connection, panel): def unwrap_cursor(connection): - if hasattr(connection, '_graphene_cursor'): + if hasattr(connection, "_graphene_cursor"): previous_cursor = connection._graphene_cursor connection.cursor = previous_cursor del connection._graphene_cursor @@ -87,15 +86,14 @@ class NormalCursorWrapper(object): if not params: return params if isinstance(params, dict): - return dict((key, self._quote_expr(value)) - for key, value in params.items()) + return dict((key, self._quote_expr(value)) for key, value in params.items()) return list(map(self._quote_expr, params)) def _decode(self, param): try: return force_text(param, strings_only=True) except UnicodeDecodeError: - return '(encoded string)' + return "(encoded string)" def _record(self, method, sql, params): start_time = time() @@ -103,45 +101,48 @@ class NormalCursorWrapper(object): return method(sql, params) finally: stop_time = time() - duration = (stop_time - start_time) - _params = '' + duration = stop_time - start_time + _params = "" try: _params = json.dumps(list(map(self._decode, params))) except Exception: pass # object not JSON serializable - alias = getattr(self.db, 'alias', 'default') + alias = getattr(self.db, "alias", "default") conn = self.db.connection - vendor = getattr(conn, 'vendor', 'unknown') + vendor = getattr(conn, "vendor", "unknown") params = { - 'vendor': vendor, - 'alias': alias, - 'sql': self.db.ops.last_executed_query( - self.cursor, sql, self._quote_params(params)), - 'duration': duration, - 'raw_sql': sql, - 'params': _params, - 'start_time': start_time, - 'stop_time': stop_time, - 'is_slow': duration > 10, - 'is_select': sql.lower().strip().startswith('select'), + "vendor": vendor, + "alias": alias, + "sql": self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params) + ), + "duration": duration, + "raw_sql": sql, + "params": _params, + "start_time": start_time, + "stop_time": stop_time, + "is_slow": duration > 10, + "is_select": sql.lower().strip().startswith("select"), } - if vendor == 'postgresql': + if vendor == "postgresql": # If an erroneous query was ran on the connection, it might # be in a state where checking isolation_level raises an # exception. try: iso_level = conn.isolation_level except conn.InternalError: - iso_level = 'unknown' - params.update({ - 'trans_id': self.logger.get_transaction_id(alias), - 'trans_status': conn.get_transaction_status(), - 'iso_level': iso_level, - 'encoding': conn.encoding, - }) + iso_level = "unknown" + params.update( + { + "trans_id": self.logger.get_transaction_id(alias), + "trans_status": conn.get_transaction_status(), + "iso_level": iso_level, + "encoding": conn.encoding, + } + ) _sql = DjangoDebugSQL(**params) # We keep `sql` to maintain backwards compatibility diff --git a/graphene_django/debug/sql/types.py b/graphene_django/debug/sql/types.py index 6ae4d31..850ced4 100644 --- a/graphene_django/debug/sql/types.py +++ b/graphene_django/debug/sql/types.py @@ -2,19 +2,53 @@ from graphene import Boolean, Float, ObjectType, String class DjangoDebugSQL(ObjectType): - vendor = String() - alias = String() - sql = String() - duration = Float() - raw_sql = String() - params = String() - start_time = Float() - stop_time = Float() - is_slow = Boolean() - is_select = Boolean() + class Meta: + description = ( + "Represents a single database query made to a Django managed DB." + ) + + vendor = String( + required=True, + description=( + "The type of database being used (e.g. postrgesql, mysql, sqlite)." + ), + ) + alias = String( + required=True, + description="The Django database alias (e.g. 'default').", + ) + sql = String(description="The actual SQL sent to this database.") + duration = Float( + required=True, + description="Duration of this database query in seconds.", + ) + raw_sql = String( + required=True, + description="The raw SQL of this query, without params.", + ) + params = String( + required=True, + description="JSON encoded database query parameters.", + ) + start_time = Float( + required=True, + description="Start time of this database query.", + ) + stop_time = Float( + required=True, + description="Stop time of this database query.", + ) + is_slow = Boolean( + required=True, + description="Whether this database query took more than 10 seconds.", + ) + is_select = Boolean( + required=True, + description="Whether this database query was a SELECT.", + ) # Postgres - trans_id = String() - trans_status = String() - iso_level = String() - encoding = String() + trans_id = String(description="Postgres transaction ID if available.") + trans_status = String(description="Postgres transaction status if available.") + iso_level = String(description="Postgres isolation level if available.") + encoding = String(description="Postgres connection encoding if available.") diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py index 72747b2..f2ef096 100644 --- a/graphene_django/debug/tests/test_query.py +++ b/graphene_django/debug/tests/test_query.py @@ -12,31 +12,31 @@ from ..types import DjangoDebug class context(object): pass + # from examples.starwars_django.models import Character pytestmark = pytest.mark.django_db def test_should_query_field(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_reporter(self, info, **args): return Reporter.objects.first() - query = ''' + query = """ query ReporterQuery { reporter { lastName @@ -47,43 +47,40 @@ def test_should_query_field(): } } } - ''' + """ expected = { - 'reporter': { - 'lastName': 'ABA', + "reporter": {"lastName": "ABA"}, + "__debug": { + "sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}] }, - '__debug': { - 'sql': [{ - 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) - }] - } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors assert result.data == expected def test_should_query_list(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = graphene.List(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters { lastName @@ -94,45 +91,38 @@ def test_should_query_list(): } } } - ''' + """ expected = { - 'allReporters': [{ - 'lastName': 'ABA', - }, { - 'lastName': 'Griffin', - }], - '__debug': { - 'sql': [{ - 'rawSql': str(Reporter.objects.all().query) - }] - } + "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}], + "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors assert result.data == expected def test_should_query_connection(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters(first:1) { edges { @@ -147,48 +137,41 @@ def test_should_query_connection(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'lastName': 'ABA', - } - }] - }, - } + """ + expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}} schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors - assert result.data['allReporters'] == expected['allReporters'] - assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] + assert result.data["allReporters"] == expected["allReporters"] + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) - assert result.data['__debug']['sql'][1]['rawSql'] == query + assert result.data["__debug"]["sql"][1]["rawSql"] == query def test_should_query_connectionfilter(): from ...filter import DjangoFilterConnectionField - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): - all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name']) + all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"]) s = graphene.String(resolver=lambda *_: "S") - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters(first:1) { edges { @@ -203,20 +186,14 @@ def test_should_query_connectionfilter(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'lastName': 'ABA', - } - }] - }, - } + """ + expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}} schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors - assert result.data['allReporters'] == expected['allReporters'] - assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] + assert result.data["allReporters"] == expected["allReporters"] + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) - assert result.data['__debug']['sql'][1]['rawSql'] == query + assert result.data["__debug"]["sql"][1]["rawSql"] == query diff --git a/graphene_django/debug/types.py b/graphene_django/debug/types.py index 0d3701d..cda5725 100644 --- a/graphene_django/debug/types.py +++ b/graphene_django/debug/types.py @@ -4,4 +4,10 @@ from .sql.types import DjangoDebugSQL class DjangoDebug(ObjectType): - sql = List(DjangoDebugSQL) + class Meta: + description = "Debugging information for the current query." + + sql = List( + DjangoDebugSQL, + description="Executed SQL queries for this API query.", + ) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index e755b93..1ecce45 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -13,7 +13,6 @@ from .utils import maybe_queryset class DjangoListField(Field): - def __init__(self, _type, *args, **kwargs): super(DjangoListField, self).__init__(List(_type), *args, **kwargs) @@ -30,25 +29,28 @@ class DjangoListField(Field): class DjangoConnectionField(ConnectionField): - def __init__(self, *args, **kwargs): - self.on = kwargs.pop('on', False) + self.on = kwargs.pop("on", False) self.max_limit = kwargs.pop( - 'max_limit', - graphene_settings.RELAY_CONNECTION_MAX_LIMIT + "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT ) self.enforce_first_or_last = kwargs.pop( - 'enforce_first_or_last', - graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST + "enforce_first_or_last", + graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) super(DjangoConnectionField, self).__init__(*args, **kwargs) @property def type(self): from .types import DjangoObjectType + _type = super(ConnectionField, self).type - assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types" - assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) + assert issubclass( + _type, DjangoObjectType + ), "DjangoConnectionField only accepts DjangoObjectType types" + assert _type._meta.connection, "The type {} doesn't have a connection".format( + _type.__name__ + ) return _type._meta.connection @property @@ -67,6 +69,10 @@ class DjangoConnectionField(ConnectionField): @classmethod def merge_querysets(cls, default_queryset, queryset): + if default_queryset.query.distinct and not queryset.query.distinct: + queryset = queryset.distinct() + elif queryset.query.distinct and not default_queryset.query.distinct: + default_queryset = default_queryset.distinct() return queryset & default_queryset @classmethod @@ -96,28 +102,37 @@ class DjangoConnectionField(ConnectionField): return connection @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, info, **args): - first = args.get('first') - last = args.get('last') + def connection_resolver( + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + root, + info, + **args + ): + first = args.get("first") + last = args.get("last") if enforce_first_or_last: assert first or last, ( - 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + "You must provide a `first` or `last` value to properly paginate the `{}` connection." ).format(info.field_name) if max_limit: if first: assert first <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." ).format(first, info.field_name, max_limit) - args['first'] = min(first, max_limit) + args["first"] = min(first, max_limit) if last: assert last <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' + "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." ).format(last, info.field_name, max_limit) - args['last'] = min(last, max_limit) + args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args) @@ -134,5 +149,5 @@ class DjangoConnectionField(ConnectionField): self.type, self.get_manager(), self.max_limit, - self.enforce_first_or_last + self.enforce_first_or_last, ) diff --git a/graphene_django/filter/__init__.py b/graphene_django/filter/__init__.py index 24fae60..daafe56 100644 --- a/graphene_django/filter/__init__.py +++ b/graphene_django/filter/__init__.py @@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED if not DJANGO_FILTER_INSTALLED: warnings.warn( "Use of django filtering requires the django-filter package " - "be installed. You can do so using `pip install django-filter`", ImportWarning + "be installed. You can do so using `pip install django-filter`", + ImportWarning, ) else: from .fields import DjangoFilterConnectionField from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter - __all__ = ['DjangoFilterConnectionField', - 'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter'] + __all__ = [ + "DjangoFilterConnectionField", + "GlobalIDFilter", + "GlobalIDMultipleChoiceFilter", + ] diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 06e81c2..cb42543 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class class DjangoFilterConnectionField(DjangoConnectionField): - - def __init__(self, type, fields=None, order_by=None, - extra_filter_meta=None, filterset_class=None, - *args, **kwargs): + def __init__( + self, + type, + fields=None, + order_by=None, + extra_filter_meta=None, + filterset_class=None, + *args, + **kwargs + ): self._fields = fields self._provided_filterset_class = filterset_class self._filterset_class = None @@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField): def filterset_class(self): if not self._filterset_class: fields = self._fields or self.node_type._meta.filter_fields - meta = dict(model=self.model, - fields=fields) + meta = dict(model=self.model, fields=fields) if self._extra_filter_meta: meta.update(self._extra_filter_meta) - self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta) + self._filterset_class = get_filterset_class( + self._provided_filterset_class, **meta + ) return self._filterset_class @@ -52,28 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField): # See related PR: https://github.com/graphql-python/graphene-django/pull/126 - assert not (default_queryset.query.low_mark and queryset.query.low_mark), ( - 'Received two sliced querysets (low mark) in the connection, please slice only in one.' - ) - assert not (default_queryset.query.high_mark and queryset.query.high_mark), ( - 'Received two sliced querysets (high mark) in the connection, please slice only in one.' - ) + assert not ( + default_queryset.query.low_mark and queryset.query.low_mark + ), "Received two sliced querysets (low mark) in the connection, please slice only in one." + assert not ( + default_queryset.query.high_mark and queryset.query.high_mark + ), "Received two sliced querysets (high mark) in the connection, please slice only in one." low = default_queryset.query.low_mark or queryset.query.low_mark high = default_queryset.query.high_mark or queryset.query.high_mark default_queryset.query.clear_limits() - queryset = super(DjangoFilterConnectionField, cls).merge_querysets(default_queryset, queryset) + queryset = super(DjangoFilterConnectionField, cls).merge_querysets( + default_queryset, queryset + ) queryset.query.set_limits(low, high) return queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, filterset_class, filtering_args, - root, info, **args): + def connection_resolver( + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + filterset_class, + filtering_args, + root, + info, + **args + ): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( data=filter_kwargs, queryset=default_manager.get_queryset(), - request=info.context + request=info.context, ).qs return super(DjangoFilterConnectionField, cls).connection_resolver( @@ -96,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField): self.max_limit, self.enforce_first_or_last, self.filterset_class, - self.filtering_args + self.filtering_args, ) diff --git a/graphene_django/filter/filterset.py b/graphene_django/filter/filterset.py index c716b05..4059083 100644 --- a/graphene_django/filter/filterset.py +++ b/graphene_django/filter/filterset.py @@ -1,8 +1,7 @@ import itertools from django.db import models -from django.utils.text import capfirst -from django_filters import Filter, MultipleChoiceFilter +from django_filters import Filter, MultipleChoiceFilter, VERSION from django_filters.filterset import BaseFilterSet, FilterSet from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS @@ -15,7 +14,10 @@ class GlobalIDFilter(Filter): field_class = GlobalIDFormField def filter(self, qs, value): - _type, _id = from_global_id(value) + """ Convert the filter value to a primary key before filtering """ + _id = None + if value is not None: + _, _id = from_global_id(value) return super(GlobalIDFilter, self).filter(qs, _id) @@ -28,71 +30,76 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): GRAPHENE_FILTER_SET_OVERRIDES = { - models.AutoField: { - 'filter_class': GlobalIDFilter, - }, - models.OneToOneField: { - 'filter_class': GlobalIDFilter, - }, - models.ForeignKey: { - 'filter_class': GlobalIDFilter, - }, - models.ManyToManyField: { - 'filter_class': GlobalIDMultipleChoiceFilter, - } + models.AutoField: {"filter_class": GlobalIDFilter}, + models.OneToOneField: {"filter_class": GlobalIDFilter}, + models.ForeignKey: {"filter_class": GlobalIDFilter}, + models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter}, + models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter}, + models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter}, } class GrapheneFilterSetMixin(BaseFilterSet): - FILTER_DEFAULTS = dict(itertools.chain( - FILTER_FOR_DBFIELD_DEFAULTS.items(), - GRAPHENE_FILTER_SET_OVERRIDES.items() - )) + """ A django_filters.filterset.BaseFilterSet with default filter overrides + to handle global IDs """ - @classmethod - def filter_for_reverse_field(cls, f, name): - """Handles retrieving filters for reverse relationships + FILTER_DEFAULTS = dict( + itertools.chain( + FILTER_FOR_DBFIELD_DEFAULTS.items(), + GRAPHENE_FILTER_SET_OVERRIDES.items() + ) + ) - We override the default implementation so that we can handle - Global IDs (the default implementation expects database - primary keys) - """ - rel = f.field.rel - default = { - 'name': name, - 'label': capfirst(rel.related_name) - } - if rel.multiple: - # For to-many relationships - return GlobalIDMultipleChoiceFilter(**default) - else: - # For to-one relationships - return GlobalIDFilter(**default) + +# To support a Django 1.11 + Python 2.7 combination django-filter must be +# < 2.x.x. To support the earlier version of django-filter, the +# filter_for_reverse_field method must be present on GrapheneFilterSetMixin and +# must not be present for later versions of django-filter. +if VERSION[0] < 2: + from django.utils.text import capfirst + + class GrapheneFilterSetMixinPython2(GrapheneFilterSetMixin): + + @classmethod + def filter_for_reverse_field(cls, f, name): + """Handles retrieving filters for reverse relationships + We override the default implementation so that we can handle + Global IDs (the default implementation expects database + primary keys) + """ + try: + rel = f.field.remote_field + except AttributeError: + rel = f.field.rel + default = {"name": name, "label": capfirst(rel.related_name)} + if rel.multiple: + # For to-many relationships + return GlobalIDMultipleChoiceFilter(**default) + else: + # For to-one relationships + return GlobalIDFilter(**default) + + GrapheneFilterSetMixin = GrapheneFilterSetMixinPython2 def setup_filterset(filterset_class): """ Wrap a provided filterset in Graphene-specific functionality """ return type( - 'Graphene{}'.format(filterset_class.__name__), + "Graphene{}".format(filterset_class.__name__), (filterset_class, GrapheneFilterSetMixin), {}, ) -def custom_filterset_factory(model, filterset_base_class=FilterSet, - **meta): +def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta): """ Create a filterset for the given model using the provided meta data """ - meta.update({ - 'model': model, - }) - meta_class = type(str('Meta'), (object,), meta) + meta.update({"model": model}) + meta_class = type(str("Meta"), (object,), meta) filterset = type( - str('%sFilterSet' % model._meta.object_name), + str("%sFilterSet" % model._meta.object_name), (filterset_base_class, GrapheneFilterSetMixin), - { - 'Meta': meta_class - } + {"Meta": meta_class}, ) return filterset diff --git a/graphene_django/filter/tests/filters.py b/graphene_django/filter/tests/filters.py index 4a3fbaa..359d2ba 100644 --- a/graphene_django/filter/tests/filters.py +++ b/graphene_django/filter/tests/filters.py @@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter class ArticleFilter(django_filters.FilterSet): - class Meta: model = Article fields = { - 'headline': ['exact', 'icontains'], - 'pub_date': ['gt', 'lt', 'exact'], - 'reporter': ['exact'], + "headline": ["exact", "icontains"], + "pub_date": ["gt", "lt", "exact"], + "reporter": ["exact"], } - order_by = OrderingFilter(fields=('pub_date',)) + order_by = OrderingFilter(fields=("pub_date",)) class ReporterFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['first_name', 'last_name', 'email', 'pets'] + fields = ["first_name", "last_name", "email", "pets"] - order_by = OrderingFilter(fields=('pub_date',)) + order_by = OrderingFilter(fields=("pub_date",)) class PetFilter(django_filters.FilterSet): - class Meta: model = Pet - fields = ['name'] + fields = ["name"] diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 258da3e..f9ef0ae 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -5,8 +5,7 @@ import pytest from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.forms import (GlobalIDFormField, - GlobalIDMultipleChoiceField) +from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.utils import DJANGO_FILTER_INSTALLED @@ -20,36 +19,43 @@ if DJANGO_FILTER_INSTALLED: import django_filters from django_filters import FilterSet, NumberFilter - from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField, - GlobalIDMultipleChoiceFilter) - from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter + from graphene_django.filter import ( + GlobalIDFilter, + DjangoFilterConnectionField, + GlobalIDMultipleChoiceFilter, + ) + from graphene_django.filter.tests.filters import ( + ArticleFilter, + PetFilter, + ReporterFilter, + ) else: - pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible')) + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) pytestmark.append(pytest.mark.django_db) if DJANGO_FILTER_INSTALLED: - class ArticleNode(DjangoObjectType): + class ArticleNode(DjangoObjectType): class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('headline', ) - + interfaces = (Node,) + filter_fields = ("headline",) class ReporterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - + interfaces = (Node,) class PetNode(DjangoObjectType): - class Meta: model = Pet - interfaces = (Node, ) + interfaces = (Node,) # schema = Schema() @@ -59,58 +65,47 @@ def get_args(field): def assert_arguments(field, *arguments): - ignore = ('after', 'before', 'first', 'last', 'order_by') + ignore = ("after", "before", "first", "last", "order_by") args = get_args(field) - actual = [ - name - for name in args - if name not in ignore and not name.startswith('_') - ] - assert set(arguments) == set(actual), \ - 'Expected arguments ({}) did not match actual ({})'.format( - arguments, - actual - ) + actual = [name for name in args if name not in ignore and not name.startswith("_")] + assert set(arguments) == set( + actual + ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual) def assert_orderable(field): args = get_args(field) - assert 'order_by' in args, \ - 'Field cannot be ordered' + assert "order_by" in args, "Field cannot be ordered" def assert_not_orderable(field): args = get_args(field) - assert 'order_by' not in args, \ - 'Field can be ordered' + assert "order_by" not in args, "Field can be ordered" def test_filter_explicit_filterset_arguments(): field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter) - assert_arguments(field, - 'headline', 'headline__icontains', - 'pub_date', 'pub_date__gt', 'pub_date__lt', - 'reporter', - ) + assert_arguments( + field, + "headline", + "headline__icontains", + "pub_date", + "pub_date__gt", + "pub_date__lt", + "reporter", + ) def test_filter_shortcut_filterset_arguments_list(): - field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter']) - assert_arguments(field, - 'pub_date', - 'reporter', - ) + field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"]) + assert_arguments(field, "pub_date", "reporter") def test_filter_shortcut_filterset_arguments_dict(): - field = DjangoFilterConnectionField(ArticleNode, fields={ - 'headline': ['exact', 'icontains'], - 'reporter': ['exact'], - }) - assert_arguments(field, - 'headline', 'headline__icontains', - 'reporter', - ) + field = DjangoFilterConnectionField( + ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]} + ) + assert_arguments(field, "headline", "headline__icontains", "reporter") def test_filter_explicit_filterset_orderable(): @@ -134,15 +129,14 @@ def test_filter_explicit_filterset_not_orderable(): def test_filter_shortcut_filterset_extra_meta(): - field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ - 'exclude': ('headline', ) - }) - assert 'headline' not in field.filterset_class.get_fields() + field = DjangoFilterConnectionField( + ArticleNode, extra_filter_meta={"exclude": ("headline",)} + ) + assert "headline" not in field.filterset_class.get_fields() def test_filter_shortcut_filterset_context(): class ArticleContextFilter(django_filters.FilterSet): - class Meta: model = Article exclude = set() @@ -153,17 +147,31 @@ def test_filter_shortcut_filterset_context(): return qs.filter(reporter=self.request.reporter) class Query(ObjectType): - context_articles = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleContextFilter) + context_articles = DjangoFilterConnectionField( + ArticleNode, filterset_class=ArticleContextFilter + ) - r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') - r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') - Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1, editor=r1) - Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2, editor=r2) + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) class context(object): reporter = r2 - query = ''' + query = """ query { contextArticles { edges { @@ -173,42 +181,39 @@ def test_filter_shortcut_filterset_context(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query, context_value=context()) assert not result.errors - assert len(result.data['contextArticles']['edges']) == 1 - assert result.data['contextArticles']['edges'][0]['node']['headline'] == 'a2' + assert len(result.data["contextArticles"]["edges"]) == 1 + assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2" def test_filter_filterset_information_on_meta(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] field = DjangoFilterConnectionField(ReporterFilterNode) - assert_arguments(field, 'first_name', 'articles') + assert_arguments(field, "first_name", "articles") assert_not_orderable(field) def test_filter_filterset_information_on_meta_related(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] class ArticleFilterNode(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ['headline', 'reporter'] + interfaces = (Node,) + filter_fields = ["headline", "reporter"] class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) @@ -217,25 +222,23 @@ def test_filter_filterset_information_on_meta_related(): article = Field(ArticleFilterNode) schema = Schema(query=Query) - articles_field = ReporterFilterNode._meta.fields['articles'].get_type() - assert_arguments(articles_field, 'headline', 'reporter') + articles_field = ReporterFilterNode._meta.fields["articles"].get_type() + assert_arguments(articles_field, "headline", "reporter") assert_not_orderable(articles_field) def test_filter_filterset_related_results(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] class ArticleFilterNode(DjangoObjectType): - class Meta: - interfaces = (Node, ) + interfaces = (Node,) model = Article - filter_fields = ['headline', 'reporter'] + filter_fields = ["headline", "reporter"] class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) @@ -243,12 +246,22 @@ def test_filter_filterset_related_results(): reporter = Field(ReporterFilterNode) article = Field(ArticleFilterNode) - r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') - r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') - Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1) - Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2) + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + ) - query = ''' + query = """ query { allReporters { edges { @@ -264,123 +277,134 @@ def test_filter_filterset_related_results(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query) assert not result.errors # We should only get back a single article for each reporter - assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1 - assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1 + assert ( + len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1 + ) + assert ( + len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1 + ) def test_global_id_field_implicit(): - field = DjangoFilterConnectionField(ArticleNode, fields=['id']) + field = DjangoFilterConnectionField(ArticleNode, fields=["id"]) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['id'] + id_filter = filterset_class.base_filters["id"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_global_id_field_explicit(): class ArticleIdFilter(django_filters.FilterSet): - class Meta: model = Article - fields = ['id'] + fields = ["id"] field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['id'] + id_filter = filterset_class.base_filters["id"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_filterset_descriptions(): class ArticleIdFilter(django_filters.FilterSet): - class Meta: model = Article - fields = ['id'] + fields = ["id"] - max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time") + max_time = django_filters.NumberFilter( + method="filter_max_time", label="The maximum time" + ) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) - max_time = field.args['max_time'] + max_time = field.args["max_time"] assert isinstance(max_time, Argument) assert max_time.type == Float - assert max_time.description == 'The maximum time' + assert max_time.description == "The maximum time" def test_global_id_field_relation(): - field = DjangoFilterConnectionField(ArticleNode, fields=['reporter']) + field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"]) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['reporter'] + id_filter = filterset_class.base_filters["reporter"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_global_id_multiple_field_implicit(): - field = DjangoFilterConnectionField(ReporterNode, fields=['pets']) + field = DjangoFilterConnectionField(ReporterNode, fields=["pets"]) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['pets'] + multiple_filter = filterset_class.base_filters["pets"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_explicit(): class ReporterPetsFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['pets'] + fields = ["pets"] - field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + field = DjangoFilterConnectionField( + ReporterNode, filterset_class=ReporterPetsFilter + ) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['pets'] + multiple_filter = filterset_class.base_filters["pets"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_implicit_reverse(): - field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) + field = DjangoFilterConnectionField(ReporterNode, fields=["articles"]) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['articles'] + multiple_filter = filterset_class.base_filters["articles"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_explicit_reverse(): class ReporterPetsFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['articles'] + fields = ["articles"] - field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + field = DjangoFilterConnectionField( + ReporterNode, filterset_class=ReporterPetsFilter + ) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['articles'] + multiple_filter = filterset_class.base_filters["articles"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_filter_filterset_related_results(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = { - 'first_name': ['icontains'] - } + interfaces = (Node,) + filter_fields = {"first_name": ["icontains"]} class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com') - r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com') - r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com') + r1 = Reporter.objects.create( + first_name="A test user", last_name="Last Name", email="test1@test.com" + ) + r2 = Reporter.objects.create( + first_name="Other test user", + last_name="Other Last Name", + email="test2@test.com", + ) + r3 = Reporter.objects.create( + first_name="Random", last_name="RandomLast", email="random@test.com" + ) - query = ''' + query = """ query { allReporters(firstName_Icontains: "test") { edges { @@ -390,12 +414,12 @@ def test_filter_filterset_related_results(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query) assert not result.errors # We should only get two reporters - assert len(result.data['allReporters']['edges']) == 2 + assert len(result.data["allReporters"]["edges"]) == 2 def test_recursive_filter_connection(): @@ -407,77 +431,73 @@ def test_recursive_filter_connection(): class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode + assert ( + ReporterFilterNode._meta.fields["child_reporters"].node_type + == ReporterFilterNode + ) def test_should_query_filter_node_limit(): class ReporterFilter(FilterSet): - limit = NumberFilter(method='filter_limit') + limit = NumberFilter(method="filter_limit") def filter_limit(self, queryset, name, value): return queryset[:value] class Meta: model = Reporter - fields = ['first_name', ] + fields = ["first_name"] class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(ObjectType): all_reporters = DjangoFilterConnectionField( - ReporterType, - filterset_class=ReporterFilter + ReporterType, filterset_class=ReporterFilter ) def resolve_all_reporters(self, info, **args): - return Reporter.objects.order_by('a_choice') + return Reporter.objects.order_by("a_choice") Reporter.objects.create( - first_name='Bob', - last_name='Doe', - email='bobdoe@example.com', - a_choice=2 + first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.now(), + pub_date_time=datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.now(), + pub_date_time=datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(limit: 1) { edges { @@ -496,24 +516,23 @@ def test_should_query_filter_node_limit(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjI=', - 'firstName': 'John', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==', - 'lang': 'ES' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjI=", + "firstName": "John", + "articles": { + "edges": [ + {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}} + ] + }, } } - }] + ] } } @@ -524,45 +543,37 @@ def test_should_query_filter_node_limit(): def test_should_query_filter_node_double_limit_raises(): class ReporterFilter(FilterSet): - limit = NumberFilter(method='filter_limit') + limit = NumberFilter(method="filter_limit") def filter_limit(self, queryset, name, value): return queryset[:value] class Meta: model = Reporter - fields = ['first_name', ] + fields = ["first_name"] class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(ObjectType): all_reporters = DjangoFilterConnectionField( - ReporterType, - filterset_class=ReporterFilter + ReporterType, filterset_class=ReporterFilter ) def resolve_all_reporters(self, info, **args): - return Reporter.objects.order_by('a_choice')[:2] + return Reporter.objects.order_by("a_choice")[:2] Reporter.objects.create( - first_name='Bob', - last_name='Doe', - email='bobdoe@example.com', - a_choice=2 + first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(limit: 1) { edges { @@ -573,41 +584,40 @@ def test_should_query_filter_node_double_limit_raises(): } } } - ''' + """ result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Received two sliced querysets (high mark) in the connection, please slice only in one.' + "Received two sliced querysets (high mark) in the connection, please slice only in one." ) + def test_order_by_is_perserved(): class ReporterType(DjangoObjectType): class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) filter_fields = () class Query(ObjectType): - all_reporters = DjangoFilterConnectionField(ReporterType, reverse_order=Boolean()) + all_reporters = DjangoFilterConnectionField( + ReporterType, reverse_order=Boolean() + ) def resolve_all_reporters(self, info, reverse_order=False, **args): - reporters = Reporter.objects.order_by('first_name') + reporters = Reporter.objects.order_by("first_name") if reverse_order: return reporters.reverse() - + return reporters - Reporter.objects.create( - first_name='b', - ) - r = Reporter.objects.create( - first_name='a', - ) + Reporter.objects.create(first_name="b") + r = Reporter.objects.create(first_name="a") schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(first: 1) { edges { @@ -617,23 +627,14 @@ def test_order_by_is_perserved(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'firstName': 'a', - } - }] - } - } + """ + expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected - - reverse_query = ''' + reverse_query = """ query NodeFilteringQuery { allReporters(first: 1, reverseOrder: true) { edges { @@ -643,33 +644,26 @@ def test_order_by_is_perserved(): } } } - ''' + """ - reverse_expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'firstName': 'b', - } - }] - } - } + reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}} reverse_result = schema.execute(reverse_query) assert not reverse_result.errors assert reverse_result.data == reverse_expected + def test_annotation_is_perserved(): class ReporterType(DjangoObjectType): full_name = String() - + def resolve_full_name(instance, info, **args): return instance.full_name class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) filter_fields = () class Query(ObjectType): @@ -677,17 +671,16 @@ def test_annotation_is_perserved(): def resolve_all_reporters(self, info, **args): return Reporter.objects.annotate( - full_name=Concat('first_name', Value(' '), 'last_name', output_field=TextField()) + full_name=Concat( + "first_name", Value(" "), "last_name", output_field=TextField() + ) ) - Reporter.objects.create( - first_name='John', - last_name='Doe', - ) + Reporter.objects.create(first_name="John", last_name="Doe") schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(first: 1) { edges { @@ -697,16 +690,8 @@ def test_annotation_is_perserved(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'fullName': 'John Doe', - } - }] - } - } + """ + expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}} result = schema.execute(query) diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 6b938ce..cfa5621 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -8,7 +8,7 @@ def get_filtering_args_from_filterset(filterset_class, type): a Graphene Field. These arguments will be available to filter against in the GraphQL """ - from ..form_converter import convert_form_field + from ..forms.converter import convert_form_field args = {} for name, filter_field in six.iteritems(filterset_class.base_filters): diff --git a/graphene_django/forms/__init__.py b/graphene_django/forms/__init__.py new file mode 100644 index 0000000..066eec4 --- /dev/null +++ b/graphene_django/forms/__init__.py @@ -0,0 +1 @@ +from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa diff --git a/graphene_django/form_converter.py b/graphene_django/forms/converter.py similarity index 72% rename from graphene_django/form_converter.py rename to graphene_django/forms/converter.py index 195c8c4..87180b2 100644 --- a/graphene_django/form_converter.py +++ b/graphene_django/forms/converter.py @@ -1,24 +1,24 @@ from django import forms -from django.forms.fields import BaseTemporalField +from django.core.exceptions import ImproperlyConfigured -from graphene import ID, Boolean, Float, Int, List, String, UUID +from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from .utils import import_single_dispatch +from ..utils import import_single_dispatch + singledispatch = import_single_dispatch() @singledispatch def convert_form_field(field): - raise Exception( + raise ImproperlyConfigured( "Don't know how to convert the Django form field %s (%s) " - "to Graphene type" % - (field, field.__class__) + "to Graphene type" % (field, field.__class__) ) -@convert_form_field.register(BaseTemporalField) +@convert_form_field.register(forms.fields.BaseTemporalField) @convert_form_field.register(forms.CharField) @convert_form_field.register(forms.EmailField) @convert_form_field.register(forms.SlugField) @@ -63,6 +63,21 @@ def convert_form_field_to_list(field): return List(ID, required=field.required) +@convert_form_field.register(forms.DateField) +def convert_form_field_to_date(field): + return Date(description=field.help_text, required=field.required) + + +@convert_form_field.register(forms.DateTimeField) +def convert_form_field_to_datetime(field): + return DateTime(description=field.help_text, required=field.required) + + +@convert_form_field.register(forms.TimeField) +def convert_form_field_to_time(field): + return Time(description=field.help_text, required=field.required) + + @convert_form_field.register(forms.ModelChoiceField) @convert_form_field.register(GlobalIDFormField) def convert_form_field_to_id(field): diff --git a/graphene_django/forms.py b/graphene_django/forms/forms.py similarity index 72% rename from graphene_django/forms.py rename to graphene_django/forms/forms.py index a54f0a5..14e68c8 100644 --- a/graphene_django/forms.py +++ b/graphene_django/forms/forms.py @@ -8,9 +8,7 @@ from graphql_relay import from_global_id class GlobalIDFormField(Field): - default_error_messages = { - 'invalid': _('Invalid ID specified.'), - } + default_error_messages = {"invalid": _("Invalid ID specified.")} def clean(self, value): if not value and not self.required: @@ -19,21 +17,21 @@ class GlobalIDFormField(Field): try: _type, _id = from_global_id(value) except (TypeError, ValueError, UnicodeDecodeError, binascii.Error): - raise ValidationError(self.error_messages['invalid']) + raise ValidationError(self.error_messages["invalid"]) try: CharField().clean(_id) CharField().clean(_type) except ValidationError: - raise ValidationError(self.error_messages['invalid']) + raise ValidationError(self.error_messages["invalid"]) return value class GlobalIDMultipleChoiceField(MultipleChoiceField): default_error_messages = { - 'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'), - 'invalid_list': _('Enter a list of values.'), + "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."), + "invalid_list": _("Enter a list of values."), } def valid_value(self, value): diff --git a/graphene_django/forms/mutation.py b/graphene_django/forms/mutation.py new file mode 100644 index 0000000..63ea089 --- /dev/null +++ b/graphene_django/forms/mutation.py @@ -0,0 +1,192 @@ +# from django import forms +from collections import OrderedDict + +import graphene +from graphene import Field, InputField +from graphene.relay.mutation import ClientIDMutation +from graphene.types.mutation import MutationOptions + +# from graphene.types.inputobjecttype import ( +# InputObjectTypeOptions, +# InputObjectType, +# ) +from graphene.types.utils import yank_fields_from_attrs +from graphene_django.registry import get_global_registry + +from .converter import convert_form_field +from .types import ErrorType + + +def fields_for_form(form, only_fields, exclude_fields): + fields = OrderedDict() + for name, field in form.fields.items(): + is_not_in_only = only_fields and name not in only_fields + is_excluded = ( + name + in exclude_fields # or + # name in already_created_fields + ) + + if is_not_in_only or is_excluded: + continue + + fields[name] = convert_form_field(field) + return fields + + +class BaseDjangoFormMutation(ClientIDMutation): + class Meta: + abstract = True + + @classmethod + def mutate_and_get_payload(cls, root, info, **input): + form = cls.get_form(root, info, **input) + + if form.is_valid(): + return cls.perform_mutate(form, info) + else: + errors = [ + ErrorType(field=key, messages=value) + for key, value in form.errors.items() + ] + + return cls(errors=errors) + + @classmethod + def get_form(cls, root, info, **input): + form_kwargs = cls.get_form_kwargs(root, info, **input) + return cls._meta.form_class(**form_kwargs) + + @classmethod + def get_form_kwargs(cls, root, info, **input): + kwargs = {"data": input} + + pk = input.pop("id", None) + if pk: + instance = cls._meta.model._default_manager.get(pk=pk) + kwargs["instance"] = instance + + return kwargs + + +# class DjangoFormInputObjectTypeOptions(InputObjectTypeOptions): +# form_class = None + + +# class DjangoFormInputObjectType(InputObjectType): +# class Meta: +# abstract = True + +# @classmethod +# def __init_subclass_with_meta__(cls, form_class=None, +# only_fields=(), exclude_fields=(), _meta=None, **options): +# if not _meta: +# _meta = DjangoFormInputObjectTypeOptions(cls) +# assert isinstance(form_class, forms.Form), ( +# 'form_class must be an instance of django.forms.Form' +# ) +# _meta.form_class = form_class +# form = form_class() +# fields = fields_for_form(form, only_fields, exclude_fields) +# super(DjangoFormInputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, fields=fields, **options) + + +class DjangoFormMutationOptions(MutationOptions): + form_class = None + + +class DjangoFormMutation(BaseDjangoFormMutation): + class Meta: + abstract = True + + errors = graphene.List(ErrorType) + + @classmethod + def __init_subclass_with_meta__( + cls, form_class=None, only_fields=(), exclude_fields=(), **options + ): + + if not form_class: + raise Exception("form_class is required for DjangoFormMutation") + + form = form_class() + input_fields = fields_for_form(form, only_fields, exclude_fields) + output_fields = fields_for_form(form, only_fields, exclude_fields) + + _meta = DjangoFormMutationOptions(cls) + _meta.form_class = form_class + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) + + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(DjangoFormMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options + ) + + @classmethod + def perform_mutate(cls, form, info): + form.save() + return cls(errors=[]) + + +class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions): + model = None + return_field_name = None + + +class DjangoModelFormMutation(BaseDjangoFormMutation): + class Meta: + abstract = True + + errors = graphene.List(ErrorType) + + @classmethod + def __init_subclass_with_meta__( + cls, + form_class=None, + model=None, + return_field_name=None, + only_fields=(), + exclude_fields=(), + **options + ): + + if not form_class: + raise Exception("form_class is required for DjangoModelFormMutation") + + if not model: + model = form_class._meta.model + + if not model: + raise Exception("model is required for DjangoModelFormMutation") + + form = form_class() + input_fields = fields_for_form(form, only_fields, exclude_fields) + if "id" not in exclude_fields: + input_fields["id"] = graphene.ID() + + registry = get_global_registry() + model_type = registry.get_type_for_model(model) + return_field_name = return_field_name + if not return_field_name: + model_name = model.__name__ + return_field_name = model_name[:1].lower() + model_name[1:] + + output_fields = OrderedDict() + output_fields[return_field_name] = graphene.Field(model_type) + + _meta = DjangoModelDjangoFormMutationOptions(cls) + _meta.form_class = form_class + _meta.model = model + _meta.return_field_name = return_field_name + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) + + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options + ) + + @classmethod + def perform_mutate(cls, form, info): + obj = form.save() + kwargs = {cls._meta.return_field_name: obj} + return cls(errors=[], **kwargs) diff --git a/graphene_django/forms/tests/__init__.py b/graphene_django/forms/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphene_django/forms/tests/test_converter.py b/graphene_django/forms/tests/test_converter.py new file mode 100644 index 0000000..955b952 --- /dev/null +++ b/graphene_django/forms/tests/test_converter.py @@ -0,0 +1,114 @@ +from django import forms +from py.test import raises + +import graphene +from graphene import ( + String, + Int, + Boolean, + Float, + ID, + UUID, + List, + NonNull, + DateTime, + Date, + Time, +) + +from ..converter import convert_form_field + + +def assert_conversion(django_field, graphene_field, *args): + field = django_field(*args, help_text="Custom Help Text") + graphene_type = convert_form_field(field) + assert isinstance(graphene_type, graphene_field) + field = graphene_type.Field() + assert field.description == "Custom Help Text" + return field + + +def test_should_unknown_django_field_raise_exception(): + with raises(Exception) as excinfo: + convert_form_field(None) + assert "Don't know how to convert the Django form field" in str(excinfo.value) + + +def test_should_date_convert_date(): + assert_conversion(forms.DateField, Date) + + +def test_should_time_convert_time(): + assert_conversion(forms.TimeField, Time) + + +def test_should_date_time_convert_date_time(): + assert_conversion(forms.DateTimeField, DateTime) + + +def test_should_char_convert_string(): + assert_conversion(forms.CharField, String) + + +def test_should_email_convert_string(): + assert_conversion(forms.EmailField, String) + + +def test_should_slug_convert_string(): + assert_conversion(forms.SlugField, String) + + +def test_should_url_convert_string(): + assert_conversion(forms.URLField, String) + + +def test_should_choice_convert_string(): + assert_conversion(forms.ChoiceField, String) + + +def test_should_base_field_convert_string(): + assert_conversion(forms.Field, String) + + +def test_should_regex_convert_string(): + assert_conversion(forms.RegexField, String, "[0-9]+") + + +def test_should_uuid_convert_string(): + if hasattr(forms, "UUIDField"): + assert_conversion(forms.UUIDField, UUID) + + +def test_should_integer_convert_int(): + assert_conversion(forms.IntegerField, Int) + + +def test_should_boolean_convert_boolean(): + field = assert_conversion(forms.BooleanField, Boolean) + assert isinstance(field.type, NonNull) + + +def test_should_nullboolean_convert_boolean(): + field = assert_conversion(forms.NullBooleanField, Boolean) + assert not isinstance(field.type, NonNull) + + +def test_should_float_convert_float(): + assert_conversion(forms.FloatField, Float) + + +def test_should_decimal_convert_float(): + assert_conversion(forms.DecimalField, Float) + + +def test_should_multiple_choice_convert_connectionorlist(): + field = forms.ModelMultipleChoiceField(queryset=None) + graphene_type = convert_form_field(field) + assert isinstance(graphene_type, List) + assert graphene_type.of_type == ID + + +def test_should_manytoone_convert_connectionorlist(): + field = forms.ModelChoiceField(queryset=None) + graphene_type = convert_form_field(field) + assert isinstance(graphene_type, ID) diff --git a/graphene_django/forms/tests/test_mutation.py b/graphene_django/forms/tests/test_mutation.py new file mode 100644 index 0000000..df0ffd5 --- /dev/null +++ b/graphene_django/forms/tests/test_mutation.py @@ -0,0 +1,141 @@ +from django import forms +from django.test import TestCase +from py.test import raises + +from graphene_django.tests.models import Pet, Film, FilmDetails +from ..mutation import DjangoFormMutation, DjangoModelFormMutation + + +class MyForm(forms.Form): + text = forms.CharField() + + +class PetForm(forms.ModelForm): + class Meta: + model = Pet + fields = '__all__' + + +def test_needs_form_class(): + with raises(Exception) as exc: + + class MyMutation(DjangoFormMutation): + pass + + assert exc.value.args[0] == "form_class is required for DjangoFormMutation" + + +def test_has_output_fields(): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + + assert "errors" in MyMutation._meta.fields + + +def test_has_input_fields(): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + + assert "text" in MyMutation.Input._meta.fields + + +class ModelFormMutationTests(TestCase): + def test_default_meta_fields(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + self.assertEqual(PetMutation._meta.model, Pet) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("pet", PetMutation._meta.fields) + + def test_default_input_meta_fields(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + self.assertEqual(PetMutation._meta.model, Pet) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("name", PetMutation.Input._meta.fields) + self.assertIn("client_mutation_id", PetMutation.Input._meta.fields) + self.assertIn("id", PetMutation.Input._meta.fields) + + def test_exclude_fields_input_meta_fields(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + exclude_fields = ['id'] + + self.assertEqual(PetMutation._meta.model, Pet) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("name", PetMutation.Input._meta.fields) + self.assertIn("age", PetMutation.Input._meta.fields) + self.assertIn("client_mutation_id", PetMutation.Input._meta.fields) + self.assertNotIn("id", PetMutation.Input._meta.fields) + + def test_return_field_name_is_camelcased(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + model = FilmDetails + + self.assertEqual(PetMutation._meta.model, FilmDetails) + self.assertEqual(PetMutation._meta.return_field_name, "filmDetails") + + def test_custom_return_field_name(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + model = Film + return_field_name = "animal" + + self.assertEqual(PetMutation._meta.model, Film) + self.assertEqual(PetMutation._meta.return_field_name, "animal") + self.assertIn("animal", PetMutation._meta.fields) + + def test_model_form_mutation_mutate(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + pet = Pet.objects.create(name="Axel", age=10) + + result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia", age=10) + + self.assertEqual(Pet.objects.count(), 1) + pet.refresh_from_db() + self.assertEqual(pet.name, "Mia") + self.assertEqual(result.errors, []) + + def test_model_form_mutation_updates_existing_(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + result = PetMutation.mutate_and_get_payload(None, None, name="Mia", age=10) + + self.assertEqual(Pet.objects.count(), 1) + pet = Pet.objects.get() + self.assertEqual(pet.name, "Mia") + self.assertEqual(pet.age, 10) + self.assertEqual(result.errors, []) + + def test_model_form_mutation_mutate_invalid_form(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + result = PetMutation.mutate_and_get_payload(None, None) + + # A pet was not created + self.assertEqual(Pet.objects.count(), 0) + + + fields_w_error = [e.field for e in result.errors] + self.assertEqual(len(result.errors), 2) + self.assertIn("name", fields_w_error) + self.assertEqual(result.errors[0].messages, ["This field is required."]) + self.assertIn("age", fields_w_error) + self.assertEqual(result.errors[1].messages, ["This field is required."]) diff --git a/graphene_django/forms/types.py b/graphene_django/forms/types.py new file mode 100644 index 0000000..1fe33f3 --- /dev/null +++ b/graphene_django/forms/types.py @@ -0,0 +1,6 @@ +import graphene + + +class ErrorType(graphene.ObjectType): + field = graphene.String() + messages = graphene.List(graphene.String) diff --git a/graphene_django/management/commands/graphql_schema.py b/graphene_django/management/commands/graphql_schema.py index 14ecf0c..4e526ec 100644 --- a/graphene_django/management/commands/graphql_schema.py +++ b/graphene_django/management/commands/graphql_schema.py @@ -7,43 +7,45 @@ from graphene_django.settings import graphene_settings class CommandArguments(BaseCommand): - def add_arguments(self, parser): parser.add_argument( - '--schema', + "--schema", type=str, - dest='schema', + dest="schema", default=graphene_settings.SCHEMA, - help='Django app containing schema to dump, e.g. myproject.core.schema.schema') + help="Django app containing schema to dump, e.g. myproject.core.schema.schema", + ) parser.add_argument( - '--out', + "--out", type=str, - dest='out', + dest="out", default=graphene_settings.SCHEMA_OUTPUT, - help='Output file (default: schema.json)') + help="Output file, --out=- prints to stdout (default: schema.json)", + ) parser.add_argument( - '--indent', + "--indent", type=int, - dest='indent', + dest="indent", default=graphene_settings.SCHEMA_INDENT, - help='Output file indent (default: None)') + help="Output file indent (default: None)", + ) class Command(CommandArguments): - help = 'Dump Graphene schema JSON to file' + help = "Dump Graphene schema JSON to file" can_import_settings = True def save_file(self, out, schema_dict, indent): - with open(out, 'w') as outfile: + with open(out, "w") as outfile: json.dump(schema_dict, outfile, indent=indent) def handle(self, *args, **options): - options_schema = options.get('schema') + options_schema = options.get("schema") if options_schema and type(options_schema) is str: - module_str, schema_name = options_schema.rsplit('.', 1) + module_str, schema_name = options_schema.rsplit(".", 1) mod = importlib.import_module(module_str) schema = getattr(mod, schema_name) @@ -53,16 +55,21 @@ class Command(CommandArguments): else: schema = graphene_settings.SCHEMA - out = options.get('out') or graphene_settings.SCHEMA_OUTPUT + out = options.get("out") or graphene_settings.SCHEMA_OUTPUT if not schema: - raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema') + raise CommandError( + "Specify schema on GRAPHENE.SCHEMA setting or by using --schema" + ) - indent = options.get('indent') - schema_dict = {'data': schema.introspect()} - self.save_file(out, schema_dict, indent) + indent = options.get("indent") + schema_dict = {"data": schema.introspect()} + if out == '-': + self.stdout.write(json.dumps(schema_dict, indent=indent)) + else: + self.save_file(out, schema_dict, indent) - style = getattr(self, 'style', None) - success = getattr(style, 'SUCCESS', lambda x: x) + style = getattr(self, "style", None) + success = getattr(style, "SUCCESS", lambda x: x) - self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out)) + self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out)) diff --git a/graphene_django/registry.py b/graphene_django/registry.py index b28268d..50a8ae5 100644 --- a/graphene_django/registry.py +++ b/graphene_django/registry.py @@ -1,20 +1,21 @@ - class Registry(object): - def __init__(self): self._registry = {} self._field_registry = {} def register(self, cls): from .types import DjangoObjectType + assert issubclass( - cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format( - cls.__name__) - assert cls._meta.registry == self, 'Registry for a Model have to match.' + cls, DjangoObjectType + ), 'Only DjangoObjectTypes can be registered, received "{}"'.format( + cls.__name__ + ) + assert cls._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) == cls, ( # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) # ) - if not getattr(cls._meta, 'skip_registry', False): + if not getattr(cls._meta, "skip_registry", False): self._registry[cls._meta.model] = cls def get_type_for_model(self, model): diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index a776eab..5e343aa 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -1,20 +1,21 @@ from collections import OrderedDict +from django.shortcuts import get_object_or_404 + import graphene from graphene.types import Field, InputField from graphene.types.mutation import MutationOptions from graphene.relay.mutation import ClientIDMutation -from graphene.types.objecttype import ( - yank_fields_from_attrs -) +from graphene.types.objecttype import yank_fields_from_attrs -from .serializer_converter import ( - convert_serializer_field -) +from .serializer_converter import convert_serializer_field from .types import ErrorType class SerializerMutationOptions(MutationOptions): + lookup_field = None + model_class = None + model_operations = ["create", "update"] serializer_class = None @@ -23,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals for name, field in serializer.fields.items(): is_not_in_only = only_fields and name not in only_fields is_excluded = ( - name in exclude_fields # or + name + in exclude_fields # or # name in already_created_fields ) @@ -39,37 +41,86 @@ class SerializerMutation(ClientIDMutation): abstract = True errors = graphene.List( - ErrorType, - description='May contain more than one error for same field.' + ErrorType, description="May contain more than one error for same field." ) @classmethod - def __init_subclass_with_meta__(cls, serializer_class=None, - only_fields=(), exclude_fields=(), **options): + def __init_subclass_with_meta__( + cls, + lookup_field=None, + serializer_class=None, + model_class=None, + model_operations=["create", "update"], + only_fields=(), + exclude_fields=(), + **options + ): if not serializer_class: - raise Exception('serializer_class is required for the SerializerMutation') + raise Exception("serializer_class is required for the SerializerMutation") + + if "update" not in model_operations and "create" not in model_operations: + raise Exception('model_operations must contain "create" and/or "update"') serializer = serializer_class() - input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) - output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) + if model_class is None: + serializer_meta = getattr(serializer_class, "Meta", None) + if serializer_meta: + model_class = getattr(serializer_meta, "model", None) + + if lookup_field is None and model_class: + lookup_field = model_class._meta.pk.name + + input_fields = fields_for_serializer( + serializer, only_fields, exclude_fields, is_input=True + ) + output_fields = fields_for_serializer( + serializer, only_fields, exclude_fields, is_input=False + ) _meta = SerializerMutationOptions(cls) + _meta.lookup_field = lookup_field + _meta.model_operations = model_operations _meta.serializer_class = serializer_class - _meta.fields = yank_fields_from_attrs( - output_fields, - _as=Field, + _meta.model_class = model_class + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) + + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(SerializerMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options ) - input_fields = yank_fields_from_attrs( - input_fields, - _as=InputField, - ) - super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) + @classmethod + def get_serializer_kwargs(cls, root, info, **input): + lookup_field = cls._meta.lookup_field + model_class = cls._meta.model_class + + if model_class: + if "update" in cls._meta.model_operations and lookup_field in input: + instance = get_object_or_404( + model_class, **{lookup_field: input[lookup_field]} + ) + elif "create" in cls._meta.model_operations: + instance = None + else: + raise Exception( + 'Invalid update operation. Input parameter "{}" required.'.format( + lookup_field + ) + ) + + return { + "instance": instance, + "data": input, + "context": {"request": info.context}, + } + + return {"data": input, "context": {"request": info.context}} @classmethod def mutate_and_get_payload(cls, root, info, **input): - serializer = cls._meta.serializer_class(data=input) + kwargs = cls.get_serializer_kwargs(root, info, **input) + serializer = cls._meta.serializer_class(**kwargs) if serializer.is_valid(): return cls.perform_mutate(serializer, info) diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py index 6a57f5f..9f8e516 100644 --- a/graphene_django/rest_framework/serializer_converter.py +++ b/graphene_django/rest_framework/serializer_converter.py @@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True): graphql_type = get_graphene_type_from_serializer_field(field) args = [] - kwargs = { - 'description': field.help_text, - 'required': is_input and field.required, - } + kwargs = {"description": field.help_text, "required": is_input and field.required} # if it is a tuple or a list it means that we are returning # the graphql type and the child type if isinstance(graphql_type, (list, tuple)): - kwargs['of_type'] = graphql_type[1] + kwargs["of_type"] = graphql_type[1] graphql_type = graphql_type[0] if isinstance(field, serializers.ModelSerializer): @@ -46,6 +43,15 @@ def convert_serializer_field(field, is_input=True): global_registry = get_global_registry() field_model = field.Meta.model args = [global_registry.get_type_for_model(field_model)] + elif isinstance(field, serializers.ListSerializer): + field = field.child + if is_input: + kwargs["of_type"] = convert_serializer_to_input_type(field.__class__) + else: + del kwargs["of_type"] + global_registry = get_global_registry() + field_model = field.Meta.model + args = [global_registry.get_type_for_model(field_model)] return graphql_type(*args, **kwargs) @@ -59,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class): } return type( - '{}Input'.format(serializer.__class__.__name__), + "{}Input".format(serializer.__class__.__name__), (graphene.InputObjectType,), - items + items, ) @@ -75,6 +81,12 @@ def convert_serializer_to_field(field): return graphene.Field +@get_graphene_type_from_serializer_field.register(serializers.ListSerializer) +def convert_list_serializer_to_field(field): + child_type = get_graphene_type_from_serializer_field(field.child) + return (graphene.List, child_type) + + @get_graphene_type_from_serializer_field.register(serializers.IntegerField) def convert_serializer_field_to_int(field): return graphene.Int @@ -92,9 +104,13 @@ def convert_serializer_field_to_float(field): @get_graphene_type_from_serializer_field.register(serializers.DateTimeField) +def convert_serializer_field_to_datetime_time(field): + return graphene.types.datetime.DateTime + + @get_graphene_type_from_serializer_field.register(serializers.DateField) def convert_serializer_field_to_date_time(field): - return graphene.types.datetime.DateTime + return graphene.types.datetime.Date @get_graphene_type_from_serializer_field.register(serializers.TimeField) diff --git a/graphene_django/rest_framework/tests/test_field_converter.py b/graphene_django/rest_framework/tests/test_field_converter.py index 623cf58..6fa4ca8 100644 --- a/graphene_django/rest_framework/tests/test_field_converter.py +++ b/graphene_django/rest_framework/tests/test_field_converter.py @@ -1,8 +1,10 @@ import copy -from rest_framework import serializers -from py.test import raises import graphene +from django.db import models +from graphene import InputObjectType +from py.test import raises +from rest_framework import serializers from ..serializer_converter import convert_serializer_field from ..types import DictType @@ -14,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs): # Remove `source=` from the field declaration. # since we are reusing the same child in when testing the required attribute - if 'child' in kwargs: - kwargs['child'] = copy.deepcopy(kwargs['child']) + if "child" in kwargs: + kwargs["child"] = copy.deepcopy(kwargs["child"]) field = rest_framework_field(**kwargs) @@ -23,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs): def assert_conversion(rest_framework_field, graphene_field, **kwargs): - graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs) + graphene_type = _get_type( + rest_framework_field, help_text="Custom Help Text", **kwargs + ) assert isinstance(graphene_type, graphene_field) graphene_type_required = _get_type( - rest_framework_field, help_text='Custom Help Text', required=True, **kwargs + rest_framework_field, help_text="Custom Help Text", required=True, **kwargs ) assert isinstance(graphene_type_required, graphene_field) @@ -37,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs): def test_should_unknown_rest_framework_field_raise_exception(): with raises(Exception) as excinfo: convert_serializer_field(None) - assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) + assert "Don't know how to convert the serializer field" in str(excinfo.value) def test_should_char_convert_string(): @@ -65,20 +69,19 @@ def test_should_base_field_convert_string(): def test_should_regex_convert_string(): - assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+') + assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+") def test_should_uuid_convert_string(): - if hasattr(serializers, 'UUIDField'): + if hasattr(serializers, "UUIDField"): assert_conversion(serializers.UUIDField, graphene.String) def test_should_model_convert_field(): - class MyModelSerializer(serializers.ModelSerializer): class Meta: model = None - fields = '__all__' + fields = "__all__" assert_conversion(MyModelSerializer, graphene.Field, is_input=False) @@ -87,8 +90,8 @@ def test_should_date_time_convert_datetime(): assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime) -def test_should_date_convert_datetime(): - assert_conversion(serializers.DateField, graphene.types.datetime.DateTime) +def test_should_date_convert_date(): + assert_conversion(serializers.DateField, graphene.types.datetime.Date) def test_should_time_convert_time(): @@ -108,7 +111,9 @@ def test_should_float_convert_float(): def test_should_decimal_convert_float(): - assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2) + assert_conversion( + serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2 + ) def test_should_list_convert_to_list(): @@ -118,7 +123,7 @@ def test_should_list_convert_to_list(): field_a = assert_conversion( serializers.ListField, graphene.List, - child=serializers.IntegerField(min_value=0, max_value=100) + child=serializers.IntegerField(min_value=0, max_value=100), ) assert field_a.of_type == graphene.Int @@ -128,6 +133,34 @@ def test_should_list_convert_to_list(): assert field_b.of_type == graphene.String +def test_should_list_serializer_convert_to_list(): + class FooModel(models.Model): + pass + + class ChildSerializer(serializers.ModelSerializer): + class Meta: + model = FooModel + fields = "__all__" + + class ParentSerializer(serializers.ModelSerializer): + child = ChildSerializer(many=True) + + class Meta: + model = FooModel + fields = "__all__" + + converted_type = convert_serializer_field( + ParentSerializer().get_fields()["child"], is_input=True + ) + assert isinstance(converted_type, graphene.List) + + converted_type = convert_serializer_field( + ParentSerializer().get_fields()["child"], is_input=False + ) + assert isinstance(converted_type, graphene.List) + assert converted_type.of_type is None + + def test_should_dict_convert_dict(): assert_conversion(serializers.DictField, DictType) @@ -141,7 +174,7 @@ def test_should_file_convert_string(): def test_should_filepath_convert_string(): - assert_conversion(serializers.FilePathField, graphene.String, path='/') + assert_conversion(serializers.FilePathField, graphene.String, path="/") def test_should_ip_convert_string(): @@ -157,6 +190,8 @@ def test_should_json_convert_jsonstring(): def test_should_multiplechoicefield_convert_to_list_of_string(): - field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3]) + field = assert_conversion( + serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3] + ) assert field.of_type == graphene.String diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 491192a..4dccc18 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -1,6 +1,6 @@ import datetime -from graphene import Field +from graphene import Field, ResolveInfo from graphene.types.inputobjecttype import InputObjectType from py.test import raises from py.test import mark @@ -11,10 +11,30 @@ from ..models import MyFakeModel from ..mutation import SerializerMutation +def mock_info(): + return ResolveInfo( + None, + None, + None, + None, + schema=None, + fragments=None, + root_value=None, + operation=None, + variable_values=None, + context=None, + ) + + class MyModelSerializer(serializers.ModelSerializer): class Meta: model = MyFakeModel - fields = '__all__' + fields = "__all__" + + +class MyModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer class MySerializer(serializers.Serializer): @@ -27,10 +47,11 @@ class MySerializer(serializers.Serializer): def test_needs_serializer_class(): with raises(Exception) as exc: + class MyMutation(SerializerMutation): pass - assert str(exc.value) == 'serializer_class is required for the SerializerMutation' + assert str(exc.value) == "serializer_class is required for the SerializerMutation" def test_has_fields(): @@ -38,9 +59,9 @@ def test_has_fields(): class Meta: serializer_class = MySerializer - assert 'text' in MyMutation._meta.fields - assert 'model' in MyMutation._meta.fields - assert 'errors' in MyMutation._meta.fields + assert "text" in MyMutation._meta.fields + assert "model" in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields def test_has_input_fields(): @@ -48,25 +69,24 @@ def test_has_input_fields(): class Meta: serializer_class = MySerializer - assert 'text' in MyMutation.Input._meta.fields - assert 'model' in MyMutation.Input._meta.fields + assert "text" in MyMutation.Input._meta.fields + assert "model" in MyMutation.Input._meta.fields def test_exclude_fields(): class MyMutation(SerializerMutation): class Meta: serializer_class = MyModelSerializer - exclude_fields = ['created'] + exclude_fields = ["created"] - assert 'cool_name' in MyMutation._meta.fields - assert 'created' not in MyMutation._meta.fields - assert 'errors' in MyMutation._meta.fields - assert 'cool_name' in MyMutation.Input._meta.fields - assert 'created' not in MyMutation.Input._meta.fields + assert "cool_name" in MyMutation._meta.fields + assert "created" not in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields + assert "cool_name" in MyMutation.Input._meta.fields + assert "created" not in MyMutation.Input._meta.fields def test_nested_model(): - class MyFakeModelGrapheneType(DjangoObjectType): class Meta: model = MyFakeModel @@ -75,61 +95,85 @@ def test_nested_model(): class Meta: serializer_class = MySerializer - model_field = MyMutation._meta.fields['model'] + model_field = MyMutation._meta.fields["model"] assert isinstance(model_field, Field) assert model_field.type == MyFakeModelGrapheneType - model_input = MyMutation.Input._meta.fields['model'] + model_input = MyMutation.Input._meta.fields["model"] model_input_type = model_input._type.of_type assert issubclass(model_input_type, InputObjectType) - assert 'cool_name' in model_input_type._meta.fields - assert 'created' in model_input_type._meta.fields + assert "cool_name" in model_input_type._meta.fields + assert "created" in model_input_type._meta.fields def test_mutate_and_get_payload_success(): - class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer - result = MyMutation.mutate_and_get_payload(None, None, **{ - 'text': 'value', - 'model': { - 'cool_name': 'other_value' - } - }) + result = MyMutation.mutate_and_get_payload( + None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}} + ) assert result.errors is None @mark.django_db -def test_model_mutate_and_get_payload_success(): - class MyMutation(SerializerMutation): - class Meta: - serializer_class = MyModelSerializer - - result = MyMutation.mutate_and_get_payload(None, None, **{ - 'cool_name': 'Narf', - }) +def test_model_add_mutate_and_get_payload_success(): + result = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) assert result.errors is None - assert result.cool_name == 'Narf' + assert result.cool_name == "Narf" assert isinstance(result.created, datetime.datetime) -def test_mutate_and_get_payload_error(): +@mark.django_db +def test_model_update_mutate_and_get_payload_success(): + instance = MyFakeModel.objects.create(cool_name="Narf") + result = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"} + ) + assert result.errors is None + assert result.cool_name == "New Narf" + + +@mark.django_db +def test_model_invalid_update_mutate_and_get_payload_success(): + class InvalidModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ["update"] + + with raises(Exception) as exc: + result = InvalidModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) + + assert '"id" required' in str(exc.value) + + +def test_mutate_and_get_payload_error(): class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer # missing required fields - result = MyMutation.mutate_and_get_payload(None, None, **{}) + result = MyMutation.mutate_and_get_payload(None, mock_info(), **{}) assert len(result.errors) > 0 + def test_model_mutate_and_get_payload_error(): - - class MyMutation(SerializerMutation): - class Meta: - serializer_class = MyModelSerializer - # missing required fields - result = MyMutation.mutate_and_get_payload(None, None, **{}) + result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{}) assert len(result.errors) > 0 + + +def test_invalid_serializer_operations(): + with raises(Exception) as exc: + + class MyModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ["Add"] + + assert "model_operations" in str(exc.value) diff --git a/graphene_django/settings.py b/graphene_django/settings.py index 46d70ee..7cd750a 100644 --- a/graphene_django/settings.py +++ b/graphene_django/settings.py @@ -26,27 +26,22 @@ except ImportError: # Copied shamelessly from Django REST Framework DEFAULTS = { - 'SCHEMA': None, - 'SCHEMA_OUTPUT': 'schema.json', - 'SCHEMA_INDENT': None, - 'MIDDLEWARE': (), + "SCHEMA": None, + "SCHEMA_OUTPUT": "schema.json", + "SCHEMA_INDENT": None, + "MIDDLEWARE": (), # Set to True if the connection fields must have # either the first or last argument - 'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False, + "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False, # Max items returned in ConnectionFields / FilterConnectionFields - 'RELAY_CONNECTION_MAX_LIMIT': 100, + "RELAY_CONNECTION_MAX_LIMIT": 100, } if settings.DEBUG: - DEFAULTS['MIDDLEWARE'] += ( - 'graphene_django.debug.DjangoDebugMiddleware', - ) + DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",) # List of settings that may be in string import notation. -IMPORT_STRINGS = ( - 'MIDDLEWARE', - 'SCHEMA', -) +IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA") def perform_import(val, setting_name): @@ -69,12 +64,17 @@ def import_from_string(val, setting_name): """ try: # Nod to tastypie's use of importlib. - parts = val.split('.') - module_path, class_name = '.'.join(parts[:-1]), parts[-1] + parts = val.split(".") + module_path, class_name = ".".join(parts[:-1]), parts[-1] module = importlib.import_module(module_path) return getattr(module, class_name) except (ImportError, AttributeError) as e: - msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) + msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % ( + val, + setting_name, + e.__class__.__name__, + e, + ) raise ImportError(msg) @@ -96,8 +96,8 @@ class GrapheneSettings(object): @property def user_settings(self): - if not hasattr(self, '_user_settings'): - self._user_settings = getattr(settings, 'GRAPHENE', {}) + if not hasattr(self, "_user_settings"): + self._user_settings = getattr(settings, "GRAPHENE", {}) return self._user_settings def __getattr__(self, attr): @@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS) def reload_graphene_settings(*args, **kwargs): global graphene_settings - setting, value = kwargs['setting'], kwargs['value'] - if setting == 'GRAPHENE': + setting, value = kwargs["setting"], kwargs["value"] + if setting == "GRAPHENE": graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS) diff --git a/graphene_django/static/graphene_django/graphiql.js b/graphene_django/static/graphene_django/graphiql.js new file mode 100644 index 0000000..2be7e3c --- /dev/null +++ b/graphene_django/static/graphene_django/graphiql.js @@ -0,0 +1,99 @@ +(function() { + + // Parse the cookie value for a CSRF token + var csrftoken; + var cookies = ('; ' + document.cookie).split('; csrftoken='); + if (cookies.length == 2) + csrftoken = cookies.pop().split(';').shift(); + + // Collect the URL parameters + var parameters = {}; + window.location.hash.substr(1).split('&').forEach(function (entry) { + var eq = entry.indexOf('='); + if (eq >= 0) { + parameters[decodeURIComponent(entry.slice(0, eq))] = + decodeURIComponent(entry.slice(eq + 1)); + } + }); + // Produce a Location fragment string from a parameter object. + function locationQuery(params) { + return '#' + Object.keys(params).map(function (key) { + return encodeURIComponent(key) + '=' + + encodeURIComponent(params[key]); + }).join('&'); + } + // Derive a fetch URL from the current URL, sans the GraphQL parameters. + var graphqlParamNames = { + query: true, + variables: true, + operationName: true + }; + var otherParams = {}; + for (var k in parameters) { + if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) { + otherParams[k] = parameters[k]; + } + } + + var fetchURL = locationQuery(otherParams); + + // Defines a GraphQL fetcher using the fetch API. + function graphQLFetcher(graphQLParams) { + var headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json' + }; + if (csrftoken) { + headers['X-CSRFToken'] = csrftoken; + } + return fetch(fetchURL, { + method: 'post', + headers: headers, + body: JSON.stringify(graphQLParams), + credentials: 'include', + }).then(function (response) { + return response.text(); + }).then(function (responseBody) { + try { + return JSON.parse(responseBody); + } catch (error) { + return responseBody; + } + }); + } + // When the query and variables string is edited, update the URL bar so + // that it can be easily shared. + function onEditQuery(newQuery) { + parameters.query = newQuery; + updateURL(); + } + function onEditVariables(newVariables) { + parameters.variables = newVariables; + updateURL(); + } + function onEditOperationName(newOperationName) { + parameters.operationName = newOperationName; + updateURL(); + } + function updateURL() { + history.replaceState(null, null, locationQuery(parameters)); + } + var options = { + fetcher: graphQLFetcher, + onEditQuery: onEditQuery, + onEditVariables: onEditVariables, + onEditOperationName: onEditOperationName, + query: parameters.query, + } + if (parameters.variables) { + options.variables = parameters.variables; + } + if (parameters.operation_name) { + options.operationName = parameters.operation_name; + } + // Render into the body. + ReactDOM.render( + React.createElement(GraphiQL, options), + document.body + ); +})(); diff --git a/graphene_django/templates/graphene/graphiql.html b/graphene_django/templates/graphene/graphiql.html index 949b850..af11274 100644 --- a/graphene_django/templates/graphene/graphiql.html +++ b/graphene_django/templates/graphene/graphiql.html @@ -5,6 +5,7 @@ exploring GraphQL. If you wish to receive JSON, provide the header "Accept: application/json" or add "&raw" to the end of the URL within a browser. --> +{% load static %} @@ -16,108 +17,13 @@ add "&raw" to the end of the URL within a browser. width: 100%; } - - - - - + + + + + - + diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 406d184..4fe546d 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -3,56 +3,103 @@ from __future__ import absolute_import from django.db import models from django.utils.translation import ugettext_lazy as _ -CHOICES = ( - (1, 'this'), - (2, _('that')) -) +CHOICES = ((1, "this"), (2, _("that"))) class Pet(models.Model): name = models.CharField(max_length=30) + age = models.PositiveIntegerField() class FilmDetails(models.Model): location = models.CharField(max_length=30) - film = models.OneToOneField('Film', related_name='details') + film = models.OneToOneField( + "Film", on_delete=models.CASCADE, related_name="details" + ) class Film(models.Model): - reporters = models.ManyToManyField('Reporter', - related_name='films') + genre = models.CharField( + max_length=2, + help_text="Genre", + choices=[("do", "Documentary"), ("ot", "Other")], + default="ot", + ) + reporters = models.ManyToManyField("Reporter", related_name="films") + class DoeReporterManager(models.Manager): def get_queryset(self): return super(DoeReporterManager, self).get_queryset().filter(last_name="Doe") + class Reporter(models.Model): first_name = models.CharField(max_length=30) last_name = models.CharField(max_length=30) email = models.EmailField() - pets = models.ManyToManyField('self') + pets = models.ManyToManyField("self") a_choice = models.CharField(max_length=30, choices=CHOICES) objects = models.Manager() doe_objects = DoeReporterManager() - def __str__(self): # __unicode__ on Python 2 + reporter_type = models.IntegerField( + "Reporter Type", + null=True, + blank=True, + choices=[(1, u"Regular"), (2, u"CNN Reporter")], + ) + + def __str__(self): # __unicode__ on Python 2 return "%s %s" % (self.first_name, self.last_name) + def __init__(self, *args, **kwargs): + """ + Override the init method so that during runtime, Django + can know that this object can be a CNNReporter by casting + it to the proxy model. Otherwise, as far as Django knows, + when a CNNReporter is pulled from the database, it is still + of type Reporter. This was added to test proxy model support. + """ + super(Reporter, self).__init__(*args, **kwargs) + if self.reporter_type == 2: # quick and dirty way without enums + self.__class__ = CNNReporter + + +class CNNReporter(Reporter): + """ + This class is a proxy model for Reporter, used for testing + proxy model support + """ + + class Meta: + proxy = True + class Article(models.Model): headline = models.CharField(max_length=100) pub_date = models.DateField() - reporter = models.ForeignKey(Reporter, related_name='articles') - editor = models.ForeignKey(Reporter, related_name='edited_articles_+') - lang = models.CharField(max_length=2, help_text='Language', choices=[ - ('es', 'Spanish'), - ('en', 'English') - ], default='es') - importance = models.IntegerField('Importance', null=True, blank=True, - choices=[(1, u'Very important'), (2, u'Not as important')]) + pub_date_time = models.DateTimeField() + reporter = models.ForeignKey( + Reporter, on_delete=models.CASCADE, related_name="articles" + ) + editor = models.ForeignKey( + Reporter, on_delete=models.CASCADE, related_name="edited_articles_+" + ) + lang = models.CharField( + max_length=2, + help_text="Language", + choices=[("es", "Spanish"), ("en", "English")], + default="es", + ) + importance = models.IntegerField( + "Importance", + null=True, + blank=True, + choices=[(1, u"Very important"), (2, u"Not as important")], + ) - def __str__(self): # __unicode__ on Python 2 + def __str__(self): # __unicode__ on Python 2 return self.headline class Meta: - ordering = ('headline',) + ordering = ("headline",) diff --git a/graphene_django/tests/schema.py b/graphene_django/tests/schema.py index 3134604..d0d9e47 100644 --- a/graphene_django/tests/schema.py +++ b/graphene_django/tests/schema.py @@ -6,10 +6,9 @@ from .models import Article, Reporter class Character(DjangoObjectType): - class Meta: model = Reporter - interfaces = (relay.Node, ) + interfaces = (relay.Node,) def get_node(self, info, id): pass @@ -20,7 +19,7 @@ class Human(DjangoObjectType): class Meta: model = Article - interfaces = (relay.Node, ) + interfaces = (relay.Node,) def resolve_raises(self, info): raise Exception("This field should raise exception") diff --git a/graphene_django/tests/schema_view.py b/graphene_django/tests/schema_view.py index c750433..9b3bd1e 100644 --- a/graphene_django/tests/schema_view.py +++ b/graphene_django/tests/schema_view.py @@ -12,10 +12,10 @@ class QueryRoot(ObjectType): raise Exception("Throws!") def resolve_request(self, info): - return info.context.GET.get('q') + return info.context.GET.get("q") def resolve_test(self, info, who=None): - return 'Hello %s' % (who or 'World') + return "Hello %s" % (who or "World") class MutationRoot(ObjectType): diff --git a/graphene_django/tests/test_command.py b/graphene_django/tests/test_command.py index caf9f7a..ff6e6e1 100644 --- a/graphene_django/tests/test_command.py +++ b/graphene_django/tests/test_command.py @@ -3,8 +3,8 @@ from mock import patch from six import StringIO -@patch('graphene_django.management.commands.graphql_schema.Command.save_file') +@patch("graphene_django.management.commands.graphql_schema.Command.save_file") def test_generate_file_on_call_graphql_schema(savefile_mock, settings): out = StringIO() - management.call_command('graphql_schema', schema='', stdout=out) + management.call_command("graphql_schema", schema="", stdout=out) assert "Successfully dumped GraphQL schema to schema.json" in out.getvalue() diff --git a/graphene_django/tests/test_converter.py b/graphene_django/tests/test_converter.py index cd366b1..094f593 100644 --- a/graphene_django/tests/test_converter.py +++ b/graphene_django/tests/test_converter.py @@ -6,7 +6,7 @@ from py.test import raises import graphene from graphene.relay import ConnectionField, Node -from graphene.types.datetime import DateTime, Time +from graphene.types.datetime import DateTime, Date, Time from graphene.types.json import JSONString from ..compat import JSONField, ArrayField, HStoreField, RangeField, MissingType @@ -20,11 +20,11 @@ from .models import Article, Film, FilmDetails, Reporter def assert_conversion(django_field, graphene_field, *args, **kwargs): - field = django_field(help_text='Custom Help Text', null=True, *args, **kwargs) + field = django_field(help_text="Custom Help Text", null=True, *args, **kwargs) graphene_type = convert_django_field(field) assert isinstance(graphene_type, graphene_field) field = graphene_type.Field() - assert field.description == 'Custom Help Text' + assert field.description == "Custom Help Text" nonnull_field = django_field(null=False, *args, **kwargs) if not nonnull_field.null: nonnull_graphene_type = convert_django_field(nonnull_field) @@ -37,11 +37,15 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs): def test_should_unknown_django_field_raise_exception(): with raises(Exception) as excinfo: convert_django_field(None) - assert 'Don\'t know how to convert the Django field' in str(excinfo.value) + assert "Don't know how to convert the Django field" in str(excinfo.value) + + +def test_should_date_time_convert_string(): + assert_conversion(models.DateTimeField, DateTime) def test_should_date_convert_string(): - assert_conversion(models.DateField, DateTime) + assert_conversion(models.DateField, Date) def test_should_time_convert_string(): @@ -84,6 +88,10 @@ def test_should_image_convert_string(): assert_conversion(models.ImageField, graphene.String) +def test_should_url_convert_string(): + assert_conversion(models.FilePathField, graphene.String) + + def test_should_auto_convert_id(): assert_conversion(models.AutoField, graphene.ID, primary_key=True) @@ -126,70 +134,69 @@ def test_should_nullboolean_convert_boolean(): def test_field_with_choices_convert_enum(): - field = models.CharField(help_text='Language', choices=( - ('es', 'Spanish'), - ('en', 'English') - )) + field = models.CharField( + help_text="Language", choices=(("es", "Spanish"), ("en", "English")) + ) class TranslatedModel(models.Model): language = field class Meta: - app_label = 'test' + app_label = "test" graphene_type = convert_django_field_with_choices(field) assert isinstance(graphene_type, graphene.Enum) - assert graphene_type._meta.name == 'TranslatedModelLanguage' - assert graphene_type._meta.enum.__members__['ES'].value == 'es' - assert graphene_type._meta.enum.__members__['ES'].description == 'Spanish' - assert graphene_type._meta.enum.__members__['EN'].value == 'en' - assert graphene_type._meta.enum.__members__['EN'].description == 'English' + assert graphene_type._meta.name == "TranslatedModelLanguage" + assert graphene_type._meta.enum.__members__["ES"].value == "es" + assert graphene_type._meta.enum.__members__["ES"].description == "Spanish" + assert graphene_type._meta.enum.__members__["EN"].value == "en" + assert graphene_type._meta.enum.__members__["EN"].description == "English" def test_field_with_grouped_choices(): - field = models.CharField(help_text='Language', choices=( - ('Europe', ( - ('es', 'Spanish'), - ('en', 'English'), - )), - )) + field = models.CharField( + help_text="Language", + choices=(("Europe", (("es", "Spanish"), ("en", "English"))),), + ) class GroupedChoicesModel(models.Model): language = field class Meta: - app_label = 'test' + app_label = "test" convert_django_field_with_choices(field) def test_field_with_choices_gettext(): - field = models.CharField(help_text='Language', choices=( - ('es', _('Spanish')), - ('en', _('English')) - )) + field = models.CharField( + help_text="Language", choices=(("es", _("Spanish")), ("en", _("English"))) + ) class TranslatedChoicesModel(models.Model): language = field class Meta: - app_label = 'test' + app_label = "test" convert_django_field_with_choices(field) def test_field_with_choices_collision(): - field = models.CharField(help_text='Timezone', choices=( - ('Etc/GMT+1+2', 'Fake choice to produce double collision'), - ('Etc/GMT+1', 'Greenwich Mean Time +1'), - ('Etc/GMT-1', 'Greenwich Mean Time -1'), - )) + field = models.CharField( + help_text="Timezone", + choices=( + ("Etc/GMT+1+2", "Fake choice to produce double collision"), + ("Etc/GMT+1", "Greenwich Mean Time +1"), + ("Etc/GMT-1", "Greenwich Mean Time -1"), + ), + ) class CollisionChoicesModel(models.Model): timezone = field class Meta: - app_label = 'test' + app_label = "test" convert_django_field_with_choices(field) @@ -206,11 +213,12 @@ def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist_list(): class A(DjangoObjectType): - class Meta: model = Reporter - graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) + graphene_field = convert_django_field( + Reporter._meta.local_many_to_many[0], A._meta.registry + ) assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, graphene.Field) @@ -220,12 +228,13 @@ def test_should_manytomany_convert_connectionorlist_list(): def test_should_manytomany_convert_connectionorlist_connection(): class A(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) - graphene_field = convert_django_field(Reporter._meta.local_many_to_many[0], A._meta.registry) + graphene_field = convert_django_field( + Reporter._meta.local_many_to_many[0], A._meta.registry + ) assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, ConnectionField) @@ -233,16 +242,12 @@ def test_should_manytomany_convert_connectionorlist_connection(): def test_should_manytoone_convert_connectionorlist(): - # Django 1.9 uses 'rel', <1.9 uses 'related - related = getattr(Reporter.articles, 'rel', None) or \ - getattr(Reporter.articles, 'related') - class A(DjangoObjectType): - class Meta: model = Article - graphene_field = convert_django_field(related, A._meta.registry) + graphene_field = convert_django_field(Reporter.articles.rel, + A._meta.registry) assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, graphene.Field) @@ -251,57 +256,53 @@ def test_should_manytoone_convert_connectionorlist(): def test_should_onetoone_reverse_convert_model(): - # Django 1.9 uses 'rel', <1.9 uses 'related - related = getattr(Film.details, 'rel', None) or \ - getattr(Film.details, 'related') - class A(DjangoObjectType): - class Meta: model = FilmDetails - graphene_field = convert_django_field(related, A._meta.registry) + graphene_field = convert_django_field(Film.details.related, + A._meta.registry) assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, graphene.Field) assert dynamic_field.type == A -@pytest.mark.skipif(ArrayField is MissingType, - reason="ArrayField should exist") +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") def test_should_postgres_array_convert_list(): - field = assert_conversion(ArrayField, graphene.List, models.CharField(max_length=100)) + field = assert_conversion( + ArrayField, graphene.List, models.CharField(max_length=100) + ) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) assert field.type.of_type.of_type == graphene.String -@pytest.mark.skipif(ArrayField is MissingType, - reason="ArrayField should exist") +@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") def test_should_postgres_array_multiple_convert_list(): - field = assert_conversion(ArrayField, graphene.List, ArrayField(models.CharField(max_length=100))) + field = assert_conversion( + ArrayField, graphene.List, ArrayField(models.CharField(max_length=100)) + ) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) assert isinstance(field.type.of_type.of_type, graphene.List) assert field.type.of_type.of_type.of_type == graphene.String -@pytest.mark.skipif(HStoreField is MissingType, - reason="HStoreField should exist") +@pytest.mark.skipif(HStoreField is MissingType, reason="HStoreField should exist") def test_should_postgres_hstore_convert_string(): assert_conversion(HStoreField, JSONString) -@pytest.mark.skipif(JSONField is MissingType, - reason="JSONField should exist") +@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist") def test_should_postgres_json_convert_string(): assert_conversion(JSONField, JSONString) -@pytest.mark.skipif(RangeField is MissingType, - reason="RangeField should exist") +@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist") def test_should_postgres_range_convert_list(): from django.contrib.postgres.fields import IntegerRangeField + field = assert_conversion(IntegerRangeField, graphene.List) assert isinstance(field.type, graphene.NonNull) assert isinstance(field.type.of_type, graphene.List) diff --git a/graphene_django/tests/test_form_converter.py b/graphene_django/tests/test_form_converter.py deleted file mode 100644 index 5a13554..0000000 --- a/graphene_django/tests/test_form_converter.py +++ /dev/null @@ -1,103 +0,0 @@ -from django import forms -from py.test import raises - -import graphene -from graphene import ID, List, NonNull - -from ..form_converter import convert_form_field -from .models import Reporter - - -def assert_conversion(django_field, graphene_field, *args): - field = django_field(*args, help_text='Custom Help Text') - graphene_type = convert_form_field(field) - assert isinstance(graphene_type, graphene_field) - field = graphene_type.Field() - assert field.description == 'Custom Help Text' - return field - - -def test_should_unknown_django_field_raise_exception(): - with raises(Exception) as excinfo: - convert_form_field(None) - assert 'Don\'t know how to convert the Django form field' in str(excinfo.value) - - -def test_should_date_convert_string(): - assert_conversion(forms.DateField, graphene.String) - - -def test_should_time_convert_string(): - assert_conversion(forms.TimeField, graphene.String) - - -def test_should_date_time_convert_string(): - assert_conversion(forms.DateTimeField, graphene.String) - - -def test_should_char_convert_string(): - assert_conversion(forms.CharField, graphene.String) - - -def test_should_email_convert_string(): - assert_conversion(forms.EmailField, graphene.String) - - -def test_should_slug_convert_string(): - assert_conversion(forms.SlugField, graphene.String) - - -def test_should_url_convert_string(): - assert_conversion(forms.URLField, graphene.String) - - -def test_should_choice_convert_string(): - assert_conversion(forms.ChoiceField, graphene.String) - - -def test_should_base_field_convert_string(): - assert_conversion(forms.Field, graphene.String) - - -def test_should_regex_convert_string(): - assert_conversion(forms.RegexField, graphene.String, '[0-9]+') - - -def test_should_uuid_convert_string(): - if hasattr(forms, 'UUIDField'): - assert_conversion(forms.UUIDField, graphene.UUID) - - -def test_should_integer_convert_int(): - assert_conversion(forms.IntegerField, graphene.Int) - - -def test_should_boolean_convert_boolean(): - field = assert_conversion(forms.BooleanField, graphene.Boolean) - assert isinstance(field.type, NonNull) - - -def test_should_nullboolean_convert_boolean(): - field = assert_conversion(forms.NullBooleanField, graphene.Boolean) - assert not isinstance(field.type, NonNull) - - -def test_should_float_convert_float(): - assert_conversion(forms.FloatField, graphene.Float) - - -def test_should_decimal_convert_float(): - assert_conversion(forms.DecimalField, graphene.Float) - - -def test_should_multiple_choice_convert_connectionorlist(): - field = forms.ModelMultipleChoiceField(Reporter.objects.all()) - graphene_type = convert_form_field(field) - assert isinstance(graphene_type, List) - assert graphene_type.of_type == ID - - -def test_should_manytoone_convert_connectionorlist(): - field = forms.ModelChoiceField(Reporter.objects.all()) - graphene_type = convert_form_field(field) - assert isinstance(graphene_type, graphene.ID) diff --git a/graphene_django/tests/test_forms.py b/graphene_django/tests/test_forms.py index b15e866..fa6628d 100644 --- a/graphene_django/tests/test_forms.py +++ b/graphene_django/tests/test_forms.py @@ -1,7 +1,7 @@ from django.core.exceptions import ValidationError from py.test import raises -from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField +from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField # 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc' @@ -9,24 +9,24 @@ from ..forms import GlobalIDFormField,GlobalIDMultipleChoiceField def test_global_id_valid(): field = GlobalIDFormField() - field.clean('TXlUeXBlOmFiYw==') + field.clean("TXlUeXBlOmFiYw==") def test_global_id_invalid(): field = GlobalIDFormField() with raises(ValidationError): - field.clean('badvalue') + field.clean("badvalue") def test_global_id_multiple_valid(): field = GlobalIDMultipleChoiceField() - field.clean(['TXlUeXBlOmFiYw==', 'TXlUeXBlOmFiYw==']) + field.clean(["TXlUeXBlOmFiYw==", "TXlUeXBlOmFiYw=="]) def test_global_id_multiple_invalid(): field = GlobalIDMultipleChoiceField() with raises(ValidationError): - field.clean(['badvalue', 'another bad avue']) + field.clean(["badvalue", "another bad avue"]) def test_global_id_none(): diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index c4c26f5..1716034 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -5,6 +5,8 @@ from django.db import models from django.utils.functional import SimpleLazyObject from py.test import raises +from django.db.models import Q + import graphene from graphene.relay import Node @@ -13,35 +15,34 @@ from ..compat import MissingType, JSONField from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..settings import graphene_settings -from .models import Article, Reporter +from .models import Article, CNNReporter, Reporter, Film, FilmDetails pytestmark = pytest.mark.django_db def test_should_query_only_fields(): with raises(Exception): - class ReporterType(DjangoObjectType): + class ReporterType(DjangoObjectType): class Meta: model = Reporter - only_fields = ('articles', ) + only_fields = ("articles",) schema = graphene.Schema(query=ReporterType) - query = ''' + query = """ query ReporterQuery { articles } - ''' + """ result = schema.execute(query) assert not result.errors def test_should_query_simplelazy_objects(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - only_fields = ('id', ) + only_fields = ("id",) class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) @@ -50,25 +51,20 @@ def test_should_query_simplelazy_objects(): return SimpleLazyObject(lambda: Reporter(id=1)) schema = graphene.Schema(query=Query) - query = ''' + query = """ query { reporter { id } } - ''' + """ result = schema.execute(query) assert not result.errors - assert result.data == { - 'reporter': { - 'id': '1' - } - } + assert result.data == {"reporter": {"id": "1"}} def test_should_query_well(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter @@ -76,9 +72,9 @@ def test_should_query_well(): reporter = graphene.Field(ReporterType) def resolve_reporter(self, info): - return Reporter(first_name='ABA', last_name='X') + return Reporter(first_name="ABA", last_name="X") - query = ''' + query = """ query ReporterQuery { reporter { firstName, @@ -86,33 +82,30 @@ def test_should_query_well(): email } } - ''' - expected = { - 'reporter': { - 'firstName': 'ABA', - 'lastName': 'X', - 'email': '' - } - } + """ + expected = {"reporter": {"firstName": "ABA", "lastName": "X", "email": ""}} schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors assert result.data == expected -@pytest.mark.skipif(JSONField is MissingType, - reason="RangeField should exist") +@pytest.mark.skipif(JSONField is MissingType, reason="RangeField should exist") def test_should_query_postgres_fields(): - from django.contrib.postgres.fields import IntegerRangeField, ArrayField, JSONField, HStoreField + from django.contrib.postgres.fields import ( + IntegerRangeField, + ArrayField, + JSONField, + HStoreField, + ) class Event(models.Model): - ages = IntegerRangeField(help_text='The age ranges') - data = JSONField(help_text='Data') + ages = IntegerRangeField(help_text="The age ranges") + data = JSONField(help_text="Data") store = HStoreField() tags = ArrayField(models.CharField(max_length=50)) class EventType(DjangoObjectType): - class Meta: model = Event @@ -122,13 +115,13 @@ def test_should_query_postgres_fields(): def resolve_event(self, info): return Event( ages=(0, 10), - data={'angry_babies': True}, - store={'h': 'store'}, - tags=['child', 'angry', 'babies'] + data={"angry_babies": True}, + store={"h": "store"}, + tags=["child", "angry", "babies"], ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query myQuery { event { ages @@ -137,14 +130,14 @@ def test_should_query_postgres_fields(): store } } - ''' + """ expected = { - 'event': { - 'ages': [0, 10], - 'tags': ['child', 'angry', 'babies'], - 'data': '{"angry_babies": true}', - 'store': '{"h": "store"}', - }, + "event": { + "ages": [0, 10], + "tags": ["child", "angry", "babies"], + "data": '{"angry_babies": true}', + "store": '{"h": "store"}', + } } result = schema.execute(query) assert not result.errors @@ -156,27 +149,27 @@ def test_should_node(): # Node._meta.registry = get_global_registry() class ReporterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, info, id): - return Reporter(id=2, first_name='Cookie Monster') + return Reporter(id=2, first_name="Cookie Monster") def resolve_articles(self, info, **args): - return [Article(headline='Hi!')] + return [Article(headline="Hi!")] class ArticleNode(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, info, id): - return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)) + return Article( + id=1, headline="Article node", pub_date=datetime.date(2002, 3, 11) + ) class Query(graphene.ObjectType): node = Node.Field() @@ -184,9 +177,9 @@ def test_should_node(): article = graphene.Field(ArticleNode) def resolve_reporter(self, info): - return Reporter(id=1, first_name='ABA', last_name='X') + return Reporter(id=1, first_name="ABA", last_name="X") - query = ''' + query = """ query ReporterQuery { reporter { id, @@ -212,26 +205,20 @@ def test_should_node(): } } } - ''' + """ expected = { - 'reporter': { - 'id': 'UmVwb3J0ZXJOb2RlOjE=', - 'firstName': 'ABA', - 'lastName': 'X', - 'email': '', - 'articles': { - 'edges': [{ - 'node': { - 'headline': 'Hi!' - } - }] - }, + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "ABA", + "lastName": "X", + "email": "", + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "myArticle": { + "id": "QXJ0aWNsZU5vZGU6MQ==", + "headline": "Article node", + "pubDate": "2002-03-11", }, - 'myArticle': { - 'id': 'QXJ0aWNsZU5vZGU6MQ==', - 'headline': 'Article node', - 'pubDate': '2002-03-11', - } } schema = graphene.Schema(query=Query) result = schema.execute(query) @@ -241,11 +228,10 @@ def test_should_node(): def test_should_query_connectionfields(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - only_fields = ('articles', ) + interfaces = (Node,) + only_fields = ("articles",) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) @@ -254,7 +240,7 @@ def test_should_query_connectionfields(): return [Reporter(id=1)] schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterConnectionQuery { allReporters { pageInfo { @@ -267,55 +253,48 @@ def test_should_query_connectionfields(): } } } - ''' + """ result = schema.execute(query) assert not result.errors assert result.data == { - 'allReporters': { - 'pageInfo': { - 'hasNextPage': False, - }, - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] + "allReporters": { + "pageInfo": {"hasNextPage": False}, + "edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}], } } def test_should_keep_annotations(): - from django.db.models import ( - Count, - Avg, - ) + from django.db.models import Count, Avg class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - only_fields = ('articles', ) + interfaces = (Node,) + only_fields = ("articles",) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) all_articles = DjangoConnectionField(ArticleType) def resolve_all_reporters(self, info, **args): - return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') + return Reporter.objects.annotate(articles_c=Count("articles")).order_by( + "articles_c" + ) def resolve_all_articles(self, info, **args): - return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') + return Article.objects.annotate(import_avg=Avg("importance")).order_by( + "import_avg" + ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterConnectionQuery { allReporters { pageInfo { @@ -338,53 +317,51 @@ def test_should_keep_annotations(): } } } - ''' + """ result = schema.execute(query) assert not result.errors -@pytest.mark.skipif(not DJANGO_FILTER_INSTALLED, - reason="django-filter should be installed") +@pytest.mark.skipif( + not DJANGO_FILTER_INSTALLED, reason="django-filter should be installed" +) def test_should_query_node_filtering(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters { edges { @@ -401,22 +378,20 @@ def test_should_query_node_filtering(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [{"node": {"id": "QXJ0aWNsZVR5cGU6MQ=="}}] + }, } } - }] + ] } } @@ -425,55 +400,98 @@ def test_should_query_node_filtering(): assert result.data == expected -@pytest.mark.skipif(not DJANGO_FILTER_INSTALLED, - reason="django-filter should be installed") +@pytest.mark.skipif( + not DJANGO_FILTER_INSTALLED, reason="django-filter should be installed" +) +def test_should_query_node_filtering_with_distinct_queryset(): + class FilmType(DjangoObjectType): + class Meta: + model = Film + interfaces = (Node,) + filter_fields = ("genre",) + + class Query(graphene.ObjectType): + films = DjangoConnectionField(FilmType) + + # def resolve_all_reporters_with_berlin_films(self, args, context, info): + # return Reporter.objects.filter(Q(films__film__location__contains="Berlin") | Q(a_choice=1)) + + def resolve_films(self, info, **args): + return Film.objects.filter( + Q(details__location__contains="Berlin") | Q(genre__in=["ot"]) + ).distinct() + + f = Film.objects.create() + fd = FilmDetails.objects.create(location="Berlin", film=f) + + schema = graphene.Schema(query=Query) + query = """ + query NodeFilteringQuery { + films { + edges { + node { + genre + } + } + } + } + """ + + expected = {"films": {"edges": [{"node": {"genre": "OT"}}]}} + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +@pytest.mark.skipif( + not DJANGO_FILTER_INSTALLED, reason="django-filter should be installed" +) def test_should_query_node_multiple_filtering(): class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', 'headline') + interfaces = (Node,) + filter_fields = ("lang", "headline") class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 3', + headline="Article Node 3", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters { edges { @@ -490,22 +508,20 @@ def test_should_query_node_multiple_filtering(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [{"node": {"id": "QXJ0aWNsZVR5cGU6MQ=="}}] + }, } } - }] + ] } } @@ -518,23 +534,19 @@ def test_should_enforce_first_or_last(): graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = True class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters { edges { @@ -544,17 +556,15 @@ def test_should_enforce_first_or_last(): } } } - ''' + """ - expected = { - 'allReporters': None - } + expected = {"allReporters": None} result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'You must provide a `first` or `last` value to properly ' - 'paginate the `allReporters` connection.' + "You must provide a `first` or `last` value to properly " + "paginate the `allReporters` connection." ) assert result.data == expected @@ -563,23 +573,19 @@ def test_should_error_if_first_is_greater_than_max(): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 100 class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(first: 101) { edges { @@ -589,17 +595,15 @@ def test_should_error_if_first_is_greater_than_max(): } } } - ''' + """ - expected = { - 'allReporters': None - } + expected = {"allReporters": None} result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Requesting 101 records on the `allReporters` connection ' - 'exceeds the `first` limit of 100 records.' + "Requesting 101 records on the `allReporters` connection " + "exceeds the `first` limit of 100 records." ) assert result.data == expected @@ -610,23 +614,19 @@ def test_should_error_if_last_is_greater_than_max(): graphene_settings.RELAY_CONNECTION_MAX_LIMIT = 100 class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(last: 101) { edges { @@ -636,17 +636,15 @@ def test_should_error_if_last_is_greater_than_max(): } } } - ''' + """ - expected = { - 'allReporters': None - } + expected = {"allReporters": None} result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Requesting 101 records on the `allReporters` connection ' - 'exceeds the `last` limit of 100 records.' + "Requesting 101 records on the `allReporters` connection " + "exceeds the `last` limit of 100 records." ) assert result.data == expected @@ -657,19 +655,18 @@ def test_should_query_promise_connectionfields(): from promise import Promise class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) def resolve_all_reporters(self, info, **args): return Promise.resolve([Reporter(id=1)]) - + schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterPromiseConnectionQuery { allReporters(first: 1) { edges { @@ -679,45 +676,34 @@ def test_should_query_promise_connectionfields(): } } } - ''' + """ - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] - } - } + expected = {"allReporters": {"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected + def test_should_query_connectionfields_with_last(): r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - + schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterLastQuery { allReporters(last: 1) { edges { @@ -727,52 +713,38 @@ def test_should_query_connectionfields_with_last(): } } } - ''' + """ - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] - } - } + expected = {"allReporters": {"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}]}} result = schema.execute(query) assert not result.errors assert result.data == expected + def test_should_query_connectionfields_with_manager(): r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) r = Reporter.objects.create( - first_name='John', - last_name='NotDoe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="NotDoe", email="johndoe@example.com", a_choice=1 ) class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): - all_reporters = DjangoConnectionField(ReporterType, on='doe_objects') + all_reporters = DjangoConnectionField(ReporterType, on="doe_objects") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - + schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterLastQuery { allReporters(first: 2) { edges { @@ -782,17 +754,9 @@ def test_should_query_connectionfields_with_manager(): } } } - ''' + """ - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=' - } - }] - } - } + expected = {"allReporters": {"edges": [{"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}]}} result = schema.execute(query) assert not result.errors @@ -805,24 +769,24 @@ def test_should_query_dataloader_fields(): def article_batch_load_fn(keys): queryset = Article.objects.filter(reporter_id__in=keys) - return Promise.resolve([ - [article for article in queryset if article.reporter_id == id] - for id in keys - ]) + return Promise.resolve( + [ + [article for article in queryset if article.reporter_id == id] + for id in keys + ] + ) article_loader = DataLoader(article_batch_load_fn) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) use_connection = True articles = DjangoConnectionField(ArticleType) @@ -834,28 +798,28 @@ def test_should_query_dataloader_fields(): all_reporters = DjangoConnectionField(ReporterType) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) + Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = graphene.Schema(query=Query) - query = ''' + query = """ query ReporterPromiseConnectionQuery { allReporters(first: 1) { edges { @@ -872,26 +836,23 @@ def test_should_query_dataloader_fields(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjE=', - 'articles': { - 'edges': [{ - 'node': { - 'headline': 'Article Node 1', - } - }, { - 'node': { - 'headline': 'Article Node 2' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjE=", + "articles": { + "edges": [ + {"node": {"headline": "Article Node 1"}}, + {"node": {"headline": "Article Node 2"}}, + ] + }, } } - }] + ] } } @@ -902,7 +863,7 @@ def test_should_query_dataloader_fields(): def test_should_handle_inherited_choices(): class BaseModel(models.Model): - choice_field = models.IntegerField(choices=((0, 'zero'), (1, 'one'))) + choice_field = models.IntegerField(choices=((0, "zero"), (1, "one"))) class ChildModel(BaseModel): class Meta: @@ -921,12 +882,128 @@ def test_should_handle_inherited_choices(): child = graphene.Field(ChildType) schema = graphene.Schema(query=Query) - query = ''' + query = """ query { child { choiceField } } - ''' + """ result = schema.execute(query) assert not result.errors + + +def test_proxy_model_support(): + """ + This test asserts that we can query for all Reporters, + even if some are of a proxy model type at runtime. + """ + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + + reporter_1 = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + reporter_2 = CNNReporter.objects.create( + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", + a_choice=1, + reporter_type=2, # set this guy to be CNN + ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query ProxyModelQuery { + allReporters { + edges { + node { + id + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + {"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}, + {"node": {"id": "UmVwb3J0ZXJUeXBlOjI="}}, + ] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_proxy_model_fails(): + """ + This test asserts that if you try to query for a proxy model, + that query will fail with: + GraphQLError('Expected value of type "CNNReporterType" but got: + CNNReporter.',) + + This is because a proxy model has the identical model definition + to its superclass, and defines its behavior at runtime, rather than + at the database level. Currently, filtering objects of the proxy models' + type isn't supported. It would require a field on the model that would + represent the type, and it doesn't seem like there is a clear way to + enforce this pattern across all projects + """ + + class CNNReporterType(DjangoObjectType): + class Meta: + model = CNNReporter + interfaces = (Node,) + use_connection = True + + reporter_1 = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + reporter_2 = CNNReporter.objects.create( + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", + a_choice=1, + reporter_type=2, # set this guy to be CNN + ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(CNNReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query ProxyModelQuery { + allReporters { + edges { + node { + id + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + {"node": {"id": "UmVwb3J0ZXJUeXBlOjE="}}, + {"node": {"id": "UmVwb3J0ZXJUeXBlOjI="}}, + ] + } + } + + result = schema.execute(query) + assert result.errors diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index 32db172..452449b 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -7,47 +7,47 @@ from .models import Reporter def test_should_raise_if_no_model(): with raises(Exception) as excinfo: + class Character1(DjangoObjectType): pass - assert 'valid Django Model' in str(excinfo.value) + + assert "valid Django Model" in str(excinfo.value) def test_should_raise_if_model_is_invalid(): with raises(Exception) as excinfo: - class Character2(DjangoObjectType): + class Character2(DjangoObjectType): class Meta: model = 1 - assert 'valid Django Model' in str(excinfo.value) + + assert "valid Django Model" in str(excinfo.value) def test_should_map_fields_correctly(): class ReporterType2(DjangoObjectType): - class Meta: model = Reporter registry = Registry() + fields = list(ReporterType2._meta.fields.keys()) assert fields[:-2] == [ - 'id', - 'first_name', - 'last_name', - 'email', - 'pets', - 'a_choice', + "id", + "first_name", + "last_name", + "email", + "pets", + "a_choice", + "reporter_type", ] - assert sorted(fields[-2:]) == [ - 'articles', - 'films', - ] + assert sorted(fields[-2:]) == ["articles", "films"] def test_should_map_only_few_fields(): class Reporter2(DjangoObjectType): - class Meta: model = Reporter - only_fields = ('id', 'email') + only_fields = ("id", "email") - assert list(Reporter2._meta.fields.keys()) == ['id', 'email'] + assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index a69870d..8a8643b 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -12,27 +12,30 @@ registry.reset_global_registry() class Reporter(DjangoObjectType): - '''Reporter description''' + """Reporter description""" + class Meta: model = ReporterModel class ArticleConnection(Connection): - '''Article Connection''' + """Article Connection""" + test = String() def resolve_test(): - return 'test' + return "test" class Meta: abstract = True class Article(DjangoObjectType): - '''Article description''' + """Article description""" + class Meta: model = ArticleModel - interfaces = (Node, ) + interfaces = (Node,) connection_class = ArticleConnection @@ -48,7 +51,7 @@ def test_django_interface(): assert issubclass(Node, Node) -@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) +@patch("graphene_django.tests.models.Article.objects.get", return_value=Article(id=1)) def test_django_get_node(get): article = Article.get_node(None, 1) get.assert_called_with(pk=1) @@ -58,27 +61,44 @@ def test_django_get_node(get): def test_django_objecttype_map_correct_fields(): fields = Reporter._meta.fields fields = list(fields.keys()) - assert fields[:-2] == ['id', 'first_name', 'last_name', 'email', 'pets', 'a_choice'] - assert sorted(fields[-2:]) == ['articles', 'films'] + assert fields[:-2] == [ + "id", + "first_name", + "last_name", + "email", + "pets", + "a_choice", + "reporter_type", + ] + assert sorted(fields[-2:]) == ["articles", "films"] def test_django_objecttype_with_node_have_correct_fields(): fields = Article._meta.fields - assert list(fields.keys()) == ['id', 'headline', 'pub_date', 'reporter', 'editor', 'lang', 'importance'] + assert list(fields.keys()) == [ + "id", + "headline", + "pub_date", + "pub_date_time", + "reporter", + "editor", + "lang", + "importance", + ] def test_django_objecttype_with_custom_meta(): class ArticleTypeOptions(DjangoObjectTypeOptions): - '''Article Type Options''' + """Article Type Options""" class ArticleType(DjangoObjectType): class Meta: abstract = True @classmethod - def __init_subclass_with_meta__(cls, _meta=None, **options): - _meta = ArticleTypeOptions(cls) - super(ArticleType, cls).__init_subclass_with_meta__(_meta=_meta, **options) + def __init_subclass_with_meta__(cls, **options): + options.setdefault("_meta", ArticleTypeOptions(cls)) + super(ArticleType, cls).__init_subclass_with_meta__(**options) class Article(ArticleType): class Meta: @@ -96,7 +116,8 @@ schema { type Article implements Node { id: ID! headline: String! - pubDate: DateTime! + pubDate: Date! + pubDateTime: DateTime! reporter: Reporter! editor: Reporter! lang: ArticleLang! @@ -124,6 +145,8 @@ enum ArticleLang { EN } +scalar Date + scalar DateTime interface Node { @@ -144,6 +167,7 @@ type Reporter { email: String! pets: [Reporter] aChoice: ReporterAChoice! + reporterType: ReporterReporterType articles(before: String, after: String, first: Int, last: Int): ArticleConnection } @@ -152,6 +176,11 @@ enum ReporterAChoice { A_2 } +enum ReporterReporterType { + A_1 + A_2 +} + type RootQuery { node(id: ID!): Node } @@ -171,6 +200,7 @@ def with_local_registry(func): else: registry.registry = old return retval + return inner @@ -179,11 +209,10 @@ def test_django_objecttype_only_fields(): class Reporter(DjangoObjectType): class Meta: model = ReporterModel - only_fields = ('id', 'email', 'films') - + only_fields = ("id", "email", "films") fields = list(Reporter._meta.fields.keys()) - assert fields == ['id', 'email', 'films'] + assert fields == ["id", "email", "films"] @with_local_registry @@ -191,8 +220,7 @@ def test_django_objecttype_exclude_fields(): class Reporter(DjangoObjectType): class Meta: model = ReporterModel - exclude_fields = ('email') - + exclude_fields = "email" fields = list(Reporter._meta.fields.keys()) - assert 'email' not in fields + assert "email" not in fields diff --git a/graphene_django/tests/test_views.py b/graphene_django/tests/test_views.py index c31db8d..db6cc4e 100644 --- a/graphene_django/tests/test_views.py +++ b/graphene_django/tests/test_views.py @@ -8,15 +8,15 @@ except ImportError: from urllib.parse import urlencode -def url_string(string='/graphql', **url_params): +def url_string(string="/graphql", **url_params): if url_params: - string += '?' + urlencode(url_params) + string += "?" + urlencode(url_params) return string def batch_url_string(**url_params): - return url_string('/graphql/batch', **url_params) + return url_string("/graphql/batch", **url_params) def response_json(response): @@ -28,405 +28,446 @@ jl = lambda **kwargs: json.dumps([kwargs]) def test_graphiql_is_enabled(client): - response = client.get(url_string(), HTTP_ACCEPT='text/html') + response = client.get(url_string(), HTTP_ACCEPT="text/html") assert response.status_code == 200 + assert response["Content-Type"].split(";")[0] == "text/html" + + +def test_qfactor_graphiql(client): + response = client.get( + url_string(query="{test}"), + HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9", + ) + assert response.status_code == 200 + assert response["Content-Type"].split(";")[0] == "text/html" + + +def test_qfactor_json(client): + response = client.get( + url_string(query="{test}"), + HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9", + ) + assert response.status_code == 200 + assert response["Content-Type"].split(";")[0] == "application/json" + assert response_json(response) == {"data": {"test": "Hello World"}} def test_allows_get_with_query_param(client): - response = client.get(url_string(query='{test}')) + response = client.get(url_string(query="{test}")) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_allows_get_with_variable_values(client): - response = client.get(url_string( - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - )) + response = client.get( + url_string( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ) + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_allows_get_with_operation_name(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query helloYou { test(who: "You"), ...shared } query helloWorld { test(who: "World"), ...shared } query helloDolly { test(who: "Dolly"), ...shared } fragment shared on QueryRoot { shared: test(who: "Everyone") } - ''', - operationName='helloWorld' - )) + """, + operationName="helloWorld", + ) + ) assert response.status_code == 200 assert response_json(response) == { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } + "data": {"test": "Hello World", "shared": "Hello Everyone"} } def test_reports_validation_errors(client): - response = client.get(url_string( - query='{ test, unknownOne, unknownTwo }' - )) + response = client.get(url_string(query="{ test, unknownOne, unknownTwo }")) assert response.status_code == 400 assert response_json(response) == { - 'errors': [ + "errors": [ { - 'message': 'Cannot query field "unknownOne" on type "QueryRoot".', - 'locations': [{'line': 1, 'column': 9}] + "message": 'Cannot query field "unknownOne" on type "QueryRoot".', + "locations": [{"line": 1, "column": 9}], }, { - 'message': 'Cannot query field "unknownTwo" on type "QueryRoot".', - 'locations': [{'line': 1, 'column': 21}] - } + "message": 'Cannot query field "unknownTwo" on type "QueryRoot".', + "locations": [{"line": 1, "column": 21}], + }, ] } def test_errors_when_missing_operation_name(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query TestQuery { test } mutation TestMutation { writeTest { test } } - ''' - )) + """ + ) + ) assert response.status_code == 400 assert response_json(response) == { - 'errors': [ + "errors": [ { - 'message': 'Must provide operation name if query contains multiple operations.' + "message": "Must provide operation name if query contains multiple operations." } ] } def test_errors_when_sending_a_mutation_via_get(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" mutation TestMutation { writeTest { test } } - ''' - )) + """ + ) + ) assert response.status_code == 405 assert response_json(response) == { - 'errors': [ - { - 'message': 'Can only perform a mutation operation from a POST request.' - } + "errors": [ + {"message": "Can only perform a mutation operation from a POST request."} ] } def test_errors_when_selecting_a_mutation_within_a_get(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query TestQuery { test } mutation TestMutation { writeTest { test } } - ''', - operationName='TestMutation' - )) + """, + operationName="TestMutation", + ) + ) assert response.status_code == 405 assert response_json(response) == { - 'errors': [ - { - 'message': 'Can only perform a mutation operation from a POST request.' - } + "errors": [ + {"message": "Can only perform a mutation operation from a POST request."} ] } def test_allows_mutation_to_exist_within_a_get(client): - response = client.get(url_string( - query=''' + response = client.get( + url_string( + query=""" query TestQuery { test } mutation TestMutation { writeTest { test } } - ''', - operationName='TestQuery' - )) + """, + operationName="TestQuery", + ) + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_allows_post_with_json_encoding(client): - response = client.post(url_string(), j(query='{test}'), 'application/json') + response = client.post(url_string(), j(query="{test}"), "application/json") assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_batch_allows_post_with_json_encoding(client): - response = client.post(batch_url_string(), jl(id=1, query='{test}'), 'application/json') + response = client.post( + batch_url_string(), jl(id=1, query="{test}"), "application/json" + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': {'test': "Hello World"}, - 'status': 200, - }] + assert response_json(response) == [ + {"id": 1, "data": {"test": "Hello World"}, "status": 200} + ] def test_batch_fails_if_is_empty(client): - response = client.post(batch_url_string(), '[]', 'application/json') + response = client.post(batch_url_string(), "[]", "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Received an empty list in the batch request.'}] + "errors": [{"message": "Received an empty list in the batch request."}] } def test_allows_sending_a_mutation_via_post(client): - response = client.post(url_string(), j(query='mutation TestMutation { writeTest { test } }'), 'application/json') + response = client.post( + url_string(), + j(query="mutation TestMutation { writeTest { test } }"), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'writeTest': {'test': 'Hello World'}} - } + assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}} def test_allows_post_with_url_encoding(client): - response = client.post(url_string(), urlencode(dict(query='{test}')), 'application/x-www-form-urlencoded') + response = client.post( + url_string(), + urlencode(dict(query="{test}")), + "application/x-www-form-urlencoded", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello World"} - } + assert response_json(response) == {"data": {"test": "Hello World"}} def test_supports_post_json_query_with_string_variables(client): - response = client.post(url_string(), j( - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - ), 'application/json') + response = client.post( + url_string(), + j( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } - + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_batch_supports_post_json_query_with_string_variables(client): - response = client.post(batch_url_string(), jl( - id=1, - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - ), 'application/json') + response = client.post( + batch_url_string(), + jl( + id=1, + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': {'test': "Hello Dolly"}, - 'status': 200, - }] + assert response_json(response) == [ + {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} + ] def test_supports_post_json_query_with_json_variables(client): - response = client.post(url_string(), j( - query='query helloWho($who: String){ test(who: $who) }', - variables={'who': "Dolly"} - ), 'application/json') + response = client.post( + url_string(), + j( + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_batch_supports_post_json_query_with_json_variables(client): - response = client.post(batch_url_string(), jl( - id=1, - query='query helloWho($who: String){ test(who: $who) }', - variables={'who': "Dolly"} - ), 'application/json') + response = client.post( + batch_url_string(), + jl( + id=1, + query="query helloWho($who: String){ test(who: $who) }", + variables={"who": "Dolly"}, + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': {'test': "Hello Dolly"}, - 'status': 200, - }] + assert response_json(response) == [ + {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} + ] def test_supports_post_url_encoded_query_with_string_variables(client): - response = client.post(url_string(), urlencode(dict( - query='query helloWho($who: String){ test(who: $who) }', - variables=json.dumps({'who': "Dolly"}) - )), 'application/x-www-form-urlencoded') + response = client.post( + url_string(), + urlencode( + dict( + query="query helloWho($who: String){ test(who: $who) }", + variables=json.dumps({"who": "Dolly"}), + ) + ), + "application/x-www-form-urlencoded", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_supports_post_json_quey_with_get_variable_values(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), j( - query='query helloWho($who: String){ test(who: $who) }', - ), 'application/json') + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + j(query="query helloWho($who: String){ test(who: $who) }"), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_post_url_encoded_query_with_get_variable_values(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), urlencode(dict( - query='query helloWho($who: String){ test(who: $who) }', - )), 'application/x-www-form-urlencoded') + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")), + "application/x-www-form-urlencoded", + ) assert response.status_code == 200 - assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } + assert response_json(response) == {"data": {"test": "Hello Dolly"}} def test_supports_post_raw_text_query_with_get_variable_values(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), - 'query helloWho($who: String){ test(who: $who) }', - 'application/graphql' + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + "query helloWho($who: String){ test(who: $who) }", + "application/graphql", + ) + + assert response.status_code == 200 + assert response_json(response) == {"data": {"test": "Hello Dolly"}} + + +def test_allows_post_with_operation_name(client): + response = client.post( + url_string(), + j( + query=""" + query helloYou { test(who: "You"), ...shared } + query helloWorld { test(who: "World"), ...shared } + query helloDolly { test(who: "Dolly"), ...shared } + fragment shared on QueryRoot { + shared: test(who: "Everyone") + } + """, + operationName="helloWorld", + ), + "application/json", ) assert response.status_code == 200 assert response_json(response) == { - 'data': {'test': "Hello Dolly"} - } - - -def test_allows_post_with_operation_name(client): - response = client.post(url_string(), j( - query=''' - query helloYou { test(who: "You"), ...shared } - query helloWorld { test(who: "World"), ...shared } - query helloDolly { test(who: "Dolly"), ...shared } - fragment shared on QueryRoot { - shared: test(who: "Everyone") - } - ''', - operationName='helloWorld' - ), 'application/json') - - assert response.status_code == 200 - assert response_json(response) == { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } + "data": {"test": "Hello World", "shared": "Hello Everyone"} } def test_batch_allows_post_with_operation_name(client): - response = client.post(batch_url_string(), jl( - id=1, - query=''' + response = client.post( + batch_url_string(), + jl( + id=1, + query=""" query helloYou { test(who: "You"), ...shared } query helloWorld { test(who: "World"), ...shared } query helloDolly { test(who: "Dolly"), ...shared } fragment shared on QueryRoot { shared: test(who: "Everyone") } - ''', - operationName='helloWorld' - ), 'application/json') + """, + operationName="helloWorld", + ), + "application/json", + ) assert response.status_code == 200 - assert response_json(response) == [{ - 'id': 1, - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - }, - 'status': 200, - }] + assert response_json(response) == [ + { + "id": 1, + "data": {"test": "Hello World", "shared": "Hello Everyone"}, + "status": 200, + } + ] def test_allows_post_with_get_operation_name(client): - response = client.post(url_string( - operationName='helloWorld' - ), ''' + response = client.post( + url_string(operationName="helloWorld"), + """ query helloYou { test(who: "You"), ...shared } query helloWorld { test(who: "World"), ...shared } query helloDolly { test(who: "Dolly"), ...shared } fragment shared on QueryRoot { shared: test(who: "Everyone") } - ''', - 'application/graphql') + """, + "application/graphql", + ) assert response.status_code == 200 assert response_json(response) == { - 'data': { - 'test': 'Hello World', - 'shared': 'Hello Everyone' - } + "data": {"test": "Hello World", "shared": "Hello Everyone"} } -@pytest.mark.urls('graphene_django.tests.urls_pretty') +@pytest.mark.urls("graphene_django.tests.urls_inherited") +def test_inherited_class_with_attributes_works(client): + inherited_url = "/graphql/inherited/" + # Check schema and pretty attributes work + response = client.post(url_string(inherited_url, query="{test}")) + assert response.content.decode() == ( + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" + ) + + # Check graphiql works + response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html") + assert response.status_code == 200 + + +@pytest.mark.urls("graphene_django.tests.urls_pretty") def test_supports_pretty_printing(client): - response = client.get(url_string(query='{test}')) + response = client.get(url_string(query="{test}")) assert response.content.decode() == ( - '{\n' - ' "data": {\n' - ' "test": "Hello World"\n' - ' }\n' - '}' + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" ) def test_supports_pretty_printing_by_request(client): - response = client.get(url_string(query='{test}', pretty='1')) + response = client.get(url_string(query="{test}", pretty="1")) assert response.content.decode() == ( - '{\n' - ' "data": {\n' - ' "test": "Hello World"\n' - ' }\n' - '}' + "{\n" ' "data": {\n' ' "test": "Hello World"\n' " }\n" "}" ) def test_handles_field_errors_caught_by_graphql(client): - response = client.get(url_string(query='{thrower}')) + response = client.get(url_string(query="{thrower}")) assert response.status_code == 200 assert response_json(response) == { - 'data': None, - 'errors': [{'locations': [{'column': 2, 'line': 1}], 'message': 'Throws!'}] + "data": None, + "errors": [ + { + "locations": [{"column": 2, "line": 1}], + "path": ["thrower"], + "message": "Throws!", + } + ], } def test_handles_syntax_errors_caught_by_graphql(client): - response = client.get(url_string(query='syntaxerror')) + response = client.get(url_string(query="syntaxerror")) assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'locations': [{'column': 1, 'line': 1}], - 'message': 'Syntax Error GraphQL request (1:1) ' - 'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n'}] + "errors": [ + { + "locations": [{"column": 1, "line": 1}], + "message": "Syntax Error GraphQL (1:1) " + 'Unexpected Name "syntaxerror"\n\n1: syntaxerror\n ^\n', + } + ] } @@ -435,25 +476,25 @@ def test_handles_errors_caused_by_a_lack_of_query(client): assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Must provide query string.'}] + "errors": [{"message": "Must provide query string."}] } def test_handles_not_expected_json_bodies(client): - response = client.post(url_string(), '[]', 'application/json') + response = client.post(url_string(), "[]", "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'The received data is not a valid JSON query.'}] + "errors": [{"message": "The received data is not a valid JSON query."}] } def test_handles_invalid_json_bodies(client): - response = client.post(url_string(), '[oh}', 'application/json') + response = client.post(url_string(), "[oh}", "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'POST body sent invalid JSON.'}] + "errors": [{"message": "POST body sent invalid JSON."}] } @@ -463,63 +504,57 @@ def test_handles_django_request_error(client, monkeypatch): monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read) - valid_json = json.dumps(dict(foo='bar')) - response = client.post(url_string(), valid_json, 'application/json') + valid_json = json.dumps(dict(foo="bar")) + response = client.post(url_string(), valid_json, "application/json") assert response.status_code == 400 - assert response_json(response) == { - 'errors': [{'message': 'foo-bar'}] - } + assert response_json(response) == {"errors": [{"message": "foo-bar"}]} def test_handles_incomplete_json_bodies(client): - response = client.post(url_string(), '{"query":', 'application/json') + response = client.post(url_string(), '{"query":', "application/json") assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'POST body sent invalid JSON.'}] + "errors": [{"message": "POST body sent invalid JSON."}] } def test_handles_plain_post_text(client): - response = client.post(url_string( - variables=json.dumps({'who': "Dolly"}) - ), - 'query helloWho($who: String){ test(who: $who) }', - 'text/plain' + response = client.post( + url_string(variables=json.dumps({"who": "Dolly"})), + "query helloWho($who: String){ test(who: $who) }", + "text/plain", ) assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Must provide query string.'}] + "errors": [{"message": "Must provide query string."}] } def test_handles_poorly_formed_variables(client): - response = client.get(url_string( - query='query helloWho($who: String){ test(who: $who) }', - variables='who:You' - )) + response = client.get( + url_string( + query="query helloWho($who: String){ test(who: $who) }", variables="who:You" + ) + ) assert response.status_code == 400 assert response_json(response) == { - 'errors': [{'message': 'Variables are invalid JSON.'}] + "errors": [{"message": "Variables are invalid JSON."}] } def test_handles_unsupported_http_methods(client): - response = client.put(url_string(query='{test}')) + response = client.put(url_string(query="{test}")) assert response.status_code == 405 - assert response['Allow'] == 'GET, POST' + assert response["Allow"] == "GET, POST" assert response_json(response) == { - 'errors': [{'message': 'GraphQL only supports GET and POST requests.'}] + "errors": [{"message": "GraphQL only supports GET and POST requests."}] } def test_passes_request_into_context_request(client): - response = client.get(url_string(query='{request}', q='testing')) + response = client.get(url_string(query="{request}", q="testing")) assert response.status_code == 200 - assert response_json(response) == { - 'data': { - 'request': 'testing' - } - } + assert response_json(response) == {"data": {"request": "testing"}} diff --git a/graphene_django/tests/urls.py b/graphene_django/tests/urls.py index 8597baa..66b3fc4 100644 --- a/graphene_django/tests/urls.py +++ b/graphene_django/tests/urls.py @@ -3,6 +3,6 @@ from django.conf.urls import url from ..views import GraphQLView urlpatterns = [ - url(r'^graphql/batch', GraphQLView.as_view(batch=True)), - url(r'^graphql', GraphQLView.as_view(graphiql=True)), + url(r"^graphql/batch", GraphQLView.as_view(batch=True)), + url(r"^graphql", GraphQLView.as_view(graphiql=True)), ] diff --git a/graphene_django/tests/urls_inherited.py b/graphene_django/tests/urls_inherited.py new file mode 100644 index 0000000..6fa8019 --- /dev/null +++ b/graphene_django/tests/urls_inherited.py @@ -0,0 +1,13 @@ +from django.conf.urls import url + +from ..views import GraphQLView +from .schema_view import schema + + +class CustomGraphQLView(GraphQLView): + schema = schema + graphiql = True + pretty = True + + +urlpatterns = [url(r"^graphql/inherited/$", CustomGraphQLView.as_view())] diff --git a/graphene_django/tests/urls_pretty.py b/graphene_django/tests/urls_pretty.py index dfe4e5b..1133c87 100644 --- a/graphene_django/tests/urls_pretty.py +++ b/graphene_django/tests/urls_pretty.py @@ -3,6 +3,4 @@ from django.conf.urls import url from ..views import GraphQLView from .schema_view import schema -urlpatterns = [ - url(r'^graphql', GraphQLView.as_view(schema=schema, pretty=True)), -] +urlpatterns = [url(r"^graphql", GraphQLView.as_view(schema=schema, pretty=True))] diff --git a/graphene_django/types.py b/graphene_django/types.py index 54ed87b..4441a9a 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -1,5 +1,7 @@ +import six from collections import OrderedDict +from django.db.models import Model from django.utils.functional import SimpleLazyObject from graphene import Field from graphene.relay import Connection, Node @@ -8,8 +10,11 @@ from graphene.types.utils import yank_fields_from_attrs from .converter import convert_django_field_with_choices from .registry import Registry, get_global_registry -from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields, - is_valid_django_model) +from .utils import DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model + + +if six.PY3: + from typing import Type def construct_fields(model, registry, only_fields, exclude_fields): @@ -21,7 +26,7 @@ def construct_fields(model, registry, only_fields, exclude_fields): # is_already_created = name in options.fields is_excluded = name in exclude_fields # or is_already_created # https://docs.djangoproject.com/en/1.10/ref/models/fields/#django.db.models.ForeignKey.related_query_name - is_no_backref = str(name).endswith('+') + is_no_backref = str(name).endswith("+") if is_not_in_only or is_excluded or is_no_backref: # We skip this field if we specify only_fields and is not # in there. Or when we exclude this field in exclude_fields. @@ -43,9 +48,21 @@ class DjangoObjectTypeOptions(ObjectTypeOptions): class DjangoObjectType(ObjectType): @classmethod - def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, - only_fields=(), exclude_fields=(), filter_fields=None, connection=None, - connection_class=None, use_connection=None, interfaces=(), _meta=None, **options): + def __init_subclass_with_meta__( + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + filter_fields=None, + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + _meta=None, + **options + ): assert is_valid_django_model(model), ( 'You need to pass a valid Django Model in {}.Meta, received "{}".' ).format(cls.__name__, model) @@ -54,7 +71,7 @@ class DjangoObjectType(ObjectType): registry = get_global_registry() assert isinstance(registry, Registry), ( - 'The attribute registry in {} needs to be an instance of ' + "The attribute registry in {} needs to be an instance of " 'Registry, received "{}".' ).format(cls.__name__, registry) @@ -62,12 +79,13 @@ class DjangoObjectType(ObjectType): raise Exception("Can only set filter_fields if Django-Filter is installed") django_fields = yank_fields_from_attrs( - construct_fields(model, registry, only_fields, exclude_fields), - _as=Field, + construct_fields(model, registry, only_fields, exclude_fields), _as=Field ) if use_connection is None and interfaces: - use_connection = any((issubclass(interface, Node) for interface in interfaces)) + use_connection = any( + (issubclass(interface, Node) for interface in interfaces) + ) if use_connection and not connection: # We create the connection automatically @@ -75,7 +93,8 @@ class DjangoObjectType(ObjectType): connection_class = Connection connection = connection_class.create_type( - '{}Connection'.format(cls.__name__), node=cls) + "{}Connection".format(cls.__name__), node=cls + ) if connection is not None: assert issubclass(connection, Connection), ( @@ -91,7 +110,9 @@ class DjangoObjectType(ObjectType): _meta.fields = django_fields _meta.connection = connection - super(DjangoObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options) + super(DjangoObjectType, cls).__init_subclass_with_meta__( + _meta=_meta, interfaces=interfaces, **options + ) if not skip_registry: registry.register(cls) @@ -107,10 +128,9 @@ class DjangoObjectType(ObjectType): if isinstance(root, cls): return True if not is_valid_django_model(type(root)): - raise Exception(( - 'Received incompatible instance "{}".' - ).format(root)) - model = root._meta.model + raise Exception(('Received incompatible instance "{}".').format(root)) + + model = root._meta.model._meta.concrete_model return model == cls._meta.model @classmethod diff --git a/graphene_django/utils.py b/graphene_django/utils.py index f8d83bf..560f604 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -13,6 +13,7 @@ class LazyList(object): try: import django_filters # noqa + DJANGO_FILTER_INSTALLED = True except ImportError: DJANGO_FILTER_INSTALLED = False @@ -25,8 +26,7 @@ def get_reverse_fields(model, local_field_names): continue # Django =>1.9 uses 'rel', django <1.9 uses 'related' - related = getattr(attr, 'rel', None) or \ - getattr(attr, 'related', None) + related = getattr(attr, "rel", None) or getattr(attr, "related", None) if isinstance(related, models.ManyToOneRel): yield (name, related) elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: @@ -42,9 +42,9 @@ def maybe_queryset(value): def get_model_fields(model): local_fields = [ (field.name, field) - for field - in sorted(list(model._meta.fields) + - list(model._meta.local_many_to_many)) + for field in sorted( + list(model._meta.fields) + list(model._meta.local_many_to_many) + ) ] # Make sure we don't duplicate local fields with "reverse" version diff --git a/graphene_django/views.py b/graphene_django/views.py index cc9e8bb..9a530de 100644 --- a/graphene_django/views.py +++ b/graphene_django/views.py @@ -10,18 +10,16 @@ from django.utils.decorators import method_decorator from django.views.generic import View from django.views.decorators.csrf import ensure_csrf_cookie -from graphql import Source, execute, parse, validate +from graphql import get_default_backend from graphql.error import format_error as format_graphql_error from graphql.error import GraphQLError from graphql.execution import ExecutionResult from graphql.type.schema import GraphQLSchema -from graphql.utils.get_operation_ast import get_operation_ast from .settings import graphene_settings class HttpError(Exception): - def __init__(self, response, message=None, *args, **kwargs): self.response = response self.message = message = message or response.content.decode() @@ -30,18 +28,18 @@ class HttpError(Exception): def get_accepted_content_types(request): def qualify(x): - parts = x.split(';', 1) + parts = x.split(";", 1) if len(parts) == 2: - match = re.match(r'(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)', - parts[1]) + match = re.match(r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1]) if match: - return parts[0], float(match.group(2)) - return parts[0], 1 + return parts[0].strip(), float(match.group(2)) + return parts[0].strip(), 1 - raw_content_types = request.META.get('HTTP_ACCEPT', '*/*').split(',') + raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",") qualified_content_types = map(qualify, raw_content_types) - return list(x[0] for x in sorted(qualified_content_types, - key=lambda x: x[1], reverse=True)) + return list( + x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True) + ) def instantiate_middleware(middlewares): @@ -53,38 +51,52 @@ def instantiate_middleware(middlewares): class GraphQLView(View): - graphiql_version = '0.10.2' - graphiql_template = 'graphene/graphiql.html' + graphiql_version = "0.11.10" + graphiql_template = "graphene/graphiql.html" schema = None graphiql = False executor = None + backend = None middleware = None root_value = None pretty = False batch = False - def __init__(self, schema=None, executor=None, middleware=None, root_value=None, graphiql=False, pretty=False, - batch=False): + def __init__( + self, + schema=None, + executor=None, + middleware=None, + root_value=None, + graphiql=False, + pretty=False, + batch=False, + backend=None, + ): if not schema: schema = graphene_settings.SCHEMA + if backend is None: + backend = get_default_backend() + if middleware is None: middleware = graphene_settings.MIDDLEWARE - self.schema = schema + self.schema = self.schema or schema if middleware is not None: self.middleware = list(instantiate_middleware(middleware)) self.executor = executor self.root_value = root_value - self.pretty = pretty - self.graphiql = graphiql - self.batch = batch + self.pretty = self.pretty or pretty + self.graphiql = self.graphiql or graphiql + self.batch = self.batch or batch + self.backend = backend assert isinstance( - self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.' - assert not all((graphiql, batch) - ), 'Use either graphiql or batch processing' + self.schema, GraphQLSchema + ), "A Schema is required to be provided to GraphQLView." + assert not all((graphiql, batch)), "Use either graphiql or batch processing" # noinspection PyUnusedLocal def get_root_value(self, request): @@ -96,62 +108,58 @@ class GraphQLView(View): def get_context(self, request): return request + def get_backend(self, request): + return self.backend + @method_decorator(ensure_csrf_cookie) def dispatch(self, request, *args, **kwargs): try: - if request.method.lower() not in ('get', 'post'): - raise HttpError(HttpResponseNotAllowed( - ['GET', 'POST'], 'GraphQL only supports GET and POST requests.')) + if request.method.lower() not in ("get", "post"): + raise HttpError( + HttpResponseNotAllowed( + ["GET", "POST"], "GraphQL only supports GET and POST requests." + ) + ) data = self.parse_body(request) - show_graphiql = self.graphiql and self.can_display_graphiql( - request, data) - - if self.batch: - responses = [self.get_response(request, entry) for entry in data] - result = '[{}]'.format(','.join([response[0] for response in responses])) - status_code = responses and max(responses, key=lambda response: response[1])[1] or 200 - else: - result, status_code = self.get_response( - request, data, show_graphiql) + show_graphiql = self.graphiql and self.can_display_graphiql(request, data) if show_graphiql: - query, variables, operation_name, id = self.get_graphql_params( - request, data) return self.render_graphiql( request, graphiql_version=self.graphiql_version, - query=query or '', - variables=json.dumps(variables) or '', - operation_name=operation_name or '', - result=result or '' ) + if self.batch: + responses = [self.get_response(request, entry) for entry in data] + result = "[{}]".format( + ",".join([response[0] for response in responses]) + ) + status_code = ( + responses + and max(responses, key=lambda response: response[1])[1] + or 200 + ) + else: + result, status_code = self.get_response(request, data, show_graphiql) + return HttpResponse( - status=status_code, - content=result, - content_type='application/json' + status=status_code, content=result, content_type="application/json" ) except HttpError as e: response = e.response - response['Content-Type'] = 'application/json' - response.content = self.json_encode(request, { - 'errors': [self.format_error(e)] - }) + response["Content-Type"] = "application/json" + response.content = self.json_encode( + request, {"errors": [self.format_error(e)]} + ) return response def get_response(self, request, data, show_graphiql=False): - query, variables, operation_name, id = self.get_graphql_params( - request, data) + query, variables, operation_name, id = self.get_graphql_params(request, data) execution_result = self.execute_graphql_request( - request, - data, - query, - variables, - operation_name, - show_graphiql + request, data, query, variables, operation_name, show_graphiql ) status_code = 200 @@ -159,17 +167,18 @@ class GraphQLView(View): response = {} if execution_result.errors: - response['errors'] = [self.format_error( - e) for e in execution_result.errors] + response["errors"] = [ + self.format_error(e) for e in execution_result.errors + ] if execution_result.invalid: status_code = 400 else: - response['data'] = execution_result.data + response["data"] = execution_result.data if self.batch: - response['id'] = id - response['status'] = status_code + response["id"] = id + response["status"] = status_code result = self.json_encode(request, response, pretty=show_graphiql) else: @@ -181,22 +190,21 @@ class GraphQLView(View): return render(request, self.graphiql_template, data) def json_encode(self, request, d, pretty=False): - if not (self.pretty or pretty) and not request.GET.get('pretty'): - return json.dumps(d, separators=(',', ':')) + if not (self.pretty or pretty) and not request.GET.get("pretty"): + return json.dumps(d, separators=(",", ":")) - return json.dumps(d, sort_keys=True, - indent=2, separators=(',', ': ')) + return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) def parse_body(self, request): content_type = self.get_content_type(request) - if content_type == 'application/graphql': - return {'query': request.body.decode()} + if content_type == "application/graphql": + return {"query": request.body.decode()} - elif content_type == 'application/json': + elif content_type == "application/json": # noinspection PyBroadException try: - body = request.body.decode('utf-8') + body = request.body.decode("utf-8") except Exception as e: raise HttpError(HttpResponseBadRequest(str(e))) @@ -204,102 +212,113 @@ class GraphQLView(View): request_json = json.loads(body) if self.batch: assert isinstance(request_json, list), ( - 'Batch requests should receive a list, but received {}.' + "Batch requests should receive a list, but received {}." ).format(repr(request_json)) - assert len(request_json) > 0, ( - 'Received an empty list in the batch request.' - ) + assert ( + len(request_json) > 0 + ), "Received an empty list in the batch request." else: - assert isinstance(request_json, dict), ( - 'The received data is not a valid JSON query.' - ) + assert isinstance( + request_json, dict + ), "The received data is not a valid JSON query." return request_json except AssertionError as e: raise HttpError(HttpResponseBadRequest(str(e))) except (TypeError, ValueError): - raise HttpError(HttpResponseBadRequest( - 'POST body sent invalid JSON.')) + raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON.")) - elif content_type in ['application/x-www-form-urlencoded', 'multipart/form-data']: + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: return request.POST return {} - def execute(self, *args, **kwargs): - return execute(self.schema, *args, **kwargs) - - def execute_graphql_request(self, request, data, query, variables, operation_name, show_graphiql=False): + def execute_graphql_request( + self, request, data, query, variables, operation_name, show_graphiql=False + ): if not query: if show_graphiql: return None - raise HttpError(HttpResponseBadRequest( - 'Must provide query string.')) - - source = Source(query, name='GraphQL request') + raise HttpError(HttpResponseBadRequest("Must provide query string.")) try: - document_ast = parse(source) - validation_errors = validate(self.schema, document_ast) - if validation_errors: - return ExecutionResult( - errors=validation_errors, - invalid=True, - ) + backend = self.get_backend(request) + document = backend.document_from_string(self.schema, query) except Exception as e: return ExecutionResult(errors=[e], invalid=True) - if request.method.lower() == 'get': - operation_ast = get_operation_ast(document_ast, operation_name) - if operation_ast and operation_ast.operation != 'query': + if request.method.lower() == "get": + operation_type = document.get_operation_type(operation_name) + if operation_type and operation_type != "query": if show_graphiql: return None - raise HttpError(HttpResponseNotAllowed( - ['POST'], 'Can only perform a {} operation from a POST request.'.format( - operation_ast.operation) - )) + raise HttpError( + HttpResponseNotAllowed( + ["POST"], + "Can only perform a {} operation from a POST request.".format( + operation_type + ), + ) + ) try: - return self.execute( - document_ast, - root_value=self.get_root_value(request), - variable_values=variables, + extra_options = {} + if self.executor: + # We only include it optionally since + # executor is not a valid argument in all backends + extra_options["executor"] = self.executor + + return document.execute( + root=self.get_root_value(request), + variables=variables, operation_name=operation_name, - context_value=self.get_context(request), + context=self.get_context(request), middleware=self.get_middleware(request), - executor=self.executor, + **extra_options ) except Exception as e: return ExecutionResult(errors=[e], invalid=True) @classmethod def can_display_graphiql(cls, request, data): - raw = 'raw' in request.GET or 'raw' in data + raw = "raw" in request.GET or "raw" in data return not raw and cls.request_wants_html(request) @classmethod def request_wants_html(cls, request): accepted = get_accepted_content_types(request) - html_index = accepted.count('text/html') - json_index = accepted.count('application/json') + accepted_length = len(accepted) + # the list will be ordered in preferred first - so we have to make + # sure the most preferred gets the highest number + html_priority = ( + accepted_length - accepted.index("text/html") + if "text/html" in accepted + else 0 + ) + json_priority = ( + accepted_length - accepted.index("application/json") + if "application/json" in accepted + else 0 + ) - return html_index > json_index + return html_priority > json_priority @staticmethod def get_graphql_params(request, data): - query = request.GET.get('query') or data.get('query') - variables = request.GET.get('variables') or data.get('variables') - id = request.GET.get('id') or data.get('id') + query = request.GET.get("query") or data.get("query") + variables = request.GET.get("variables") or data.get("variables") + id = request.GET.get("id") or data.get("id") if variables and isinstance(variables, six.text_type): try: variables = json.loads(variables) except Exception: - raise HttpError(HttpResponseBadRequest( - 'Variables are invalid JSON.')) + raise HttpError(HttpResponseBadRequest("Variables are invalid JSON.")) - operation_name = request.GET.get( - 'operationName') or data.get('operationName') + operation_name = request.GET.get("operationName") or data.get("operationName") if operation_name == "null": operation_name = None @@ -310,11 +329,10 @@ class GraphQLView(View): if isinstance(error, GraphQLError): return format_graphql_error(error) - return {'message': six.text_type(error)} + return {"message": six.text_type(error)} @staticmethod def get_content_type(request): meta = request.META - content_type = meta.get( - 'CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) - return content_type.split(';', 1)[0].lower() + content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) + return content_type.split(";", 1)[0].lower() diff --git a/setup.py b/setup.py index 2e835d1..3431cd5 100644 --- a/setup.py +++ b/setup.py @@ -3,77 +3,63 @@ import sys import ast import re -_version_re = re.compile(r'__version__\s+=\s+(.*)') +_version_re = re.compile(r"__version__\s+=\s+(.*)") -with open('graphene_django/__init__.py', 'rb') as f: - version = str(ast.literal_eval(_version_re.search( - f.read().decode('utf-8')).group(1))) +with open("graphene_django/__init__.py", "rb") as f: + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) -rest_framework_require = [ - 'djangorestframework>=3.6.3', -] +rest_framework_require = ["djangorestframework>=3.6.3"] tests_require = [ - 'pytest>=2.7.2', - 'pytest-cov', - 'coveralls', - 'mock', - 'pytz', - 'django-filter', - 'pytest-django==2.9.1', + "pytest>=3.6.3", + "pytest-cov", + "coveralls", + "mock", + "pytz", + "django-filter<2;python_version<'3'", + "django-filter>=2;python_version>='3'", + "pytest-django>=3.3.2", ] + rest_framework_require -django_version = 'Django>=1.8.0,<2' if sys.version_info[0] < 3 else 'Django>=1.8.0' setup( - name='graphene-django', + name="graphene-django", version=version, - - description='Graphene Django integration', - long_description=open('README.rst').read(), - - url='https://github.com/graphql-python/graphene-django', - - author='Syrus Akbary', - author_email='me@syrusakbary.com', - - license='MIT', - + description="Graphene Django integration", + long_description=open("README.rst").read(), + url="https://github.com/graphql-python/graphene-django", + author="Syrus Akbary", + author_email="me@syrusakbary.com", + license="MIT", classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: Implementation :: PyPy', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: Implementation :: PyPy", ], - - keywords='api graphql protocol rest relay graphene', - - packages=find_packages(exclude=['tests']), - + keywords="api graphql protocol rest relay graphene", + packages=find_packages(exclude=["tests"]), install_requires=[ - 'six>=1.10.0', - 'graphene>=2.0,<3', - django_version, - 'iso8601', - 'singledispatch>=3.4.0.3', - 'promise>=2.1', - ], - setup_requires=[ - 'pytest-runner', + "six>=1.10.0", + "graphene>=2.1.3,<3", + "graphql-core>=2.1.0,<3", + "Django>=1.11", + "singledispatch>=3.4.0.3", + "promise>=2.1", ], + setup_requires=["pytest-runner"], tests_require=tests_require, rest_framework_require=rest_framework_require, - extras_require={ - 'test': tests_require, - 'rest_framework': rest_framework_require, - }, + extras_require={"test": tests_require, "rest_framework": rest_framework_require}, include_package_data=True, zip_safe=False, - platforms='any', + platforms="any", )