diff --git a/.travis.yml b/.travis.yml index ef8a3d6..a8375ee 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,9 +11,6 @@ install: pip install -e .[test] pip install psycopg2 # Required for Django postgres fields testing pip install django==$DJANGO_VERSION - if [ $DJANGO_VERSION = 1.8 ]; then # DRF dropped 1.8 support at 3.7.0 - pip install djangorestframework==3.6.4 - fi python setup.py develop elif [ "$TEST_TYPE" = lint ]; then pip install flake8 @@ -38,13 +35,19 @@ env: matrix: fast_finish: true include: + - python: '3.4' + env: TEST_TYPE=build DJANGO_VERSION=2.0 + - python: '3.5' + env: TEST_TYPE=build DJANGO_VERSION=2.0 + - python: '3.6' + env: TEST_TYPE=build DJANGO_VERSION=2.0 + - python: '3.5' + env: TEST_TYPE=build DJANGO_VERSION=2.1 + - python: '3.6' + env: TEST_TYPE=build DJANGO_VERSION=2.1 - python: '2.7' - env: TEST_TYPE=build DJANGO_VERSION=1.8 - - python: '2.7' - env: TEST_TYPE=build DJANGO_VERSION=1.9 - - python: '2.7' - env: TEST_TYPE=build DJANGO_VERSION=1.10 - - python: '2.7' + env: TEST_TYPE=lint + - python: '3.6' env: TEST_TYPE=lint deploy: provider: pypi diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b4a4b70 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016-Present Syrus Akbary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in index 8fffb71..3c3d4f9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ -include README.md +include README.md LICENSE recursive-include graphene_django/templates * diff --git a/README.md b/README.md index 62a36f0..4e0b01d 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,10 @@ A [Django](https://www.djangoproject.com/) integration for [Graphene](http://gra ## Installation -For instaling graphene, just run this command in your shell +For installing graphene, just run this command in your shell ```bash -pip install "graphene-django>=2.0.dev" +pip install "graphene-django>=2.0" ``` ### Settings @@ -67,8 +67,7 @@ class User(DjangoObjectType): class Query(graphene.ObjectType): users = graphene.List(User) - @graphene.resolve_only_args - def resolve_users(self): + def resolve_users(self, info): return UserModel.objects.all() schema = graphene.Schema(query=Query) diff --git a/README.rst b/README.rst index 27cbdc0..a96e60f 100644 --- a/README.rst +++ b/README.rst @@ -13,11 +13,11 @@ A `Django `__ integration for Installation ------------ -For instaling graphene, just run this command in your shell +For installing graphene, just run this command in your shell .. code:: bash - pip install "graphene-django>=2.0.dev" + pip install "graphene-django>=2.0" Settings ~~~~~~~~ diff --git a/django_test_settings.py b/django_test_settings.py index 2e08272..9279a73 100644 --- a/django_test_settings.py +++ b/django_test_settings.py @@ -8,6 +8,7 @@ SECRET_KEY = 1 INSTALLED_APPS = [ 'graphene_django', + 'graphene_django.rest_framework', 'graphene_django.tests', 'starwars', ] diff --git a/docs/authorization.rst b/docs/authorization.rst index 214cbc7..9d1b2c6 100644 --- a/docs/authorization.rst +++ b/docs/authorization.rst @@ -20,7 +20,7 @@ Let's use a simple example model. Limiting Field Access --------------------- -This is easy, simply use the ``only_fields`` meta attribute. +To limit fields in a GraphQL query simply use the ``only_fields`` meta attribute. .. code:: python @@ -34,7 +34,7 @@ This is easy, simply use the ``only_fields`` meta attribute. only_fields = ('title', 'content') interfaces = (relay.Node, ) -conversely you can use ``exclude_fields`` meta atrribute. +conversely you can use ``exclude_fields`` meta attribute. .. code:: python @@ -61,10 +61,11 @@ define a resolve method for that field and return the desired queryset. from .models import Post class Query(ObjectType): - all_posts = DjangoFilterConnectionField(CategoryNode) + all_posts = DjangoFilterConnectionField(PostNode) + + def resolve_all_posts(self, info): + return Post.objects.filter(published=True) - def resolve_all_posts(self, args, info): - return Post.objects.filter(published=True) User-based Queryset Filtering ----------------------------- @@ -79,7 +80,7 @@ with the context argument. from .models import Post class Query(ObjectType): - my_posts = DjangoFilterConnectionField(CategoryNode) + my_posts = DjangoFilterConnectionField(PostNode) def resolve_my_posts(self, info): # context will reference to the Django request @@ -95,7 +96,7 @@ schema is simple. result = schema.execute(query, context_value=request) -Filtering ID-based node access +Filtering ID-based Node Access ------------------------------ In order to add authorization to id-based node access, we need to add a @@ -113,13 +114,13 @@ method to your ``DjangoObjectType``. interfaces = (relay.Node, ) @classmethod - def get_node(cls, id, context, info): + def get_node(cls, id, info): try: post = cls._meta.model.objects.get(id=id) except cls._meta.model.DoesNotExist: return None - if post.published or context.user == post.owner: + if post.published or info.context.user == post.owner: return post return None @@ -206,14 +207,14 @@ Connection example: all_reporters = MyAuthDjangoConnectionField(ReporterType) - -Adding login required +Adding Login Required --------------------- -If you want to use the standard Django LoginRequiredMixin_ you can create your own view, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``: +To restrict users from accessing the GraphQL API page the standard Django LoginRequiredMixin_ can be used to create your own standard Django Class Based View, which includes the ``LoginRequiredMixin`` and subclasses the ``GraphQLView``.: .. code:: python - + #views.py + from django.contrib.auth.mixins import LoginRequiredMixin from graphene_django.views import GraphQLView @@ -221,7 +222,9 @@ If you want to use the standard Django LoginRequiredMixin_ you can create your o class PrivateGraphQLView(LoginRequiredMixin, GraphQLView): pass -After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``: +After this, you can use the new ``PrivateGraphQLView`` in the project's URL Configuration file ``url.py``: + +For Django 1.9 and below: .. code:: python @@ -229,5 +232,14 @@ After this, you can use the new ``PrivateGraphQLView`` in ``urls.py``: # some other urls url(r'^graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)), ] + +For Django 2.0 and above: + +.. code:: python + + urlpatterns = [ + # some other urls + path('graphql', PrivateGraphQLView.as_view(graphiql=True, schema=schema)), + ] .. _LoginRequiredMixin: https://docs.djangoproject.com/en/1.10/topics/auth/default/#the-loginrequired-mixin diff --git a/docs/filtering.rst b/docs/filtering.rst index edcb7e5..feafd40 100644 --- a/docs/filtering.rst +++ b/docs/filtering.rst @@ -2,9 +2,9 @@ Filtering ========= Graphene integrates with -`django-filter `__ to provide -filtering of results. See the `usage -documentation `__ +`django-filter `__ (2.x for +Python 3 or 1.x for Python 2) to provide filtering of results. See the `usage +documentation `__ for details on the format for ``filter_fields``. This filtering is automatically available when implementing a ``relay.Node``. @@ -15,7 +15,7 @@ You will need to install it manually, which can be done as follows: .. code:: bash # You'll need to django-filter - pip install django-filter + pip install django-filter>=2 Note: The techniques below are demoed in the `cookbook example app `__. @@ -26,7 +26,7 @@ Filterable fields The ``filter_fields`` parameter is used to specify the fields which can be filtered upon. The value specified here is passed directly to ``django-filter``, so see the `filtering -documentation `__ +documentation `__ for full details on the range of options available. For example: @@ -126,3 +126,23 @@ create your own ``Filterset`` as follows: # We specify our custom AnimalFilter using the filterset_class param all_animals = DjangoFilterConnectionField(AnimalNode, filterset_class=AnimalFilter) + +The context argument is passed on as the `request argument `__ +in a ``django_filters.FilterSet`` instance. You can use this to customize your +filters to be context-dependent. We could modify the ``AnimalFilter`` above to +pre-filter animals owned by the authenticated user (set in ``context.user``). + +.. code:: python + + class AnimalFilter(django_filters.FilterSet): + # Do case-insensitive lookups on 'name' + name = django_filters.CharFilter(lookup_type='iexact') + + class Meta: + model = Animal + fields = ['name', 'genus', 'is_domesticated'] + + @property + def qs(self): + # The query context can be found in self.request. + return super(AnimalFilter, self).qs.filter(owner=self.request.user) diff --git a/docs/form-mutations.rst b/docs/form-mutations.rst new file mode 100644 index 0000000..e721a78 --- /dev/null +++ b/docs/form-mutations.rst @@ -0,0 +1,68 @@ +Integration with Django forms +============================= + +Graphene-Django comes with mutation classes that will convert the fields on Django forms into inputs on a mutation. +*Note: the API is experimental and will likely change in the future.* + +FormMutation +------------ + +.. code:: python + + class MyForm(forms.Form): + name = forms.CharField() + + class MyMutation(FormMutation): + class Meta: + form_class = MyForm + +``MyMutation`` will automatically receive an ``input`` argument. This argument should be a ``dict`` where the key is ``name`` and the value is a string. + +ModelFormMutation +----------------- + +``ModelFormMutation`` will pull the fields from a ``ModelForm``. + +.. code:: python + + class Pet(models.Model): + name = models.CharField() + + class PetForm(forms.ModelForm): + class Meta: + model = Pet + fields = ('name',) + + # This will get returned when the mutation completes successfully + class PetType(DjangoObjectType): + class Meta: + model = Pet + + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + +``PetMutation`` will grab the fields from ``PetForm`` and turn them into inputs. If the form is valid then the mutation +will lookup the ``DjangoObjectType`` for the ``Pet`` model and return that under the key ``pet``. Otherwise it will +return a list of errors. + +You can change the input name (default is ``input``) and the return field name (default is the model name lowercase). + +.. code:: python + + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + input_field_name = 'data' + return_field_name = 'my_pet' + +Form validation +--------------- + +Form mutations will call ``is_valid()`` on your forms. + +If the form is valid then ``form_valid(form, info)`` is called on the mutation. Override this method to change how +the form is saved or to return a different Graphene object type. + +If the form is *not* valid then a list of errors will be returned. These errors have two fields: ``field``, a string +containing the name of the invalid form field, and ``messages``, a list of strings with the validation messages. diff --git a/docs/index.rst b/docs/index.rst index 256da68..7c64ae7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,4 +12,5 @@ Contents: authorization debug rest-framework + form-mutations introspection diff --git a/docs/requirements.txt b/docs/requirements.txt index 2548604..220b7cf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,3 @@ sphinx # Docs template -https://github.com/graphql-python/graphene-python.org/archive/docs.zip +http://graphene-python.org/sphinx_graphene_theme.zip diff --git a/docs/rest-framework.rst b/docs/rest-framework.rst index 5e5dd70..ce666de 100644 --- a/docs/rest-framework.rst +++ b/docs/rest-framework.rst @@ -19,3 +19,46 @@ You can create a Mutation based on a serializer by using the class Meta: serializer_class = MySerializer +Create/Update Operations +--------------------- + +By default ModelSerializers accept create and update operations. To +customize this use the `model_operations` attribute. The update +operation looks up models by the primary key by default. You can +customize the look up with the lookup attribute. + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class AwesomeModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ['create', 'update'] + lookup_field = 'id' + +Overriding Update Queries +------------------------- + +Use the method `get_serializer_kwargs` to override how +updates are applied. + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class AwesomeModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + + @classmethod + def get_serializer_kwargs(cls, root, info, **input): + if 'id' in input: + instance = Post.objects.filter(id=input['id'], owner=info.context.user).first() + if instance: + return {'instance': instance, 'data': input, 'partial': True} + + else: + raise http.Http404 + + return {'data': input, 'partial': True} diff --git a/docs/tutorial-plain.rst b/docs/tutorial-plain.rst index 592f244..a87b011 100644 --- a/docs/tutorial-plain.rst +++ b/docs/tutorial-plain.rst @@ -68,7 +68,8 @@ Let's get started with these models: class Ingredient(models.Model): name = models.CharField(max_length=100) notes = models.TextField() - category = models.ForeignKey(Category, related_name='ingredients') + category = models.ForeignKey( + Category, related_name='ingredients', on_delete=models.CASCADE) def __str__(self): return self.name @@ -80,9 +81,10 @@ Add ingredients as INSTALLED_APPS: INSTALLED_APPS = [ ... # Install the ingredients app - 'ingredients', + 'cookbook.ingredients', ] + Don't forget to create & run migrations: .. code:: bash @@ -111,6 +113,18 @@ Alternatively you can use the Django admin interface to create some data yourself. You'll need to run the development server (see below), and create a login for yourself too (``./manage.py createsuperuser``). +Register models with admin panel: + +.. code:: python + + # cookbook/ingredients/admin.py + from django.contrib import admin + from cookbook.ingredients.models import Category, Ingredient + + admin.site.register(Category) + admin.site.register(Ingredient) + + Hello GraphQL - Schema and Object Types --------------------------------------- @@ -153,7 +167,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following: model = Ingredient - class Query(graphene.AbstractType): + class Query(object): all_categories = graphene.List(CategoryType) all_ingredients = graphene.List(IngredientType) @@ -165,9 +179,9 @@ Create ``cookbook/ingredients/schema.py`` and type the following: return Ingredient.objects.select_related('category').all() -Note that the above ``Query`` class is marked as 'abstract'. This is -because we will now create a project-level query which will combine all -our app-level queries. +Note that the above ``Query`` class is a mixin, inheriting from +``object``. This is because we will now create a project-level query +class which will combine all our app-level mixins. Create the parent project-level ``cookbook/schema.py``: @@ -426,7 +440,7 @@ We can update our schema to support that, by adding new query for ``ingredient`` model = Ingredient - class Query(graphene.AbstractType): + class Query(object): category = graphene.Field(CategoryType, id=graphene.Int(), name=graphene.String()) @@ -445,8 +459,8 @@ We can update our schema to support that, by adding new query for ``ingredient`` return Ingredient.objects.all() def resolve_category(self, info, **kwargs): - id = kargs.get('id') - name = kargs.get('name') + id = kwargs.get('id') + name = kwargs.get('name') if id is not None: return Category.objects.get(pk=id) @@ -457,8 +471,8 @@ We can update our schema to support that, by adding new query for ``ingredient`` return None def resolve_ingredient(self, info, **kwargs): - id = kargs.get('id') - name = kargs.get('name') + id = kwargs.get('id') + name = kwargs.get('name') if id is not None: return Ingredient.objects.get(pk=id) diff --git a/docs/tutorial-relay.rst b/docs/tutorial-relay.rst index 3ac4cec..f2502d7 100644 --- a/docs/tutorial-relay.rst +++ b/docs/tutorial-relay.rst @@ -10,7 +10,7 @@ app `__ -* `GraphQL Relay Specification `__ +* `GraphQL Relay Specification `__ Setup the Django project ------------------------ @@ -118,7 +118,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following: .. code:: python # cookbook/ingredients/schema.py - from graphene import relay, ObjectType, AbstractType + from graphene import relay, ObjectType from graphene_django import DjangoObjectType from graphene_django.filter import DjangoFilterConnectionField @@ -147,7 +147,7 @@ Create ``cookbook/ingredients/schema.py`` and type the following: interfaces = (relay.Node, ) - class Query(AbstractType): + class Query(object): category = relay.Node.Field(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode) diff --git a/examples/cookbook-plain/README.md b/examples/cookbook-plain/README.md index 018c584..0ec906b 100644 --- a/examples/cookbook-plain/README.md +++ b/examples/cookbook-plain/README.md @@ -3,7 +3,7 @@ Cookbook Example Django Project This example project demos integration between Graphene and Django. The project contains two apps, one named `ingredients` and another -named `recepies`. +named `recipes`. Getting started --------------- @@ -60,5 +60,5 @@ Now you should be ready to start the server: Now head on over to [http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql) and run some queries! -(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema) +(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema) for some example queries) diff --git a/examples/cookbook-plain/cookbook/ingredients/schema.py b/examples/cookbook-plain/cookbook/ingredients/schema.py index 895f216..1f3bb18 100644 --- a/examples/cookbook-plain/cookbook/ingredients/schema.py +++ b/examples/cookbook-plain/cookbook/ingredients/schema.py @@ -14,7 +14,7 @@ class IngredientType(DjangoObjectType): model = Ingredient -class Query(graphene.AbstractType): +class Query(object): category = graphene.Field(CategoryType, id=graphene.Int(), name=graphene.String()) diff --git a/examples/cookbook-plain/cookbook/ingredients/tests.py b/examples/cookbook-plain/cookbook/ingredients/tests.py new file mode 100644 index 0000000..4929020 --- /dev/null +++ b/examples/cookbook-plain/cookbook/ingredients/tests.py @@ -0,0 +1,2 @@ + +# Create your tests here. diff --git a/examples/cookbook-plain/cookbook/ingredients/views.py b/examples/cookbook-plain/cookbook/ingredients/views.py new file mode 100644 index 0000000..b8e4ee0 --- /dev/null +++ b/examples/cookbook-plain/cookbook/ingredients/views.py @@ -0,0 +1,2 @@ + +# Create your views here. diff --git a/examples/cookbook-plain/cookbook/recipes/models.py b/examples/cookbook-plain/cookbook/recipes/models.py index e688044..ca12fac 100644 --- a/examples/cookbook-plain/cookbook/recipes/models.py +++ b/examples/cookbook-plain/cookbook/recipes/models.py @@ -6,6 +6,7 @@ from cookbook.ingredients.models import Ingredient class Recipe(models.Model): title = models.CharField(max_length=100) instructions = models.TextField() + __unicode__ = lambda self: self.title class RecipeIngredient(models.Model): diff --git a/examples/cookbook-plain/cookbook/recipes/schema.py b/examples/cookbook-plain/cookbook/recipes/schema.py index 8ea1ccd..040c985 100644 --- a/examples/cookbook-plain/cookbook/recipes/schema.py +++ b/examples/cookbook-plain/cookbook/recipes/schema.py @@ -14,7 +14,7 @@ class RecipeIngredientType(DjangoObjectType): model = RecipeIngredient -class Query(graphene.AbstractType): +class Query(object): recipe = graphene.Field(RecipeType, id=graphene.Int(), title=graphene.String()) diff --git a/examples/cookbook-plain/cookbook/recipes/tests.py b/examples/cookbook-plain/cookbook/recipes/tests.py new file mode 100644 index 0000000..4929020 --- /dev/null +++ b/examples/cookbook-plain/cookbook/recipes/tests.py @@ -0,0 +1,2 @@ + +# Create your tests here. diff --git a/examples/cookbook-plain/cookbook/recipes/views.py b/examples/cookbook-plain/cookbook/recipes/views.py new file mode 100644 index 0000000..b8e4ee0 --- /dev/null +++ b/examples/cookbook-plain/cookbook/recipes/views.py @@ -0,0 +1,2 @@ + +# Create your views here. diff --git a/examples/cookbook-plain/requirements.txt b/examples/cookbook-plain/requirements.txt index a693bd1..362a39a 100644 --- a/examples/cookbook-plain/requirements.txt +++ b/examples/cookbook-plain/requirements.txt @@ -1,4 +1,4 @@ graphene graphene-django -graphql-core +graphql-core>=2.1rc1 django==1.9 diff --git a/examples/cookbook/README.md b/examples/cookbook/README.md index 1d3fc31..0ec906b 100644 --- a/examples/cookbook/README.md +++ b/examples/cookbook/README.md @@ -60,5 +60,5 @@ Now you should be ready to start the server: Now head on over to [http://127.0.0.1:8000/graphql](http://127.0.0.1:8000/graphql) and run some queries! -(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial#testing-our-graphql-schema) +(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema) for some example queries) diff --git a/examples/cookbook/cookbook/ingredients/admin.py b/examples/cookbook/cookbook/ingredients/admin.py index 2b16cdc..b57cbc3 100644 --- a/examples/cookbook/cookbook/ingredients/admin.py +++ b/examples/cookbook/cookbook/ingredients/admin.py @@ -2,9 +2,11 @@ from django.contrib import admin from cookbook.ingredients.models import Category, Ingredient + @admin.register(Ingredient) class IngredientAdmin(admin.ModelAdmin): - list_display = ("id","name","category") - list_editable = ("name","category") - + list_display = ('id', 'name', 'category') + list_editable = ('name', 'category') + + admin.site.register(Category) diff --git a/examples/cookbook/cookbook/ingredients/models.py b/examples/cookbook/cookbook/ingredients/models.py index a072bcf..2f0eba3 100644 --- a/examples/cookbook/cookbook/ingredients/models.py +++ b/examples/cookbook/cookbook/ingredients/models.py @@ -10,7 +10,7 @@ class Category(models.Model): class Ingredient(models.Model): name = models.CharField(max_length=100) - notes = models.TextField(null=True,blank=True) + notes = models.TextField(null=True, blank=True) category = models.ForeignKey(Category, related_name='ingredients') def __str__(self): diff --git a/examples/cookbook/cookbook/ingredients/schema.py b/examples/cookbook/cookbook/ingredients/schema.py index b8b3c12..5ad92e8 100644 --- a/examples/cookbook/cookbook/ingredients/schema.py +++ b/examples/cookbook/cookbook/ingredients/schema.py @@ -1,5 +1,5 @@ from cookbook.ingredients.models import Category, Ingredient -from graphene import AbstractType, Node +from graphene import Node from graphene_django.filter import DjangoFilterConnectionField from graphene_django.types import DjangoObjectType @@ -28,7 +28,7 @@ class IngredientNode(DjangoObjectType): } -class Query(AbstractType): +class Query(object): category = Node.Field(CategoryNode) all_categories = DjangoFilterConnectionField(CategoryNode) diff --git a/examples/cookbook/cookbook/recipes/admin.py b/examples/cookbook/cookbook/recipes/admin.py index 57e0418..10d568f 100644 --- a/examples/cookbook/cookbook/recipes/admin.py +++ b/examples/cookbook/cookbook/recipes/admin.py @@ -2,9 +2,11 @@ from django.contrib import admin from cookbook.recipes.models import Recipe, RecipeIngredient + class RecipeIngredientInline(admin.TabularInline): - model = RecipeIngredient + model = RecipeIngredient + @admin.register(Recipe) class RecipeAdmin(admin.ModelAdmin): - inlines = [RecipeIngredientInline] + inlines = [RecipeIngredientInline] diff --git a/examples/cookbook/cookbook/recipes/models.py b/examples/cookbook/cookbook/recipes/models.py index f666fe8..ca12fac 100644 --- a/examples/cookbook/cookbook/recipes/models.py +++ b/examples/cookbook/cookbook/recipes/models.py @@ -8,6 +8,7 @@ class Recipe(models.Model): instructions = models.TextField() __unicode__ = lambda self: self.title + class RecipeIngredient(models.Model): recipe = models.ForeignKey(Recipe, related_name='amounts') ingredient = models.ForeignKey(Ingredient, related_name='used_by') diff --git a/examples/cookbook/cookbook/recipes/schema.py b/examples/cookbook/cookbook/recipes/schema.py index 35a21de..8018322 100644 --- a/examples/cookbook/cookbook/recipes/schema.py +++ b/examples/cookbook/cookbook/recipes/schema.py @@ -1,5 +1,5 @@ from cookbook.recipes.models import Recipe, RecipeIngredient -from graphene import AbstractType, Node +from graphene import Node from graphene_django.filter import DjangoFilterConnectionField from graphene_django.types import DjangoObjectType @@ -24,7 +24,7 @@ class RecipeIngredientNode(DjangoObjectType): } -class Query(AbstractType): +class Query(object): recipe = Node.Field(RecipeNode) all_recipes = DjangoFilterConnectionField(RecipeNode) diff --git a/examples/cookbook/cookbook/schema.py b/examples/cookbook/cookbook/schema.py index 910e259..f8606a7 100644 --- a/examples/cookbook/cookbook/schema.py +++ b/examples/cookbook/cookbook/schema.py @@ -5,7 +5,9 @@ import graphene from graphene_django.debug import DjangoDebug -class Query(cookbook.recipes.schema.Query, cookbook.ingredients.schema.Query, graphene.ObjectType): +class Query(cookbook.ingredients.schema.Query, + cookbook.recipes.schema.Query, + graphene.ObjectType): debug = graphene.Field(DjangoDebug, name='__debug') diff --git a/examples/cookbook/cookbook/settings.py b/examples/cookbook/cookbook/settings.py index 1916201..948292d 100644 --- a/examples/cookbook/cookbook/settings.py +++ b/examples/cookbook/cookbook/settings.py @@ -1,3 +1,4 @@ +# flake8: noqa """ Django settings for cookbook project. diff --git a/examples/cookbook/cookbook/urls.py b/examples/cookbook/cookbook/urls.py index 9410ca5..9f8755b 100644 --- a/examples/cookbook/cookbook/urls.py +++ b/examples/cookbook/cookbook/urls.py @@ -3,6 +3,7 @@ from django.contrib import admin from graphene_django.views import GraphQLView + urlpatterns = [ url(r'^admin/', admin.site.urls), url(r'^graphql', GraphQLView.as_view(graphiql=True)), diff --git a/examples/cookbook/requirements.txt b/examples/cookbook/requirements.txt index 66fa629..b2ace1f 100644 --- a/examples/cookbook/requirements.txt +++ b/examples/cookbook/requirements.txt @@ -1,5 +1,5 @@ graphene graphene-django -graphql-core +graphql-core>=2.1rc1 django==1.9 -django-filter==0.11.0 +django-filter>=2 diff --git a/examples/cookbook/setup.cfg b/examples/cookbook/setup.cfg new file mode 100644 index 0000000..8c6a6e8 --- /dev/null +++ b/examples/cookbook/setup.cfg @@ -0,0 +1,2 @@ +[flake8] +exclude=migrations,.git,__pycache__ diff --git a/examples/starwars/models.py b/examples/starwars/models.py index 2f80e27..45741da 100644 --- a/examples/starwars/models.py +++ b/examples/starwars/models.py @@ -5,7 +5,7 @@ from django.db import models class Character(models.Model): name = models.CharField(max_length=50) - ship = models.ForeignKey('Ship', blank=True, null=True, related_name='characters') + ship = models.ForeignKey('Ship', on_delete=models.CASCADE, blank=True, null=True, related_name='characters') def __str__(self): return self.name @@ -13,7 +13,7 @@ class Character(models.Model): class Faction(models.Model): name = models.CharField(max_length=50) - hero = models.ForeignKey(Character) + hero = models.ForeignKey(Character, on_delete=models.CASCADE) def __str__(self): return self.name @@ -21,7 +21,7 @@ class Faction(models.Model): class Ship(models.Model): name = models.CharField(max_length=50) - faction = models.ForeignKey(Faction, related_name='ships') + faction = models.ForeignKey(Faction, on_delete=models.CASCADE, related_name='ships') def __str__(self): return self.name diff --git a/graphene_django/__init__.py b/graphene_django/__init__.py index 5ba360f..4538cb3 100644 --- a/graphene_django/__init__.py +++ b/graphene_django/__init__.py @@ -1,14 +1,6 @@ -from .types import ( - DjangoObjectType, -) -from .fields import ( - DjangoConnectionField, -) +from .types import DjangoObjectType +from .fields import DjangoConnectionField -__version__ = '2.0.0' +__version__ = "2.2.0" -__all__ = [ - '__version__', - 'DjangoObjectType', - 'DjangoConnectionField' -] +__all__ = ["__version__", "DjangoObjectType", "DjangoConnectionField"] diff --git a/graphene_django/compat.py b/graphene_django/compat.py index 0269e33..4a51de8 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -5,13 +5,7 @@ class MissingType(object): try: # Postgres fields are only available in Django with psycopg2 installed # and we cannot have psycopg2 on PyPy - from django.contrib.postgres.fields import ArrayField, HStoreField, RangeField + from django.contrib.postgres.fields import (ArrayField, HStoreField, + JSONField, RangeField) except ImportError: - ArrayField, HStoreField, JSONField, RangeField = (MissingType, ) * 4 - - -try: - # Postgres fields are only available in Django 1.9+ - from django.contrib.postgres.fields import JSONField -except ImportError: - JSONField = MissingType + ArrayField, HStoreField, JSONField, RangeField = (MissingType,) * 4 diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 3f8511c..c40313d 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,9 +1,22 @@ from django.db import models from django.utils.encoding import force_text -from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, - NonNull, String, UUID) -from graphene.types.datetime import DateTime, Time +from graphene import ( + ID, + Boolean, + Dynamic, + Enum, + Field, + Float, + Int, + List, + NonNull, + String, + UUID, + DateTime, + Date, + Time, +) from graphene.types.json import JSONString from graphene.utils.str_converters import to_camel_case, to_const from graphql import assert_valid_name @@ -33,37 +46,44 @@ def get_choices(choices): else: name = convert_choice_name(value) while name in converted_names: - name += '_' + str(len(converted_names)) + name += "_" + str(len(converted_names)) converted_names.append(name) description = help_text yield name, value, description def convert_django_field_with_choices(field, registry=None): - choices = getattr(field, 'choices', None) + if registry is not None: + converted = registry.get_converted_field(field) + if converted: + return converted + choices = getattr(field, "choices", None) if choices: meta = field.model._meta - name = to_camel_case('{}_{}'.format(meta.object_name, field.name)) + name = to_camel_case("{}_{}".format(meta.object_name, field.name)) choices = list(get_choices(choices)) named_choices = [(c[0], c[1]) for c in choices] named_choices_descriptions = {c[0]: c[2] for c in choices} class EnumWithDescriptionsType(object): - @property def description(self): return named_choices_descriptions[self.name] enum = Enum(name, list(named_choices), type=EnumWithDescriptionsType) - return enum(description=field.help_text, required=not field.null) - return convert_django_field(field, registry) + converted = enum(description=field.help_text, required=not field.null) + else: + converted = convert_django_field(field, registry) + if registry is not None: + registry.register_converted_field(field, converted) + return converted @singledispatch def convert_django_field(field, registry=None): raise Exception( - "Don't know how to convert the Django field %s (%s)" % - (field, field.__class__)) + "Don't know how to convert the Django field %s (%s)" % (field, field.__class__) + ) @convert_django_field.register(models.CharField) @@ -73,6 +93,7 @@ def convert_django_field(field, registry=None): @convert_django_field.register(models.URLField) @convert_django_field.register(models.GenericIPAddressField) @convert_django_field.register(models.FileField) +@convert_django_field.register(models.FilePathField) def convert_field_to_string(field, registry=None): return String(description=field.help_text, required=not field.null) @@ -113,9 +134,14 @@ def convert_field_to_float(field, registry=None): return Float(description=field.help_text, required=not field.null) +@convert_django_field.register(models.DateTimeField) +def convert_datetime_to_string(field, registry=None): + return DateTime(description=field.help_text, required=not field.null) + + @convert_django_field.register(models.DateField) def convert_date_to_string(field, registry=None): - return DateTime(description=field.help_text, required=not field.null) + return Date(description=field.help_text, required=not field.null) @convert_django_field.register(models.TimeField) @@ -134,7 +160,7 @@ def convert_onetoone_field_to_djangomodel(field, registry=None): # We do this for a bug in Django 1.8, where null attr # is not available in the OneToOneRel instance - null = getattr(field, 'null', True) + null = getattr(field, "null", True) return Field(_type, required=not null) return Dynamic(dynamic_type) @@ -158,6 +184,7 @@ def convert_field_to_list_or_connection(field, registry=None): # defined filter_fields in the DjangoObjectType Meta if _type._meta.filter_fields: from .filter.fields import DjangoFilterConnectionField + return DjangoFilterConnectionField(_type) return DjangoConnectionField(_type) diff --git a/graphene_django/debug/__init__.py b/graphene_django/debug/__init__.py index cd5015e..3e078da 100644 --- a/graphene_django/debug/__init__.py +++ b/graphene_django/debug/__init__.py @@ -1,4 +1,4 @@ from .middleware import DjangoDebugMiddleware from .types import DjangoDebug -__all__ = ['DjangoDebugMiddleware', 'DjangoDebug'] +__all__ = ["DjangoDebugMiddleware", "DjangoDebug"] diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index 2b11f7e..48d471f 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -7,7 +7,6 @@ from .types import DjangoDebug class DjangoDebugContext(object): - def __init__(self): self.debug_promise = None self.promises = [] @@ -38,20 +37,21 @@ class DjangoDebugContext(object): class DjangoDebugMiddleware(object): - def resolve(self, next, root, info, **args): context = info.context - django_debug = getattr(context, 'django_debug', None) + django_debug = getattr(context, "django_debug", None) if not django_debug: if context is None: - raise Exception('DjangoDebug cannot be executed in None contexts') + raise Exception("DjangoDebug cannot be executed in None contexts") try: context.django_debug = DjangoDebugContext() except Exception: - raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format( - context.__class__.__name__ - )) - if info.schema.get_type('DjangoDebug') == info.return_type: + raise Exception( + "DjangoDebug need the context to be writable, context received: {}.".format( + context.__class__.__name__ + ) + ) + if info.schema.get_type("DjangoDebug") == info.return_type: return context.django_debug.get_debug_promise() promise = next(root, info, **args) context.django_debug.add_promise(promise) diff --git a/graphene_django/debug/sql/tracking.py b/graphene_django/debug/sql/tracking.py index 9d14e4b..f96583b 100644 --- a/graphene_django/debug/sql/tracking.py +++ b/graphene_django/debug/sql/tracking.py @@ -16,7 +16,6 @@ class SQLQueryTriggered(Exception): class ThreadLocalState(local): - def __init__(self): self.enabled = True @@ -35,7 +34,7 @@ recording = state.recording # export function def wrap_cursor(connection, panel): - if not hasattr(connection, '_graphene_cursor'): + if not hasattr(connection, "_graphene_cursor"): connection._graphene_cursor = connection.cursor def cursor(): @@ -46,7 +45,7 @@ def wrap_cursor(connection, panel): def unwrap_cursor(connection): - if hasattr(connection, '_graphene_cursor'): + if hasattr(connection, "_graphene_cursor"): previous_cursor = connection._graphene_cursor connection.cursor = previous_cursor del connection._graphene_cursor @@ -87,15 +86,14 @@ class NormalCursorWrapper(object): if not params: return params if isinstance(params, dict): - return dict((key, self._quote_expr(value)) - for key, value in params.items()) + return dict((key, self._quote_expr(value)) for key, value in params.items()) return list(map(self._quote_expr, params)) def _decode(self, param): try: return force_text(param, strings_only=True) except UnicodeDecodeError: - return '(encoded string)' + return "(encoded string)" def _record(self, method, sql, params): start_time = time() @@ -103,45 +101,48 @@ class NormalCursorWrapper(object): return method(sql, params) finally: stop_time = time() - duration = (stop_time - start_time) - _params = '' + duration = stop_time - start_time + _params = "" try: _params = json.dumps(list(map(self._decode, params))) except Exception: pass # object not JSON serializable - alias = getattr(self.db, 'alias', 'default') + alias = getattr(self.db, "alias", "default") conn = self.db.connection - vendor = getattr(conn, 'vendor', 'unknown') + vendor = getattr(conn, "vendor", "unknown") params = { - 'vendor': vendor, - 'alias': alias, - 'sql': self.db.ops.last_executed_query( - self.cursor, sql, self._quote_params(params)), - 'duration': duration, - 'raw_sql': sql, - 'params': _params, - 'start_time': start_time, - 'stop_time': stop_time, - 'is_slow': duration > 10, - 'is_select': sql.lower().strip().startswith('select'), + "vendor": vendor, + "alias": alias, + "sql": self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params) + ), + "duration": duration, + "raw_sql": sql, + "params": _params, + "start_time": start_time, + "stop_time": stop_time, + "is_slow": duration > 10, + "is_select": sql.lower().strip().startswith("select"), } - if vendor == 'postgresql': + if vendor == "postgresql": # If an erroneous query was ran on the connection, it might # be in a state where checking isolation_level raises an # exception. try: iso_level = conn.isolation_level except conn.InternalError: - iso_level = 'unknown' - params.update({ - 'trans_id': self.logger.get_transaction_id(alias), - 'trans_status': conn.get_transaction_status(), - 'iso_level': iso_level, - 'encoding': conn.encoding, - }) + iso_level = "unknown" + params.update( + { + "trans_id": self.logger.get_transaction_id(alias), + "trans_status": conn.get_transaction_status(), + "iso_level": iso_level, + "encoding": conn.encoding, + } + ) _sql = DjangoDebugSQL(**params) # We keep `sql` to maintain backwards compatibility diff --git a/graphene_django/debug/sql/types.py b/graphene_django/debug/sql/types.py index 6ae4d31..850ced4 100644 --- a/graphene_django/debug/sql/types.py +++ b/graphene_django/debug/sql/types.py @@ -2,19 +2,53 @@ from graphene import Boolean, Float, ObjectType, String class DjangoDebugSQL(ObjectType): - vendor = String() - alias = String() - sql = String() - duration = Float() - raw_sql = String() - params = String() - start_time = Float() - stop_time = Float() - is_slow = Boolean() - is_select = Boolean() + class Meta: + description = ( + "Represents a single database query made to a Django managed DB." + ) + + vendor = String( + required=True, + description=( + "The type of database being used (e.g. postrgesql, mysql, sqlite)." + ), + ) + alias = String( + required=True, + description="The Django database alias (e.g. 'default').", + ) + sql = String(description="The actual SQL sent to this database.") + duration = Float( + required=True, + description="Duration of this database query in seconds.", + ) + raw_sql = String( + required=True, + description="The raw SQL of this query, without params.", + ) + params = String( + required=True, + description="JSON encoded database query parameters.", + ) + start_time = Float( + required=True, + description="Start time of this database query.", + ) + stop_time = Float( + required=True, + description="Stop time of this database query.", + ) + is_slow = Boolean( + required=True, + description="Whether this database query took more than 10 seconds.", + ) + is_select = Boolean( + required=True, + description="Whether this database query was a SELECT.", + ) # Postgres - trans_id = String() - trans_status = String() - iso_level = String() - encoding = String() + trans_id = String(description="Postgres transaction ID if available.") + trans_status = String(description="Postgres transaction status if available.") + iso_level = String(description="Postgres isolation level if available.") + encoding = String(description="Postgres connection encoding if available.") diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py index 72747b2..f2ef096 100644 --- a/graphene_django/debug/tests/test_query.py +++ b/graphene_django/debug/tests/test_query.py @@ -12,31 +12,31 @@ from ..types import DjangoDebug class context(object): pass + # from examples.starwars_django.models import Character pytestmark = pytest.mark.django_db def test_should_query_field(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_reporter(self, info, **args): return Reporter.objects.first() - query = ''' + query = """ query ReporterQuery { reporter { lastName @@ -47,43 +47,40 @@ def test_should_query_field(): } } } - ''' + """ expected = { - 'reporter': { - 'lastName': 'ABA', + "reporter": {"lastName": "ABA"}, + "__debug": { + "sql": [{"rawSql": str(Reporter.objects.order_by("pk")[:1].query)}] }, - '__debug': { - 'sql': [{ - 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) - }] - } } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors assert result.data == expected def test_should_query_list(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = graphene.List(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters { lastName @@ -94,45 +91,38 @@ def test_should_query_list(): } } } - ''' + """ expected = { - 'allReporters': [{ - 'lastName': 'ABA', - }, { - 'lastName': 'Griffin', - }], - '__debug': { - 'sql': [{ - 'rawSql': str(Reporter.objects.all().query) - }] - } + "allReporters": [{"lastName": "ABA"}, {"lastName": "Griffin"}], + "__debug": {"sql": [{"rawSql": str(Reporter.objects.all().query)}]}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors assert result.data == expected def test_should_query_connection(): - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters(first:1) { edges { @@ -147,48 +137,41 @@ def test_should_query_connection(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'lastName': 'ABA', - } - }] - }, - } + """ + expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}} schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors - assert result.data['allReporters'] == expected['allReporters'] - assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] + assert result.data["allReporters"] == expected["allReporters"] + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) - assert result.data['__debug']['sql'][1]['rawSql'] == query + assert result.data["__debug"]["sql"][1]["rawSql"] == query def test_should_query_connectionfilter(): from ...filter import DjangoFilterConnectionField - r1 = Reporter(last_name='ABA') + r1 = Reporter(last_name="ABA") r1.save() - r2 = Reporter(last_name='Griffin') + r2 = Reporter(last_name="Griffin") r2.save() class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(graphene.ObjectType): - all_reporters = DjangoFilterConnectionField(ReporterType, fields=['last_name']) + all_reporters = DjangoFilterConnectionField(ReporterType, fields=["last_name"]) s = graphene.String(resolver=lambda *_: "S") - debug = graphene.Field(DjangoDebug, name='__debug') + debug = graphene.Field(DjangoDebug, name="__debug") def resolve_all_reporters(self, info, **args): return Reporter.objects.all() - query = ''' + query = """ query ReporterQuery { allReporters(first:1) { edges { @@ -203,20 +186,14 @@ def test_should_query_connectionfilter(): } } } - ''' - expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'lastName': 'ABA', - } - }] - }, - } + """ + expected = {"allReporters": {"edges": [{"node": {"lastName": "ABA"}}]}} schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value=context(), middleware=[DjangoDebugMiddleware()]) + result = schema.execute( + query, context_value=context(), middleware=[DjangoDebugMiddleware()] + ) assert not result.errors - assert result.data['allReporters'] == expected['allReporters'] - assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] + assert result.data["allReporters"] == expected["allReporters"] + assert "COUNT" in result.data["__debug"]["sql"][0]["rawSql"] query = str(Reporter.objects.all()[:1].query) - assert result.data['__debug']['sql'][1]['rawSql'] == query + assert result.data["__debug"]["sql"][1]["rawSql"] == query diff --git a/graphene_django/debug/types.py b/graphene_django/debug/types.py index 0d3701d..cda5725 100644 --- a/graphene_django/debug/types.py +++ b/graphene_django/debug/types.py @@ -4,4 +4,10 @@ from .sql.types import DjangoDebugSQL class DjangoDebug(ObjectType): - sql = List(DjangoDebugSQL) + class Meta: + description = "Debugging information for the current query." + + sql = List( + DjangoDebugSQL, + description="Executed SQL queries for this API query.", + ) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index aa7f124..1ecce45 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -13,7 +13,6 @@ from .utils import maybe_queryset class DjangoListField(Field): - def __init__(self, _type, *args, **kwargs): super(DjangoListField, self).__init__(List(_type), *args, **kwargs) @@ -30,25 +29,28 @@ class DjangoListField(Field): class DjangoConnectionField(ConnectionField): - def __init__(self, *args, **kwargs): - self.on = kwargs.pop('on', False) + self.on = kwargs.pop("on", False) self.max_limit = kwargs.pop( - 'max_limit', - graphene_settings.RELAY_CONNECTION_MAX_LIMIT + "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT ) self.enforce_first_or_last = kwargs.pop( - 'enforce_first_or_last', - graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST + "enforce_first_or_last", + graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, ) super(DjangoConnectionField, self).__init__(*args, **kwargs) @property def type(self): from .types import DjangoObjectType + _type = super(ConnectionField, self).type - assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types" - assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) + assert issubclass( + _type, DjangoObjectType + ), "DjangoConnectionField only accepts DjangoObjectType types" + assert _type._meta.connection, "The type {} doesn't have a connection".format( + _type.__name__ + ) return _type._meta.connection @property @@ -67,6 +69,10 @@ class DjangoConnectionField(ConnectionField): @classmethod def merge_querysets(cls, default_queryset, queryset): + if default_queryset.query.distinct and not queryset.query.distinct: + queryset = queryset.distinct() + elif queryset.query.distinct and not default_queryset.query.distinct: + default_queryset = default_queryset.distinct() return queryset & default_queryset @classmethod @@ -96,28 +102,37 @@ class DjangoConnectionField(ConnectionField): return connection @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, info, **args): - first = args.get('first') - last = args.get('last') + def connection_resolver( + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + root, + info, + **args + ): + first = args.get("first") + last = args.get("last") if enforce_first_or_last: assert first or last, ( - 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + "You must provide a `first` or `last` value to properly paginate the `{}` connection." ).format(info.field_name) if max_limit: if first: assert first <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." ).format(first, info.field_name, max_limit) - args['first'] = min(first, max_limit) + args["first"] = min(first, max_limit) if last: assert last <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['last'] = min(last, max_limit) + "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." + ).format(last, info.field_name, max_limit) + args["last"] = min(last, max_limit) iterable = resolver(root, info, **args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args) @@ -134,5 +149,5 @@ class DjangoConnectionField(ConnectionField): self.type, self.get_manager(), self.max_limit, - self.enforce_first_or_last + self.enforce_first_or_last, ) diff --git a/graphene_django/filter/__init__.py b/graphene_django/filter/__init__.py index 24fae60..daafe56 100644 --- a/graphene_django/filter/__init__.py +++ b/graphene_django/filter/__init__.py @@ -4,11 +4,15 @@ from ..utils import DJANGO_FILTER_INSTALLED if not DJANGO_FILTER_INSTALLED: warnings.warn( "Use of django filtering requires the django-filter package " - "be installed. You can do so using `pip install django-filter`", ImportWarning + "be installed. You can do so using `pip install django-filter`", + ImportWarning, ) else: from .fields import DjangoFilterConnectionField from .filterset import GlobalIDFilter, GlobalIDMultipleChoiceFilter - __all__ = ['DjangoFilterConnectionField', - 'GlobalIDFilter', 'GlobalIDMultipleChoiceFilter'] + __all__ = [ + "DjangoFilterConnectionField", + "GlobalIDFilter", + "GlobalIDMultipleChoiceFilter", + ] diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index a80d8d7..cb42543 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -7,10 +7,16 @@ from .utils import get_filtering_args_from_filterset, get_filterset_class class DjangoFilterConnectionField(DjangoConnectionField): - - def __init__(self, type, fields=None, order_by=None, - extra_filter_meta=None, filterset_class=None, - *args, **kwargs): + def __init__( + self, + type, + fields=None, + order_by=None, + extra_filter_meta=None, + filterset_class=None, + *args, + **kwargs + ): self._fields = fields self._provided_filterset_class = filterset_class self._filterset_class = None @@ -30,12 +36,13 @@ class DjangoFilterConnectionField(DjangoConnectionField): def filterset_class(self): if not self._filterset_class: fields = self._fields or self.node_type._meta.filter_fields - meta = dict(model=self.model, - fields=fields) + meta = dict(model=self.model, fields=fields) if self._extra_filter_meta: meta.update(self._extra_filter_meta) - self._filterset_class = get_filterset_class(self._provided_filterset_class, **meta) + self._filterset_class = get_filterset_class( + self._provided_filterset_class, **meta + ) return self._filterset_class @@ -43,8 +50,8 @@ class DjangoFilterConnectionField(DjangoConnectionField): def filtering_args(self): return get_filtering_args_from_filterset(self.filterset_class, self.node_type) - @staticmethod - def merge_querysets(default_queryset, queryset): + @classmethod + def merge_querysets(cls, default_queryset, queryset): # There could be the case where the default queryset (returned from the filterclass) # and the resolver queryset have some limits on it. # We only would be able to apply one of those, but not both @@ -52,27 +59,40 @@ class DjangoFilterConnectionField(DjangoConnectionField): # See related PR: https://github.com/graphql-python/graphene-django/pull/126 - assert not (default_queryset.query.low_mark and queryset.query.low_mark), ( - 'Received two sliced querysets (low mark) in the connection, please slice only in one.' - ) - assert not (default_queryset.query.high_mark and queryset.query.high_mark), ( - 'Received two sliced querysets (high mark) in the connection, please slice only in one.' - ) + assert not ( + default_queryset.query.low_mark and queryset.query.low_mark + ), "Received two sliced querysets (low mark) in the connection, please slice only in one." + assert not ( + default_queryset.query.high_mark and queryset.query.high_mark + ), "Received two sliced querysets (high mark) in the connection, please slice only in one." low = default_queryset.query.low_mark or queryset.query.low_mark high = default_queryset.query.high_mark or queryset.query.high_mark default_queryset.query.clear_limits() - queryset = default_queryset & queryset + queryset = super(DjangoFilterConnectionField, cls).merge_querysets( + default_queryset, queryset + ) queryset.query.set_limits(low, high) return queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, filterset_class, filtering_args, - root, info, **args): + def connection_resolver( + cls, + resolver, + connection, + default_manager, + max_limit, + enforce_first_or_last, + filterset_class, + filtering_args, + root, + info, + **args + ): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( data=filter_kwargs, - queryset=default_manager.get_queryset() + queryset=default_manager.get_queryset(), + request=info.context, ).qs return super(DjangoFilterConnectionField, cls).connection_resolver( @@ -95,5 +115,5 @@ class DjangoFilterConnectionField(DjangoConnectionField): self.max_limit, self.enforce_first_or_last, self.filterset_class, - self.filtering_args + self.filtering_args, ) diff --git a/graphene_django/filter/filterset.py b/graphene_django/filter/filterset.py index c716b05..4059083 100644 --- a/graphene_django/filter/filterset.py +++ b/graphene_django/filter/filterset.py @@ -1,8 +1,7 @@ import itertools from django.db import models -from django.utils.text import capfirst -from django_filters import Filter, MultipleChoiceFilter +from django_filters import Filter, MultipleChoiceFilter, VERSION from django_filters.filterset import BaseFilterSet, FilterSet from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS @@ -15,7 +14,10 @@ class GlobalIDFilter(Filter): field_class = GlobalIDFormField def filter(self, qs, value): - _type, _id = from_global_id(value) + """ Convert the filter value to a primary key before filtering """ + _id = None + if value is not None: + _, _id = from_global_id(value) return super(GlobalIDFilter, self).filter(qs, _id) @@ -28,71 +30,76 @@ class GlobalIDMultipleChoiceFilter(MultipleChoiceFilter): GRAPHENE_FILTER_SET_OVERRIDES = { - models.AutoField: { - 'filter_class': GlobalIDFilter, - }, - models.OneToOneField: { - 'filter_class': GlobalIDFilter, - }, - models.ForeignKey: { - 'filter_class': GlobalIDFilter, - }, - models.ManyToManyField: { - 'filter_class': GlobalIDMultipleChoiceFilter, - } + models.AutoField: {"filter_class": GlobalIDFilter}, + models.OneToOneField: {"filter_class": GlobalIDFilter}, + models.ForeignKey: {"filter_class": GlobalIDFilter}, + models.ManyToManyField: {"filter_class": GlobalIDMultipleChoiceFilter}, + models.ManyToOneRel: {"filter_class": GlobalIDMultipleChoiceFilter}, + models.ManyToManyRel: {"filter_class": GlobalIDMultipleChoiceFilter}, } class GrapheneFilterSetMixin(BaseFilterSet): - FILTER_DEFAULTS = dict(itertools.chain( - FILTER_FOR_DBFIELD_DEFAULTS.items(), - GRAPHENE_FILTER_SET_OVERRIDES.items() - )) + """ A django_filters.filterset.BaseFilterSet with default filter overrides + to handle global IDs """ - @classmethod - def filter_for_reverse_field(cls, f, name): - """Handles retrieving filters for reverse relationships + FILTER_DEFAULTS = dict( + itertools.chain( + FILTER_FOR_DBFIELD_DEFAULTS.items(), + GRAPHENE_FILTER_SET_OVERRIDES.items() + ) + ) - We override the default implementation so that we can handle - Global IDs (the default implementation expects database - primary keys) - """ - rel = f.field.rel - default = { - 'name': name, - 'label': capfirst(rel.related_name) - } - if rel.multiple: - # For to-many relationships - return GlobalIDMultipleChoiceFilter(**default) - else: - # For to-one relationships - return GlobalIDFilter(**default) + +# To support a Django 1.11 + Python 2.7 combination django-filter must be +# < 2.x.x. To support the earlier version of django-filter, the +# filter_for_reverse_field method must be present on GrapheneFilterSetMixin and +# must not be present for later versions of django-filter. +if VERSION[0] < 2: + from django.utils.text import capfirst + + class GrapheneFilterSetMixinPython2(GrapheneFilterSetMixin): + + @classmethod + def filter_for_reverse_field(cls, f, name): + """Handles retrieving filters for reverse relationships + We override the default implementation so that we can handle + Global IDs (the default implementation expects database + primary keys) + """ + try: + rel = f.field.remote_field + except AttributeError: + rel = f.field.rel + default = {"name": name, "label": capfirst(rel.related_name)} + if rel.multiple: + # For to-many relationships + return GlobalIDMultipleChoiceFilter(**default) + else: + # For to-one relationships + return GlobalIDFilter(**default) + + GrapheneFilterSetMixin = GrapheneFilterSetMixinPython2 def setup_filterset(filterset_class): """ Wrap a provided filterset in Graphene-specific functionality """ return type( - 'Graphene{}'.format(filterset_class.__name__), + "Graphene{}".format(filterset_class.__name__), (filterset_class, GrapheneFilterSetMixin), {}, ) -def custom_filterset_factory(model, filterset_base_class=FilterSet, - **meta): +def custom_filterset_factory(model, filterset_base_class=FilterSet, **meta): """ Create a filterset for the given model using the provided meta data """ - meta.update({ - 'model': model, - }) - meta_class = type(str('Meta'), (object,), meta) + meta.update({"model": model}) + meta_class = type(str("Meta"), (object,), meta) filterset = type( - str('%sFilterSet' % model._meta.object_name), + str("%sFilterSet" % model._meta.object_name), (filterset_base_class, GrapheneFilterSetMixin), - { - 'Meta': meta_class - } + {"Meta": meta_class}, ) return filterset diff --git a/graphene_django/filter/tests/filters.py b/graphene_django/filter/tests/filters.py index 4a3fbaa..359d2ba 100644 --- a/graphene_django/filter/tests/filters.py +++ b/graphene_django/filter/tests/filters.py @@ -5,29 +5,26 @@ from graphene_django.tests.models import Article, Pet, Reporter class ArticleFilter(django_filters.FilterSet): - class Meta: model = Article fields = { - 'headline': ['exact', 'icontains'], - 'pub_date': ['gt', 'lt', 'exact'], - 'reporter': ['exact'], + "headline": ["exact", "icontains"], + "pub_date": ["gt", "lt", "exact"], + "reporter": ["exact"], } - order_by = OrderingFilter(fields=('pub_date',)) + order_by = OrderingFilter(fields=("pub_date",)) class ReporterFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['first_name', 'last_name', 'email', 'pets'] + fields = ["first_name", "last_name", "email", "pets"] - order_by = OrderingFilter(fields=('pub_date',)) + order_by = OrderingFilter(fields=("pub_date",)) class PetFilter(django_filters.FilterSet): - class Meta: model = Pet - fields = ['name'] + fields = ["name"] diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 9a0ba21..f9ef0ae 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -2,50 +2,60 @@ from datetime import datetime import pytest -from graphene import Field, ObjectType, Schema, Argument, Float +from graphene import Field, ObjectType, Schema, Argument, Float, Boolean, String from graphene.relay import Node from graphene_django import DjangoObjectType -from graphene_django.forms import (GlobalIDFormField, - GlobalIDMultipleChoiceField) +from graphene_django.forms import GlobalIDFormField, GlobalIDMultipleChoiceField from graphene_django.tests.models import Article, Pet, Reporter from graphene_django.utils import DJANGO_FILTER_INSTALLED +# for annotation test +from django.db.models import TextField, Value +from django.db.models.functions import Concat + pytestmark = [] if DJANGO_FILTER_INSTALLED: import django_filters from django_filters import FilterSet, NumberFilter - from graphene_django.filter import (GlobalIDFilter, DjangoFilterConnectionField, - GlobalIDMultipleChoiceFilter) - from graphene_django.filter.tests.filters import ArticleFilter, PetFilter, ReporterFilter + from graphene_django.filter import ( + GlobalIDFilter, + DjangoFilterConnectionField, + GlobalIDMultipleChoiceFilter, + ) + from graphene_django.filter.tests.filters import ( + ArticleFilter, + PetFilter, + ReporterFilter, + ) else: - pytestmark.append(pytest.mark.skipif(True, reason='django_filters not installed or not compatible')) + pytestmark.append( + pytest.mark.skipif( + True, reason="django_filters not installed or not compatible" + ) + ) pytestmark.append(pytest.mark.django_db) if DJANGO_FILTER_INSTALLED: - class ArticleNode(DjangoObjectType): + class ArticleNode(DjangoObjectType): class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('headline', ) - + interfaces = (Node,) + filter_fields = ("headline",) class ReporterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - + interfaces = (Node,) class PetNode(DjangoObjectType): - class Meta: model = Pet - interfaces = (Node, ) + interfaces = (Node,) # schema = Schema() @@ -55,58 +65,47 @@ def get_args(field): def assert_arguments(field, *arguments): - ignore = ('after', 'before', 'first', 'last', 'order_by') + ignore = ("after", "before", "first", "last", "order_by") args = get_args(field) - actual = [ - name - for name in args - if name not in ignore and not name.startswith('_') - ] - assert set(arguments) == set(actual), \ - 'Expected arguments ({}) did not match actual ({})'.format( - arguments, - actual - ) + actual = [name for name in args if name not in ignore and not name.startswith("_")] + assert set(arguments) == set( + actual + ), "Expected arguments ({}) did not match actual ({})".format(arguments, actual) def assert_orderable(field): args = get_args(field) - assert 'order_by' in args, \ - 'Field cannot be ordered' + assert "order_by" in args, "Field cannot be ordered" def assert_not_orderable(field): args = get_args(field) - assert 'order_by' not in args, \ - 'Field can be ordered' + assert "order_by" not in args, "Field can be ordered" def test_filter_explicit_filterset_arguments(): field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleFilter) - assert_arguments(field, - 'headline', 'headline__icontains', - 'pub_date', 'pub_date__gt', 'pub_date__lt', - 'reporter', - ) + assert_arguments( + field, + "headline", + "headline__icontains", + "pub_date", + "pub_date__gt", + "pub_date__lt", + "reporter", + ) def test_filter_shortcut_filterset_arguments_list(): - field = DjangoFilterConnectionField(ArticleNode, fields=['pub_date', 'reporter']) - assert_arguments(field, - 'pub_date', - 'reporter', - ) + field = DjangoFilterConnectionField(ArticleNode, fields=["pub_date", "reporter"]) + assert_arguments(field, "pub_date", "reporter") def test_filter_shortcut_filterset_arguments_dict(): - field = DjangoFilterConnectionField(ArticleNode, fields={ - 'headline': ['exact', 'icontains'], - 'reporter': ['exact'], - }) - assert_arguments(field, - 'headline', 'headline__icontains', - 'reporter', - ) + field = DjangoFilterConnectionField( + ArticleNode, fields={"headline": ["exact", "icontains"], "reporter": ["exact"]} + ) + assert_arguments(field, "headline", "headline__icontains", "reporter") def test_filter_explicit_filterset_orderable(): @@ -130,39 +129,91 @@ def test_filter_explicit_filterset_not_orderable(): def test_filter_shortcut_filterset_extra_meta(): - field = DjangoFilterConnectionField(ArticleNode, extra_filter_meta={ - 'exclude': ('headline', ) - }) - assert 'headline' not in field.filterset_class.get_fields() + field = DjangoFilterConnectionField( + ArticleNode, extra_filter_meta={"exclude": ("headline",)} + ) + assert "headline" not in field.filterset_class.get_fields() + + +def test_filter_shortcut_filterset_context(): + class ArticleContextFilter(django_filters.FilterSet): + class Meta: + model = Article + exclude = set() + + @property + def qs(self): + qs = super(ArticleContextFilter, self).qs + return qs.filter(reporter=self.request.reporter) + + class Query(ObjectType): + context_articles = DjangoFilterConnectionField( + ArticleNode, filterset_class=ArticleContextFilter + ) + + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + editor=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + editor=r2, + ) + + class context(object): + reporter = r2 + + query = """ + query { + contextArticles { + edges { + node { + headline + } + } + } + } + """ + schema = Schema(query=Query) + result = schema.execute(query, context_value=context()) + assert not result.errors + + assert len(result.data["contextArticles"]["edges"]) == 1 + assert result.data["contextArticles"]["edges"][0]["node"]["headline"] == "a2" def test_filter_filterset_information_on_meta(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] field = DjangoFilterConnectionField(ReporterFilterNode) - assert_arguments(field, 'first_name', 'articles') + assert_arguments(field, "first_name", "articles") assert_not_orderable(field) def test_filter_filterset_information_on_meta_related(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] class ArticleFilterNode(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ['headline', 'reporter'] + interfaces = (Node,) + filter_fields = ["headline", "reporter"] class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) @@ -171,25 +222,23 @@ def test_filter_filterset_information_on_meta_related(): article = Field(ArticleFilterNode) schema = Schema(query=Query) - articles_field = ReporterFilterNode._meta.fields['articles'].get_type() - assert_arguments(articles_field, 'headline', 'reporter') + articles_field = ReporterFilterNode._meta.fields["articles"].get_type() + assert_arguments(articles_field, "headline", "reporter") assert_not_orderable(articles_field) def test_filter_filterset_related_results(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = ['first_name', 'articles'] + interfaces = (Node,) + filter_fields = ["first_name", "articles"] class ArticleFilterNode(DjangoObjectType): - class Meta: - interfaces = (Node, ) + interfaces = (Node,) model = Article - filter_fields = ['headline', 'reporter'] + filter_fields = ["headline", "reporter"] class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) @@ -197,12 +246,22 @@ def test_filter_filterset_related_results(): reporter = Field(ReporterFilterNode) article = Field(ArticleFilterNode) - r1 = Reporter.objects.create(first_name='r1', last_name='r1', email='r1@test.com') - r2 = Reporter.objects.create(first_name='r2', last_name='r2', email='r2@test.com') - Article.objects.create(headline='a1', pub_date=datetime.now(), reporter=r1) - Article.objects.create(headline='a2', pub_date=datetime.now(), reporter=r2) + r1 = Reporter.objects.create(first_name="r1", last_name="r1", email="r1@test.com") + r2 = Reporter.objects.create(first_name="r2", last_name="r2", email="r2@test.com") + Article.objects.create( + headline="a1", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r1, + ) + Article.objects.create( + headline="a2", + pub_date=datetime.now(), + pub_date_time=datetime.now(), + reporter=r2, + ) - query = ''' + query = """ query { allReporters { edges { @@ -218,123 +277,134 @@ def test_filter_filterset_related_results(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query) assert not result.errors # We should only get back a single article for each reporter - assert len(result.data['allReporters']['edges'][0]['node']['articles']['edges']) == 1 - assert len(result.data['allReporters']['edges'][1]['node']['articles']['edges']) == 1 + assert ( + len(result.data["allReporters"]["edges"][0]["node"]["articles"]["edges"]) == 1 + ) + assert ( + len(result.data["allReporters"]["edges"][1]["node"]["articles"]["edges"]) == 1 + ) def test_global_id_field_implicit(): - field = DjangoFilterConnectionField(ArticleNode, fields=['id']) + field = DjangoFilterConnectionField(ArticleNode, fields=["id"]) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['id'] + id_filter = filterset_class.base_filters["id"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_global_id_field_explicit(): class ArticleIdFilter(django_filters.FilterSet): - class Meta: model = Article - fields = ['id'] + fields = ["id"] field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['id'] + id_filter = filterset_class.base_filters["id"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_filterset_descriptions(): class ArticleIdFilter(django_filters.FilterSet): - class Meta: model = Article - fields = ['id'] + fields = ["id"] - max_time = django_filters.NumberFilter(method='filter_max_time', label="The maximum time") + max_time = django_filters.NumberFilter( + method="filter_max_time", label="The maximum time" + ) field = DjangoFilterConnectionField(ArticleNode, filterset_class=ArticleIdFilter) - max_time = field.args['max_time'] + max_time = field.args["max_time"] assert isinstance(max_time, Argument) assert max_time.type == Float - assert max_time.description == 'The maximum time' + assert max_time.description == "The maximum time" def test_global_id_field_relation(): - field = DjangoFilterConnectionField(ArticleNode, fields=['reporter']) + field = DjangoFilterConnectionField(ArticleNode, fields=["reporter"]) filterset_class = field.filterset_class - id_filter = filterset_class.base_filters['reporter'] + id_filter = filterset_class.base_filters["reporter"] assert isinstance(id_filter, GlobalIDFilter) assert id_filter.field_class == GlobalIDFormField def test_global_id_multiple_field_implicit(): - field = DjangoFilterConnectionField(ReporterNode, fields=['pets']) + field = DjangoFilterConnectionField(ReporterNode, fields=["pets"]) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['pets'] + multiple_filter = filterset_class.base_filters["pets"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_explicit(): class ReporterPetsFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['pets'] + fields = ["pets"] - field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + field = DjangoFilterConnectionField( + ReporterNode, filterset_class=ReporterPetsFilter + ) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['pets'] + multiple_filter = filterset_class.base_filters["pets"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_implicit_reverse(): - field = DjangoFilterConnectionField(ReporterNode, fields=['articles']) + field = DjangoFilterConnectionField(ReporterNode, fields=["articles"]) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['articles'] + multiple_filter = filterset_class.base_filters["articles"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_global_id_multiple_field_explicit_reverse(): class ReporterPetsFilter(django_filters.FilterSet): - class Meta: model = Reporter - fields = ['articles'] + fields = ["articles"] - field = DjangoFilterConnectionField(ReporterNode, filterset_class=ReporterPetsFilter) + field = DjangoFilterConnectionField( + ReporterNode, filterset_class=ReporterPetsFilter + ) filterset_class = field.filterset_class - multiple_filter = filterset_class.base_filters['articles'] + multiple_filter = filterset_class.base_filters["articles"] assert isinstance(multiple_filter, GlobalIDMultipleChoiceFilter) assert multiple_filter.field_class == GlobalIDMultipleChoiceField def test_filter_filterset_related_results(): class ReporterFilterNode(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - filter_fields = { - 'first_name': ['icontains'] - } + interfaces = (Node,) + filter_fields = {"first_name": ["icontains"]} class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - r1 = Reporter.objects.create(first_name='A test user', last_name='Last Name', email='test1@test.com') - r2 = Reporter.objects.create(first_name='Other test user', last_name='Other Last Name', email='test2@test.com') - r3 = Reporter.objects.create(first_name='Random', last_name='RandomLast', email='random@test.com') + r1 = Reporter.objects.create( + first_name="A test user", last_name="Last Name", email="test1@test.com" + ) + r2 = Reporter.objects.create( + first_name="Other test user", + last_name="Other Last Name", + email="test2@test.com", + ) + r3 = Reporter.objects.create( + first_name="Random", last_name="RandomLast", email="random@test.com" + ) - query = ''' + query = """ query { allReporters(firstName_Icontains: "test") { edges { @@ -344,12 +414,12 @@ def test_filter_filterset_related_results(): } } } - ''' + """ schema = Schema(query=Query) result = schema.execute(query) assert not result.errors # We should only get two reporters - assert len(result.data['allReporters']['edges']) == 2 + assert len(result.data["allReporters"]["edges"]) == 2 def test_recursive_filter_connection(): @@ -361,77 +431,73 @@ def test_recursive_filter_connection(): class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(ObjectType): all_reporters = DjangoFilterConnectionField(ReporterFilterNode) - assert ReporterFilterNode._meta.fields['child_reporters'].node_type == ReporterFilterNode + assert ( + ReporterFilterNode._meta.fields["child_reporters"].node_type + == ReporterFilterNode + ) def test_should_query_filter_node_limit(): class ReporterFilter(FilterSet): - limit = NumberFilter(method='filter_limit') + limit = NumberFilter(method="filter_limit") def filter_limit(self, queryset, name, value): return queryset[:value] class Meta: model = Reporter - fields = ['first_name', ] + fields = ["first_name"] class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class ArticleType(DjangoObjectType): - class Meta: model = Article - interfaces = (Node, ) - filter_fields = ('lang', ) + interfaces = (Node,) + filter_fields = ("lang",) class Query(ObjectType): all_reporters = DjangoFilterConnectionField( - ReporterType, - filterset_class=ReporterFilter + ReporterType, filterset_class=ReporterFilter ) def resolve_all_reporters(self, info, **args): - return Reporter.objects.order_by('a_choice') + return Reporter.objects.order_by("a_choice") Reporter.objects.create( - first_name='Bob', - last_name='Doe', - email='bobdoe@example.com', - a_choice=2 + first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) Article.objects.create( - headline='Article Node 1', + headline="Article Node 1", pub_date=datetime.now(), + pub_date_time=datetime.now(), reporter=r, editor=r, - lang='es' + lang="es", ) Article.objects.create( - headline='Article Node 2', + headline="Article Node 2", pub_date=datetime.now(), + pub_date_time=datetime.now(), reporter=r, editor=r, - lang='en' + lang="en", ) schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(limit: 1) { edges { @@ -450,24 +516,23 @@ def test_should_query_filter_node_limit(): } } } - ''' + """ expected = { - 'allReporters': { - 'edges': [{ - 'node': { - 'id': 'UmVwb3J0ZXJUeXBlOjI=', - 'firstName': 'John', - 'articles': { - 'edges': [{ - 'node': { - 'id': 'QXJ0aWNsZVR5cGU6MQ==', - 'lang': 'ES' - } - }] + "allReporters": { + "edges": [ + { + "node": { + "id": "UmVwb3J0ZXJUeXBlOjI=", + "firstName": "John", + "articles": { + "edges": [ + {"node": {"id": "QXJ0aWNsZVR5cGU6MQ==", "lang": "ES"}} + ] + }, } } - }] + ] } } @@ -478,45 +543,37 @@ def test_should_query_filter_node_limit(): def test_should_query_filter_node_double_limit_raises(): class ReporterFilter(FilterSet): - limit = NumberFilter(method='filter_limit') + limit = NumberFilter(method="filter_limit") def filter_limit(self, queryset, name, value): return queryset[:value] class Meta: model = Reporter - fields = ['first_name', ] + fields = ["first_name"] class ReporterType(DjangoObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) class Query(ObjectType): all_reporters = DjangoFilterConnectionField( - ReporterType, - filterset_class=ReporterFilter + ReporterType, filterset_class=ReporterFilter ) def resolve_all_reporters(self, info, **args): - return Reporter.objects.order_by('a_choice')[:2] + return Reporter.objects.order_by("a_choice")[:2] Reporter.objects.create( - first_name='Bob', - last_name='Doe', - email='bobdoe@example.com', - a_choice=2 + first_name="Bob", last_name="Doe", email="bobdoe@example.com", a_choice=2 ) r = Reporter.objects.create( - first_name='John', - last_name='Doe', - email='johndoe@example.com', - a_choice=1 + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 ) schema = Schema(query=Query) - query = ''' + query = """ query NodeFilteringQuery { allReporters(limit: 1) { edges { @@ -527,10 +584,116 @@ def test_should_query_filter_node_double_limit_raises(): } } } - ''' + """ result = schema.execute(query) assert len(result.errors) == 1 assert str(result.errors[0]) == ( - 'Received two sliced querysets (high mark) in the connection, please slice only in one.' + "Received two sliced querysets (high mark) in the connection, please slice only in one." ) + + +def test_order_by_is_perserved(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = () + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField( + ReporterType, reverse_order=Boolean() + ) + + def resolve_all_reporters(self, info, reverse_order=False, **args): + reporters = Reporter.objects.order_by("first_name") + + if reverse_order: + return reporters.reverse() + + return reporters + + Reporter.objects.create(first_name="b") + r = Reporter.objects.create(first_name="a") + + schema = Schema(query=Query) + query = """ + query NodeFilteringQuery { + allReporters(first: 1) { + edges { + node { + firstName + } + } + } + } + """ + expected = {"allReporters": {"edges": [{"node": {"firstName": "a"}}]}} + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + reverse_query = """ + query NodeFilteringQuery { + allReporters(first: 1, reverseOrder: true) { + edges { + node { + firstName + } + } + } + } + """ + + reverse_expected = {"allReporters": {"edges": [{"node": {"firstName": "b"}}]}} + + reverse_result = schema.execute(reverse_query) + + assert not reverse_result.errors + assert reverse_result.data == reverse_expected + + +def test_annotation_is_perserved(): + class ReporterType(DjangoObjectType): + full_name = String() + + def resolve_full_name(instance, info, **args): + return instance.full_name + + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = () + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterType) + + def resolve_all_reporters(self, info, **args): + return Reporter.objects.annotate( + full_name=Concat( + "first_name", Value(" "), "last_name", output_field=TextField() + ) + ) + + Reporter.objects.create(first_name="John", last_name="Doe") + + schema = Schema(query=Query) + + query = """ + query NodeFilteringQuery { + allReporters(first: 1) { + edges { + node { + fullName + } + } + } + } + """ + expected = {"allReporters": {"edges": [{"node": {"fullName": "John Doe"}}]}} + + result = schema.execute(query) + + assert not result.errors + assert result.data == expected diff --git a/graphene_django/filter/utils.py b/graphene_django/filter/utils.py index 6b938ce..cfa5621 100644 --- a/graphene_django/filter/utils.py +++ b/graphene_django/filter/utils.py @@ -8,7 +8,7 @@ def get_filtering_args_from_filterset(filterset_class, type): a Graphene Field. These arguments will be available to filter against in the GraphQL """ - from ..form_converter import convert_form_field + from ..forms.converter import convert_form_field args = {} for name, filter_field in six.iteritems(filterset_class.base_filters): diff --git a/graphene_django/forms/__init__.py b/graphene_django/forms/__init__.py new file mode 100644 index 0000000..066eec4 --- /dev/null +++ b/graphene_django/forms/__init__.py @@ -0,0 +1 @@ +from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField # noqa diff --git a/graphene_django/form_converter.py b/graphene_django/forms/converter.py similarity index 71% rename from graphene_django/form_converter.py rename to graphene_django/forms/converter.py index 46a38b3..87180b2 100644 --- a/graphene_django/form_converter.py +++ b/graphene_django/forms/converter.py @@ -1,30 +1,24 @@ from django import forms -from django.forms.fields import BaseTemporalField +from django.core.exceptions import ImproperlyConfigured -from graphene import ID, Boolean, Float, Int, List, String, UUID +from graphene import ID, Boolean, Float, Int, List, String, UUID, Date, DateTime, Time from .forms import GlobalIDFormField, GlobalIDMultipleChoiceField -from .utils import import_single_dispatch +from ..utils import import_single_dispatch + singledispatch = import_single_dispatch() -try: - UUIDField = forms.UUIDField -except AttributeError: - class UUIDField(object): - pass - @singledispatch def convert_form_field(field): - raise Exception( + raise ImproperlyConfigured( "Don't know how to convert the Django form field %s (%s) " - "to Graphene type" % - (field, field.__class__) + "to Graphene type" % (field, field.__class__) ) -@convert_form_field.register(BaseTemporalField) +@convert_form_field.register(forms.fields.BaseTemporalField) @convert_form_field.register(forms.CharField) @convert_form_field.register(forms.EmailField) @convert_form_field.register(forms.SlugField) @@ -36,7 +30,7 @@ def convert_form_field_to_string(field): return String(description=field.help_text, required=field.required) -@convert_form_field.register(UUIDField) +@convert_form_field.register(forms.UUIDField) def convert_form_field_to_uuid(field): return UUID(description=field.help_text, required=field.required) @@ -69,6 +63,21 @@ def convert_form_field_to_list(field): return List(ID, required=field.required) +@convert_form_field.register(forms.DateField) +def convert_form_field_to_date(field): + return Date(description=field.help_text, required=field.required) + + +@convert_form_field.register(forms.DateTimeField) +def convert_form_field_to_datetime(field): + return DateTime(description=field.help_text, required=field.required) + + +@convert_form_field.register(forms.TimeField) +def convert_form_field_to_time(field): + return Time(description=field.help_text, required=field.required) + + @convert_form_field.register(forms.ModelChoiceField) @convert_form_field.register(GlobalIDFormField) def convert_form_field_to_id(field): diff --git a/graphene_django/forms.py b/graphene_django/forms/forms.py similarity index 72% rename from graphene_django/forms.py rename to graphene_django/forms/forms.py index a54f0a5..14e68c8 100644 --- a/graphene_django/forms.py +++ b/graphene_django/forms/forms.py @@ -8,9 +8,7 @@ from graphql_relay import from_global_id class GlobalIDFormField(Field): - default_error_messages = { - 'invalid': _('Invalid ID specified.'), - } + default_error_messages = {"invalid": _("Invalid ID specified.")} def clean(self, value): if not value and not self.required: @@ -19,21 +17,21 @@ class GlobalIDFormField(Field): try: _type, _id = from_global_id(value) except (TypeError, ValueError, UnicodeDecodeError, binascii.Error): - raise ValidationError(self.error_messages['invalid']) + raise ValidationError(self.error_messages["invalid"]) try: CharField().clean(_id) CharField().clean(_type) except ValidationError: - raise ValidationError(self.error_messages['invalid']) + raise ValidationError(self.error_messages["invalid"]) return value class GlobalIDMultipleChoiceField(MultipleChoiceField): default_error_messages = { - 'invalid_choice': _('One of the specified IDs was invalid (%(value)s).'), - 'invalid_list': _('Enter a list of values.'), + "invalid_choice": _("One of the specified IDs was invalid (%(value)s)."), + "invalid_list": _("Enter a list of values."), } def valid_value(self, value): diff --git a/graphene_django/forms/mutation.py b/graphene_django/forms/mutation.py new file mode 100644 index 0000000..63ea089 --- /dev/null +++ b/graphene_django/forms/mutation.py @@ -0,0 +1,192 @@ +# from django import forms +from collections import OrderedDict + +import graphene +from graphene import Field, InputField +from graphene.relay.mutation import ClientIDMutation +from graphene.types.mutation import MutationOptions + +# from graphene.types.inputobjecttype import ( +# InputObjectTypeOptions, +# InputObjectType, +# ) +from graphene.types.utils import yank_fields_from_attrs +from graphene_django.registry import get_global_registry + +from .converter import convert_form_field +from .types import ErrorType + + +def fields_for_form(form, only_fields, exclude_fields): + fields = OrderedDict() + for name, field in form.fields.items(): + is_not_in_only = only_fields and name not in only_fields + is_excluded = ( + name + in exclude_fields # or + # name in already_created_fields + ) + + if is_not_in_only or is_excluded: + continue + + fields[name] = convert_form_field(field) + return fields + + +class BaseDjangoFormMutation(ClientIDMutation): + class Meta: + abstract = True + + @classmethod + def mutate_and_get_payload(cls, root, info, **input): + form = cls.get_form(root, info, **input) + + if form.is_valid(): + return cls.perform_mutate(form, info) + else: + errors = [ + ErrorType(field=key, messages=value) + for key, value in form.errors.items() + ] + + return cls(errors=errors) + + @classmethod + def get_form(cls, root, info, **input): + form_kwargs = cls.get_form_kwargs(root, info, **input) + return cls._meta.form_class(**form_kwargs) + + @classmethod + def get_form_kwargs(cls, root, info, **input): + kwargs = {"data": input} + + pk = input.pop("id", None) + if pk: + instance = cls._meta.model._default_manager.get(pk=pk) + kwargs["instance"] = instance + + return kwargs + + +# class DjangoFormInputObjectTypeOptions(InputObjectTypeOptions): +# form_class = None + + +# class DjangoFormInputObjectType(InputObjectType): +# class Meta: +# abstract = True + +# @classmethod +# def __init_subclass_with_meta__(cls, form_class=None, +# only_fields=(), exclude_fields=(), _meta=None, **options): +# if not _meta: +# _meta = DjangoFormInputObjectTypeOptions(cls) +# assert isinstance(form_class, forms.Form), ( +# 'form_class must be an instance of django.forms.Form' +# ) +# _meta.form_class = form_class +# form = form_class() +# fields = fields_for_form(form, only_fields, exclude_fields) +# super(DjangoFormInputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, fields=fields, **options) + + +class DjangoFormMutationOptions(MutationOptions): + form_class = None + + +class DjangoFormMutation(BaseDjangoFormMutation): + class Meta: + abstract = True + + errors = graphene.List(ErrorType) + + @classmethod + def __init_subclass_with_meta__( + cls, form_class=None, only_fields=(), exclude_fields=(), **options + ): + + if not form_class: + raise Exception("form_class is required for DjangoFormMutation") + + form = form_class() + input_fields = fields_for_form(form, only_fields, exclude_fields) + output_fields = fields_for_form(form, only_fields, exclude_fields) + + _meta = DjangoFormMutationOptions(cls) + _meta.form_class = form_class + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) + + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(DjangoFormMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options + ) + + @classmethod + def perform_mutate(cls, form, info): + form.save() + return cls(errors=[]) + + +class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions): + model = None + return_field_name = None + + +class DjangoModelFormMutation(BaseDjangoFormMutation): + class Meta: + abstract = True + + errors = graphene.List(ErrorType) + + @classmethod + def __init_subclass_with_meta__( + cls, + form_class=None, + model=None, + return_field_name=None, + only_fields=(), + exclude_fields=(), + **options + ): + + if not form_class: + raise Exception("form_class is required for DjangoModelFormMutation") + + if not model: + model = form_class._meta.model + + if not model: + raise Exception("model is required for DjangoModelFormMutation") + + form = form_class() + input_fields = fields_for_form(form, only_fields, exclude_fields) + if "id" not in exclude_fields: + input_fields["id"] = graphene.ID() + + registry = get_global_registry() + model_type = registry.get_type_for_model(model) + return_field_name = return_field_name + if not return_field_name: + model_name = model.__name__ + return_field_name = model_name[:1].lower() + model_name[1:] + + output_fields = OrderedDict() + output_fields[return_field_name] = graphene.Field(model_type) + + _meta = DjangoModelDjangoFormMutationOptions(cls) + _meta.form_class = form_class + _meta.model = model + _meta.return_field_name = return_field_name + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) + + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(DjangoModelFormMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options + ) + + @classmethod + def perform_mutate(cls, form, info): + obj = form.save() + kwargs = {cls._meta.return_field_name: obj} + return cls(errors=[], **kwargs) diff --git a/graphene_django/forms/tests/__init__.py b/graphene_django/forms/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphene_django/forms/tests/test_converter.py b/graphene_django/forms/tests/test_converter.py new file mode 100644 index 0000000..955b952 --- /dev/null +++ b/graphene_django/forms/tests/test_converter.py @@ -0,0 +1,114 @@ +from django import forms +from py.test import raises + +import graphene +from graphene import ( + String, + Int, + Boolean, + Float, + ID, + UUID, + List, + NonNull, + DateTime, + Date, + Time, +) + +from ..converter import convert_form_field + + +def assert_conversion(django_field, graphene_field, *args): + field = django_field(*args, help_text="Custom Help Text") + graphene_type = convert_form_field(field) + assert isinstance(graphene_type, graphene_field) + field = graphene_type.Field() + assert field.description == "Custom Help Text" + return field + + +def test_should_unknown_django_field_raise_exception(): + with raises(Exception) as excinfo: + convert_form_field(None) + assert "Don't know how to convert the Django form field" in str(excinfo.value) + + +def test_should_date_convert_date(): + assert_conversion(forms.DateField, Date) + + +def test_should_time_convert_time(): + assert_conversion(forms.TimeField, Time) + + +def test_should_date_time_convert_date_time(): + assert_conversion(forms.DateTimeField, DateTime) + + +def test_should_char_convert_string(): + assert_conversion(forms.CharField, String) + + +def test_should_email_convert_string(): + assert_conversion(forms.EmailField, String) + + +def test_should_slug_convert_string(): + assert_conversion(forms.SlugField, String) + + +def test_should_url_convert_string(): + assert_conversion(forms.URLField, String) + + +def test_should_choice_convert_string(): + assert_conversion(forms.ChoiceField, String) + + +def test_should_base_field_convert_string(): + assert_conversion(forms.Field, String) + + +def test_should_regex_convert_string(): + assert_conversion(forms.RegexField, String, "[0-9]+") + + +def test_should_uuid_convert_string(): + if hasattr(forms, "UUIDField"): + assert_conversion(forms.UUIDField, UUID) + + +def test_should_integer_convert_int(): + assert_conversion(forms.IntegerField, Int) + + +def test_should_boolean_convert_boolean(): + field = assert_conversion(forms.BooleanField, Boolean) + assert isinstance(field.type, NonNull) + + +def test_should_nullboolean_convert_boolean(): + field = assert_conversion(forms.NullBooleanField, Boolean) + assert not isinstance(field.type, NonNull) + + +def test_should_float_convert_float(): + assert_conversion(forms.FloatField, Float) + + +def test_should_decimal_convert_float(): + assert_conversion(forms.DecimalField, Float) + + +def test_should_multiple_choice_convert_connectionorlist(): + field = forms.ModelMultipleChoiceField(queryset=None) + graphene_type = convert_form_field(field) + assert isinstance(graphene_type, List) + assert graphene_type.of_type == ID + + +def test_should_manytoone_convert_connectionorlist(): + field = forms.ModelChoiceField(queryset=None) + graphene_type = convert_form_field(field) + assert isinstance(graphene_type, ID) diff --git a/graphene_django/forms/tests/test_mutation.py b/graphene_django/forms/tests/test_mutation.py new file mode 100644 index 0000000..df0ffd5 --- /dev/null +++ b/graphene_django/forms/tests/test_mutation.py @@ -0,0 +1,141 @@ +from django import forms +from django.test import TestCase +from py.test import raises + +from graphene_django.tests.models import Pet, Film, FilmDetails +from ..mutation import DjangoFormMutation, DjangoModelFormMutation + + +class MyForm(forms.Form): + text = forms.CharField() + + +class PetForm(forms.ModelForm): + class Meta: + model = Pet + fields = '__all__' + + +def test_needs_form_class(): + with raises(Exception) as exc: + + class MyMutation(DjangoFormMutation): + pass + + assert exc.value.args[0] == "form_class is required for DjangoFormMutation" + + +def test_has_output_fields(): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + + assert "errors" in MyMutation._meta.fields + + +def test_has_input_fields(): + class MyMutation(DjangoFormMutation): + class Meta: + form_class = MyForm + + assert "text" in MyMutation.Input._meta.fields + + +class ModelFormMutationTests(TestCase): + def test_default_meta_fields(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + self.assertEqual(PetMutation._meta.model, Pet) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("pet", PetMutation._meta.fields) + + def test_default_input_meta_fields(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + self.assertEqual(PetMutation._meta.model, Pet) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("name", PetMutation.Input._meta.fields) + self.assertIn("client_mutation_id", PetMutation.Input._meta.fields) + self.assertIn("id", PetMutation.Input._meta.fields) + + def test_exclude_fields_input_meta_fields(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + exclude_fields = ['id'] + + self.assertEqual(PetMutation._meta.model, Pet) + self.assertEqual(PetMutation._meta.return_field_name, "pet") + self.assertIn("name", PetMutation.Input._meta.fields) + self.assertIn("age", PetMutation.Input._meta.fields) + self.assertIn("client_mutation_id", PetMutation.Input._meta.fields) + self.assertNotIn("id", PetMutation.Input._meta.fields) + + def test_return_field_name_is_camelcased(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + model = FilmDetails + + self.assertEqual(PetMutation._meta.model, FilmDetails) + self.assertEqual(PetMutation._meta.return_field_name, "filmDetails") + + def test_custom_return_field_name(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + model = Film + return_field_name = "animal" + + self.assertEqual(PetMutation._meta.model, Film) + self.assertEqual(PetMutation._meta.return_field_name, "animal") + self.assertIn("animal", PetMutation._meta.fields) + + def test_model_form_mutation_mutate(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + pet = Pet.objects.create(name="Axel", age=10) + + result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name="Mia", age=10) + + self.assertEqual(Pet.objects.count(), 1) + pet.refresh_from_db() + self.assertEqual(pet.name, "Mia") + self.assertEqual(result.errors, []) + + def test_model_form_mutation_updates_existing_(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + result = PetMutation.mutate_and_get_payload(None, None, name="Mia", age=10) + + self.assertEqual(Pet.objects.count(), 1) + pet = Pet.objects.get() + self.assertEqual(pet.name, "Mia") + self.assertEqual(pet.age, 10) + self.assertEqual(result.errors, []) + + def test_model_form_mutation_mutate_invalid_form(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + result = PetMutation.mutate_and_get_payload(None, None) + + # A pet was not created + self.assertEqual(Pet.objects.count(), 0) + + + fields_w_error = [e.field for e in result.errors] + self.assertEqual(len(result.errors), 2) + self.assertIn("name", fields_w_error) + self.assertEqual(result.errors[0].messages, ["This field is required."]) + self.assertIn("age", fields_w_error) + self.assertEqual(result.errors[1].messages, ["This field is required."]) diff --git a/graphene_django/forms/types.py b/graphene_django/forms/types.py new file mode 100644 index 0000000..1fe33f3 --- /dev/null +++ b/graphene_django/forms/types.py @@ -0,0 +1,6 @@ +import graphene + + +class ErrorType(graphene.ObjectType): + field = graphene.String() + messages = graphene.List(graphene.String) diff --git a/graphene_django/management/commands/graphql_schema.py b/graphene_django/management/commands/graphql_schema.py index 7e2dbac..4e526ec 100644 --- a/graphene_django/management/commands/graphql_schema.py +++ b/graphene_django/management/commands/graphql_schema.py @@ -1,79 +1,51 @@ import importlib import json -from distutils.version import StrictVersion -from optparse import make_option -from django import get_version as get_django_version from django.core.management.base import BaseCommand, CommandError from graphene_django.settings import graphene_settings -LT_DJANGO_1_8 = StrictVersion(get_django_version()) < StrictVersion('1.8') -if LT_DJANGO_1_8: - class CommandArguments(BaseCommand): - option_list = BaseCommand.option_list + ( - make_option( - '--schema', - type=str, - dest='schema', - default='', - help='Django app containing schema to dump, e.g. myproject.core.schema.schema', - ), - make_option( - '--out', - type=str, - dest='out', - default='', - help='Output file (default: schema.json)' - ), - make_option( - '--indent', - type=int, - dest='indent', - default=None, - help='Output file indent (default: None)' - ), +class CommandArguments(BaseCommand): + def add_arguments(self, parser): + parser.add_argument( + "--schema", + type=str, + dest="schema", + default=graphene_settings.SCHEMA, + help="Django app containing schema to dump, e.g. myproject.core.schema.schema", ) -else: - class CommandArguments(BaseCommand): - def add_arguments(self, parser): - parser.add_argument( - '--schema', - type=str, - dest='schema', - default=graphene_settings.SCHEMA, - help='Django app containing schema to dump, e.g. myproject.core.schema.schema') + parser.add_argument( + "--out", + type=str, + dest="out", + default=graphene_settings.SCHEMA_OUTPUT, + help="Output file, --out=- prints to stdout (default: schema.json)", + ) - parser.add_argument( - '--out', - type=str, - dest='out', - default=graphene_settings.SCHEMA_OUTPUT, - help='Output file (default: schema.json)') - - parser.add_argument( - '--indent', - type=int, - dest='indent', - default=graphene_settings.SCHEMA_INDENT, - help='Output file indent (default: None)') + parser.add_argument( + "--indent", + type=int, + dest="indent", + default=graphene_settings.SCHEMA_INDENT, + help="Output file indent (default: None)", + ) class Command(CommandArguments): - help = 'Dump Graphene schema JSON to file' + help = "Dump Graphene schema JSON to file" can_import_settings = True def save_file(self, out, schema_dict, indent): - with open(out, 'w') as outfile: + with open(out, "w") as outfile: json.dump(schema_dict, outfile, indent=indent) def handle(self, *args, **options): - options_schema = options.get('schema') + options_schema = options.get("schema") if options_schema and type(options_schema) is str: - module_str, schema_name = options_schema.rsplit('.', 1) + module_str, schema_name = options_schema.rsplit(".", 1) mod = importlib.import_module(module_str) schema = getattr(mod, schema_name) @@ -83,16 +55,21 @@ class Command(CommandArguments): else: schema = graphene_settings.SCHEMA - out = options.get('out') or graphene_settings.SCHEMA_OUTPUT + out = options.get("out") or graphene_settings.SCHEMA_OUTPUT if not schema: - raise CommandError('Specify schema on GRAPHENE.SCHEMA setting or by using --schema') + raise CommandError( + "Specify schema on GRAPHENE.SCHEMA setting or by using --schema" + ) - indent = options.get('indent') - schema_dict = {'data': schema.introspect()} - self.save_file(out, schema_dict, indent) + indent = options.get("indent") + schema_dict = {"data": schema.introspect()} + if out == '-': + self.stdout.write(json.dumps(schema_dict, indent=indent)) + else: + self.save_file(out, schema_dict, indent) - style = getattr(self, 'style', None) - success = getattr(style, 'SUCCESS', lambda x: x) + style = getattr(self, "style", None) + success = getattr(style, "SUCCESS", lambda x: x) - self.stdout.write(success('Successfully dumped GraphQL schema to %s' % out)) + self.stdout.write(success("Successfully dumped GraphQL schema to %s" % out)) diff --git a/graphene_django/registry.py b/graphene_django/registry.py index 4e681cc..50a8ae5 100644 --- a/graphene_django/registry.py +++ b/graphene_django/registry.py @@ -1,25 +1,32 @@ - class Registry(object): - def __init__(self): self._registry = {} - self._registry_models = {} + self._field_registry = {} def register(self, cls): from .types import DjangoObjectType + assert issubclass( - cls, DjangoObjectType), 'Only DjangoObjectTypes can be registered, received "{}"'.format( - cls.__name__) - assert cls._meta.registry == self, 'Registry for a Model have to match.' + cls, DjangoObjectType + ), 'Only DjangoObjectTypes can be registered, received "{}"'.format( + cls.__name__ + ) + assert cls._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) == cls, ( # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) # ) - if not getattr(cls._meta, 'skip_registry', False): + if not getattr(cls._meta, "skip_registry", False): self._registry[cls._meta.model] = cls def get_type_for_model(self, model): return self._registry.get(model) + def register_converted_field(self, field, converted): + self._field_registry[field] = converted + + def get_converted_field(self, field): + return self._field_registry.get(field) + registry = None diff --git a/graphene_django/rest_framework/models.py b/graphene_django/rest_framework/models.py new file mode 100644 index 0000000..848837b --- /dev/null +++ b/graphene_django/rest_framework/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class MyFakeModel(models.Model): + cool_name = models.CharField(max_length=50) + created = models.DateTimeField(auto_now_add=True) diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index 94d1e4b..5e343aa 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -1,20 +1,21 @@ from collections import OrderedDict +from django.shortcuts import get_object_or_404 + import graphene from graphene.types import Field, InputField from graphene.types.mutation import MutationOptions from graphene.relay.mutation import ClientIDMutation -from graphene.types.objecttype import ( - yank_fields_from_attrs -) +from graphene.types.objecttype import yank_fields_from_attrs -from .serializer_converter import ( - convert_serializer_field -) +from .serializer_converter import convert_serializer_field from .types import ErrorType class SerializerMutationOptions(MutationOptions): + lookup_field = None + model_class = None + model_operations = ["create", "update"] serializer_class = None @@ -23,7 +24,8 @@ def fields_for_serializer(serializer, only_fields, exclude_fields, is_input=Fals for name, field in serializer.fields.items(): is_not_in_only = only_fields and name not in only_fields is_excluded = ( - name in exclude_fields # or + name + in exclude_fields # or # name in already_created_fields ) @@ -39,37 +41,86 @@ class SerializerMutation(ClientIDMutation): abstract = True errors = graphene.List( - ErrorType, - description='May contain more than one error for same field.' + ErrorType, description="May contain more than one error for same field." ) @classmethod - def __init_subclass_with_meta__(cls, serializer_class=None, - only_fields=(), exclude_fields=(), **options): + def __init_subclass_with_meta__( + cls, + lookup_field=None, + serializer_class=None, + model_class=None, + model_operations=["create", "update"], + only_fields=(), + exclude_fields=(), + **options + ): if not serializer_class: - raise Exception('serializer_class is required for the SerializerMutation') + raise Exception("serializer_class is required for the SerializerMutation") + + if "update" not in model_operations and "create" not in model_operations: + raise Exception('model_operations must contain "create" and/or "update"') serializer = serializer_class() - input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) - output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) + if model_class is None: + serializer_meta = getattr(serializer_class, "Meta", None) + if serializer_meta: + model_class = getattr(serializer_meta, "model", None) + + if lookup_field is None and model_class: + lookup_field = model_class._meta.pk.name + + input_fields = fields_for_serializer( + serializer, only_fields, exclude_fields, is_input=True + ) + output_fields = fields_for_serializer( + serializer, only_fields, exclude_fields, is_input=False + ) _meta = SerializerMutationOptions(cls) + _meta.lookup_field = lookup_field + _meta.model_operations = model_operations _meta.serializer_class = serializer_class - _meta.fields = yank_fields_from_attrs( - output_fields, - _as=Field, + _meta.model_class = model_class + _meta.fields = yank_fields_from_attrs(output_fields, _as=Field) + + input_fields = yank_fields_from_attrs(input_fields, _as=InputField) + super(SerializerMutation, cls).__init_subclass_with_meta__( + _meta=_meta, input_fields=input_fields, **options ) - input_fields = yank_fields_from_attrs( - input_fields, - _as=InputField, - ) - super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) + @classmethod + def get_serializer_kwargs(cls, root, info, **input): + lookup_field = cls._meta.lookup_field + model_class = cls._meta.model_class + + if model_class: + if "update" in cls._meta.model_operations and lookup_field in input: + instance = get_object_or_404( + model_class, **{lookup_field: input[lookup_field]} + ) + elif "create" in cls._meta.model_operations: + instance = None + else: + raise Exception( + 'Invalid update operation. Input parameter "{}" required.'.format( + lookup_field + ) + ) + + return { + "instance": instance, + "data": input, + "context": {"request": info.context}, + } + + return {"data": input, "context": {"request": info.context}} @classmethod def mutate_and_get_payload(cls, root, info, **input): - serializer = cls._meta.serializer_class(data=input) + kwargs = cls.get_serializer_kwargs(root, info, **input) + serializer = cls._meta.serializer_class(**kwargs) if serializer.is_valid(): return cls.perform_mutate(serializer, info) @@ -84,4 +135,9 @@ class SerializerMutation(ClientIDMutation): @classmethod def perform_mutate(cls, serializer, info): obj = serializer.save() - return cls(errors=None, **obj) + + kwargs = {} + for f, field in serializer.fields.items(): + kwargs[f] = field.get_attribute(obj) + + return cls(errors=None, **kwargs) diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py index 6a57f5f..9f8e516 100644 --- a/graphene_django/rest_framework/serializer_converter.py +++ b/graphene_django/rest_framework/serializer_converter.py @@ -28,15 +28,12 @@ def convert_serializer_field(field, is_input=True): graphql_type = get_graphene_type_from_serializer_field(field) args = [] - kwargs = { - 'description': field.help_text, - 'required': is_input and field.required, - } + kwargs = {"description": field.help_text, "required": is_input and field.required} # if it is a tuple or a list it means that we are returning # the graphql type and the child type if isinstance(graphql_type, (list, tuple)): - kwargs['of_type'] = graphql_type[1] + kwargs["of_type"] = graphql_type[1] graphql_type = graphql_type[0] if isinstance(field, serializers.ModelSerializer): @@ -46,6 +43,15 @@ def convert_serializer_field(field, is_input=True): global_registry = get_global_registry() field_model = field.Meta.model args = [global_registry.get_type_for_model(field_model)] + elif isinstance(field, serializers.ListSerializer): + field = field.child + if is_input: + kwargs["of_type"] = convert_serializer_to_input_type(field.__class__) + else: + del kwargs["of_type"] + global_registry = get_global_registry() + field_model = field.Meta.model + args = [global_registry.get_type_for_model(field_model)] return graphql_type(*args, **kwargs) @@ -59,9 +65,9 @@ def convert_serializer_to_input_type(serializer_class): } return type( - '{}Input'.format(serializer.__class__.__name__), + "{}Input".format(serializer.__class__.__name__), (graphene.InputObjectType,), - items + items, ) @@ -75,6 +81,12 @@ def convert_serializer_to_field(field): return graphene.Field +@get_graphene_type_from_serializer_field.register(serializers.ListSerializer) +def convert_list_serializer_to_field(field): + child_type = get_graphene_type_from_serializer_field(field.child) + return (graphene.List, child_type) + + @get_graphene_type_from_serializer_field.register(serializers.IntegerField) def convert_serializer_field_to_int(field): return graphene.Int @@ -92,9 +104,13 @@ def convert_serializer_field_to_float(field): @get_graphene_type_from_serializer_field.register(serializers.DateTimeField) +def convert_serializer_field_to_datetime_time(field): + return graphene.types.datetime.DateTime + + @get_graphene_type_from_serializer_field.register(serializers.DateField) def convert_serializer_field_to_date_time(field): - return graphene.types.datetime.DateTime + return graphene.types.datetime.Date @get_graphene_type_from_serializer_field.register(serializers.TimeField) diff --git a/graphene_django/rest_framework/tests/test_field_converter.py b/graphene_django/rest_framework/tests/test_field_converter.py index 623cf58..6fa4ca8 100644 --- a/graphene_django/rest_framework/tests/test_field_converter.py +++ b/graphene_django/rest_framework/tests/test_field_converter.py @@ -1,8 +1,10 @@ import copy -from rest_framework import serializers -from py.test import raises import graphene +from django.db import models +from graphene import InputObjectType +from py.test import raises +from rest_framework import serializers from ..serializer_converter import convert_serializer_field from ..types import DictType @@ -14,8 +16,8 @@ def _get_type(rest_framework_field, is_input=True, **kwargs): # Remove `source=` from the field declaration. # since we are reusing the same child in when testing the required attribute - if 'child' in kwargs: - kwargs['child'] = copy.deepcopy(kwargs['child']) + if "child" in kwargs: + kwargs["child"] = copy.deepcopy(kwargs["child"]) field = rest_framework_field(**kwargs) @@ -23,11 +25,13 @@ def _get_type(rest_framework_field, is_input=True, **kwargs): def assert_conversion(rest_framework_field, graphene_field, **kwargs): - graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs) + graphene_type = _get_type( + rest_framework_field, help_text="Custom Help Text", **kwargs + ) assert isinstance(graphene_type, graphene_field) graphene_type_required = _get_type( - rest_framework_field, help_text='Custom Help Text', required=True, **kwargs + rest_framework_field, help_text="Custom Help Text", required=True, **kwargs ) assert isinstance(graphene_type_required, graphene_field) @@ -37,7 +41,7 @@ def assert_conversion(rest_framework_field, graphene_field, **kwargs): def test_should_unknown_rest_framework_field_raise_exception(): with raises(Exception) as excinfo: convert_serializer_field(None) - assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) + assert "Don't know how to convert the serializer field" in str(excinfo.value) def test_should_char_convert_string(): @@ -65,20 +69,19 @@ def test_should_base_field_convert_string(): def test_should_regex_convert_string(): - assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+') + assert_conversion(serializers.RegexField, graphene.String, regex="[0-9]+") def test_should_uuid_convert_string(): - if hasattr(serializers, 'UUIDField'): + if hasattr(serializers, "UUIDField"): assert_conversion(serializers.UUIDField, graphene.String) def test_should_model_convert_field(): - class MyModelSerializer(serializers.ModelSerializer): class Meta: model = None - fields = '__all__' + fields = "__all__" assert_conversion(MyModelSerializer, graphene.Field, is_input=False) @@ -87,8 +90,8 @@ def test_should_date_time_convert_datetime(): assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime) -def test_should_date_convert_datetime(): - assert_conversion(serializers.DateField, graphene.types.datetime.DateTime) +def test_should_date_convert_date(): + assert_conversion(serializers.DateField, graphene.types.datetime.Date) def test_should_time_convert_time(): @@ -108,7 +111,9 @@ def test_should_float_convert_float(): def test_should_decimal_convert_float(): - assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2) + assert_conversion( + serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2 + ) def test_should_list_convert_to_list(): @@ -118,7 +123,7 @@ def test_should_list_convert_to_list(): field_a = assert_conversion( serializers.ListField, graphene.List, - child=serializers.IntegerField(min_value=0, max_value=100) + child=serializers.IntegerField(min_value=0, max_value=100), ) assert field_a.of_type == graphene.Int @@ -128,6 +133,34 @@ def test_should_list_convert_to_list(): assert field_b.of_type == graphene.String +def test_should_list_serializer_convert_to_list(): + class FooModel(models.Model): + pass + + class ChildSerializer(serializers.ModelSerializer): + class Meta: + model = FooModel + fields = "__all__" + + class ParentSerializer(serializers.ModelSerializer): + child = ChildSerializer(many=True) + + class Meta: + model = FooModel + fields = "__all__" + + converted_type = convert_serializer_field( + ParentSerializer().get_fields()["child"], is_input=True + ) + assert isinstance(converted_type, graphene.List) + + converted_type = convert_serializer_field( + ParentSerializer().get_fields()["child"], is_input=False + ) + assert isinstance(converted_type, graphene.List) + assert converted_type.of_type is None + + def test_should_dict_convert_dict(): assert_conversion(serializers.DictField, DictType) @@ -141,7 +174,7 @@ def test_should_file_convert_string(): def test_should_filepath_convert_string(): - assert_conversion(serializers.FilePathField, graphene.String, path='/') + assert_conversion(serializers.FilePathField, graphene.String, path="/") def test_should_ip_convert_string(): @@ -157,6 +190,8 @@ def test_should_json_convert_jsonstring(): def test_should_multiplechoicefield_convert_to_list_of_string(): - field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3]) + field = assert_conversion( + serializers.MultipleChoiceField, graphene.List, choices=[1, 2, 3] + ) assert field.of_type == graphene.String diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 852265d..4dccc18 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -1,21 +1,40 @@ -from django.db import models -from graphene import Field +import datetime + +from graphene import Field, ResolveInfo from graphene.types.inputobjecttype import InputObjectType from py.test import raises +from py.test import mark from rest_framework import serializers from ...types import DjangoObjectType +from ..models import MyFakeModel from ..mutation import SerializerMutation -class MyFakeModel(models.Model): - cool_name = models.CharField(max_length=50) +def mock_info(): + return ResolveInfo( + None, + None, + None, + None, + schema=None, + fragments=None, + root_value=None, + operation=None, + variable_values=None, + context=None, + ) class MyModelSerializer(serializers.ModelSerializer): class Meta: model = MyFakeModel - fields = '__all__' + fields = "__all__" + + +class MyModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer class MySerializer(serializers.Serializer): @@ -28,10 +47,11 @@ class MySerializer(serializers.Serializer): def test_needs_serializer_class(): with raises(Exception) as exc: + class MyMutation(SerializerMutation): pass - assert str(exc.value) == 'serializer_class is required for the SerializerMutation' + assert str(exc.value) == "serializer_class is required for the SerializerMutation" def test_has_fields(): @@ -39,9 +59,9 @@ def test_has_fields(): class Meta: serializer_class = MySerializer - assert 'text' in MyMutation._meta.fields - assert 'model' in MyMutation._meta.fields - assert 'errors' in MyMutation._meta.fields + assert "text" in MyMutation._meta.fields + assert "model" in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields def test_has_input_fields(): @@ -49,12 +69,24 @@ def test_has_input_fields(): class Meta: serializer_class = MySerializer - assert 'text' in MyMutation.Input._meta.fields - assert 'model' in MyMutation.Input._meta.fields + assert "text" in MyMutation.Input._meta.fields + assert "model" in MyMutation.Input._meta.fields + + +def test_exclude_fields(): + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + exclude_fields = ["created"] + + assert "cool_name" in MyMutation._meta.fields + assert "created" not in MyMutation._meta.fields + assert "errors" in MyMutation._meta.fields + assert "cool_name" in MyMutation.Input._meta.fields + assert "created" not in MyMutation.Input._meta.fields def test_nested_model(): - class MyFakeModelGrapheneType(DjangoObjectType): class Meta: model = MyFakeModel @@ -63,37 +95,85 @@ def test_nested_model(): class Meta: serializer_class = MySerializer - model_field = MyMutation._meta.fields['model'] + model_field = MyMutation._meta.fields["model"] assert isinstance(model_field, Field) assert model_field.type == MyFakeModelGrapheneType - model_input = MyMutation.Input._meta.fields['model'] + model_input = MyMutation.Input._meta.fields["model"] model_input_type = model_input._type.of_type assert issubclass(model_input_type, InputObjectType) - assert 'cool_name' in model_input_type._meta.fields + assert "cool_name" in model_input_type._meta.fields + assert "created" in model_input_type._meta.fields def test_mutate_and_get_payload_success(): - class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer - result = MyMutation.mutate_and_get_payload(None, None, **{ - 'text': 'value', - 'model': { - 'cool_name': 'other_value' - } - }) + result = MyMutation.mutate_and_get_payload( + None, mock_info(), **{"text": "value", "model": {"cool_name": "other_value"}} + ) assert result.errors is None -def test_mutate_and_get_payload_error(): +@mark.django_db +def test_model_add_mutate_and_get_payload_success(): + result = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) + assert result.errors is None + assert result.cool_name == "Narf" + assert isinstance(result.created, datetime.datetime) + +@mark.django_db +def test_model_update_mutate_and_get_payload_success(): + instance = MyFakeModel.objects.create(cool_name="Narf") + result = MyModelMutation.mutate_and_get_payload( + None, mock_info(), **{"id": instance.id, "cool_name": "New Narf"} + ) + assert result.errors is None + assert result.cool_name == "New Narf" + + +@mark.django_db +def test_model_invalid_update_mutate_and_get_payload_success(): + class InvalidModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ["update"] + + with raises(Exception) as exc: + result = InvalidModelMutation.mutate_and_get_payload( + None, mock_info(), **{"cool_name": "Narf"} + ) + + assert '"id" required' in str(exc.value) + + +def test_mutate_and_get_payload_error(): class MyMutation(SerializerMutation): class Meta: serializer_class = MySerializer # missing required fields - result = MyMutation.mutate_and_get_payload(None, None, **{}) - assert len(result.errors) > 0 \ No newline at end of file + result = MyMutation.mutate_and_get_payload(None, mock_info(), **{}) + assert len(result.errors) > 0 + + +def test_model_mutate_and_get_payload_error(): + # missing required fields + result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{}) + assert len(result.errors) > 0 + + +def test_invalid_serializer_operations(): + with raises(Exception) as exc: + + class MyModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ["Add"] + + assert "model_operations" in str(exc.value) diff --git a/graphene_django/rest_framework/types.py b/graphene_django/rest_framework/types.py index 956dc43..4c84c69 100644 --- a/graphene_django/rest_framework/types.py +++ b/graphene_django/rest_framework/types.py @@ -3,8 +3,8 @@ from graphene.types.unmountedtype import UnmountedType class ErrorType(graphene.ObjectType): - field = graphene.String() - messages = graphene.List(graphene.String) + field = graphene.String(required=True) + messages = graphene.List(graphene.NonNull(graphene.String), required=True) class DictType(UnmountedType): diff --git a/graphene_django/settings.py b/graphene_django/settings.py index 46d70ee..7cd750a 100644 --- a/graphene_django/settings.py +++ b/graphene_django/settings.py @@ -26,27 +26,22 @@ except ImportError: # Copied shamelessly from Django REST Framework DEFAULTS = { - 'SCHEMA': None, - 'SCHEMA_OUTPUT': 'schema.json', - 'SCHEMA_INDENT': None, - 'MIDDLEWARE': (), + "SCHEMA": None, + "SCHEMA_OUTPUT": "schema.json", + "SCHEMA_INDENT": None, + "MIDDLEWARE": (), # Set to True if the connection fields must have # either the first or last argument - 'RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST': False, + "RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST": False, # Max items returned in ConnectionFields / FilterConnectionFields - 'RELAY_CONNECTION_MAX_LIMIT': 100, + "RELAY_CONNECTION_MAX_LIMIT": 100, } if settings.DEBUG: - DEFAULTS['MIDDLEWARE'] += ( - 'graphene_django.debug.DjangoDebugMiddleware', - ) + DEFAULTS["MIDDLEWARE"] += ("graphene_django.debug.DjangoDebugMiddleware",) # List of settings that may be in string import notation. -IMPORT_STRINGS = ( - 'MIDDLEWARE', - 'SCHEMA', -) +IMPORT_STRINGS = ("MIDDLEWARE", "SCHEMA") def perform_import(val, setting_name): @@ -69,12 +64,17 @@ def import_from_string(val, setting_name): """ try: # Nod to tastypie's use of importlib. - parts = val.split('.') - module_path, class_name = '.'.join(parts[:-1]), parts[-1] + parts = val.split(".") + module_path, class_name = ".".join(parts[:-1]), parts[-1] module = importlib.import_module(module_path) return getattr(module, class_name) except (ImportError, AttributeError) as e: - msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) + msg = "Could not import '%s' for Graphene setting '%s'. %s: %s." % ( + val, + setting_name, + e.__class__.__name__, + e, + ) raise ImportError(msg) @@ -96,8 +96,8 @@ class GrapheneSettings(object): @property def user_settings(self): - if not hasattr(self, '_user_settings'): - self._user_settings = getattr(settings, 'GRAPHENE', {}) + if not hasattr(self, "_user_settings"): + self._user_settings = getattr(settings, "GRAPHENE", {}) return self._user_settings def __getattr__(self, attr): @@ -125,8 +125,8 @@ graphene_settings = GrapheneSettings(None, DEFAULTS, IMPORT_STRINGS) def reload_graphene_settings(*args, **kwargs): global graphene_settings - setting, value = kwargs['setting'], kwargs['value'] - if setting == 'GRAPHENE': + setting, value = kwargs["setting"], kwargs["value"] + if setting == "GRAPHENE": graphene_settings = GrapheneSettings(value, DEFAULTS, IMPORT_STRINGS) diff --git a/graphene_django/templates/graphene/graphiql.html b/graphene_django/templates/graphene/graphiql.html index 949b850..1ba0613 100644 --- a/graphene_django/templates/graphene/graphiql.html +++ b/graphene_django/templates/graphene/graphiql.html @@ -16,11 +16,11 @@ add "&raw" to the end of the URL within a browser. width: 100%; } - - - - - + + + + +