Compare commits

..

No commits in common. "main" and "v3.1.5" have entirely different histories.
main ... v3.1.5

61 changed files with 344 additions and 1445 deletions

View File

@ -6,18 +6,13 @@ on:
- 'v*' - 'v*'
jobs: jobs:
lint: build:
uses: ./.github/workflows/lint.yml
tests:
uses: ./.github/workflows/tests.yml
release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [lint, tests]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Set up Python 3.11 - name: Set up Python 3.11
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: '3.11' python-version: '3.11'
- name: Build wheel and source tarball - name: Build wheel and source tarball

View File

@ -4,16 +4,15 @@ on:
push: push:
branches: ["main"] branches: ["main"]
pull_request: pull_request:
workflow_call:
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Set up Python 3.11 - name: Set up Python 3.11
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: '3.11' python-version: '3.11'
- name: Install dependencies - name: Install dependencies

View File

@ -4,7 +4,6 @@ on:
push: push:
branches: ["main"] branches: ["main"]
pull_request: pull_request:
workflow_call:
jobs: jobs:
build: build:
@ -12,29 +11,17 @@ jobs:
strategy: strategy:
max-parallel: 4 max-parallel: 4
matrix: matrix:
django: ["3.2", "4.2", "5.0", "5.1", "5.2"] django: ["3.2", "4.1", "4.2"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python-version: ["3.8", "3.9", "3.10"]
exclude: include:
- django: "3.2" - django: "4.1"
python-version: "3.11"
- django: "4.2"
python-version: "3.11" python-version: "3.11"
- django: "3.2"
python-version: "3.12"
- django: "5.0"
python-version: "3.8"
- django: "5.0"
python-version: "3.9"
- django: "5.1"
python-version: "3.8"
- django: "5.1"
python-version: "3.9"
- django: "5.2"
python-version: "3.8"
- django: "5.2"
python-version: "3.9"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies

View File

@ -2,7 +2,7 @@ default_language_version:
python: python3.11 python: python3.11
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 rev: v4.4.0
hooks: hooks:
- id: check-merge-conflict - id: check-merge-conflict
- id: check-json - id: check-json
@ -15,9 +15,12 @@ repos:
- --autofix - --autofix
- id: trailing-whitespace - id: trailing-whitespace
exclude: README.md exclude: README.md
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.2 rev: v0.0.283
hooks: hooks:
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes] args: [--fix, --exit-non-zero-on-fix, --show-fixes]
- id: ruff-format

View File

@ -1,18 +0,0 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/conf.py
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt

View File

@ -13,7 +13,6 @@ ignore = [
"B017", # pytest.raises(Exception) should be considered evil "B017", # pytest.raises(Exception) should be considered evil
"B028", # warnings.warn called without an explicit stacklevel keyword argument "B028", # warnings.warn called without an explicit stacklevel keyword argument
"B904", # check for raise statements in exception handlers that lack a from clause "B904", # check for raise statements in exception handlers that lack a from clause
"W191", # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules
] ]
exclude = [ exclude = [
@ -25,8 +24,10 @@ target-version = "py38"
[per-file-ignores] [per-file-ignores]
# Ignore unused imports (F401) in these files # Ignore unused imports (F401) in these files
"__init__.py" = ["F401"] "__init__.py" = ["F401"]
"graphene_django/compat.py" = ["F401"]
[isort] [isort]
known-first-party = ["graphene", "graphene-django"] known-first-party = ["graphene", "graphene-django"]
known-local-folder = ["cookbook"] known-local-folder = ["cookbook"]
force-wrap-aliases = true
combine-as-imports = true combine-as-imports = true

View File

@ -33,7 +33,7 @@ make tests
## Opening Pull Requests ## Opening Pull Requests
Please fork the project and open a pull request against the `main` branch. Please fork the project and open a pull request against the master branch.
This will trigger a series of test and lint checks. This will trigger a series of test and lint checks.

View File

@ -14,7 +14,7 @@ tests:
.PHONY: format ## Format code .PHONY: format ## Format code
format: format:
ruff format graphene_django examples setup.py black graphene_django examples setup.py
.PHONY: lint ## Lint code .PHONY: lint ## Lint code
lint: lint:

View File

@ -30,7 +30,7 @@ Graphene-Django is an open-source library that provides seamless integration bet
To install Graphene-Django, run the following command: To install Graphene-Django, run the following command:
```sh ```
pip install graphene-django pip install graphene-django
``` ```
@ -114,11 +114,11 @@ class MyModelAPITestCase(GraphQLTestCase):
## Contributing ## Contributing
Contributions to Graphene-Django are always welcome! To get started, check the repository's [issue tracker](https://github.com/graphql-python/graphene-django/issues) and [contribution guidelines](https://github.com/graphql-python/graphene-django/blob/main/CONTRIBUTING.md). Contributions to Graphene-Django are always welcome! To get started, check the repository's [issue tracker](https://github.com/graphql-python/graphene-django/issues) and [contribution guidelines](https://github.com/graphql-python/graphene-django/blob/master/CONTRIBUTING.md).
## License ## License
Graphene-Django is released under the [MIT License](https://github.com/graphql-python/graphene-django/blob/main/LICENSE). Graphene-Django is released under the [MIT License](https://github.com/graphql-python/graphene-django/blob/master/LICENSE).
## Resources ## Resources

View File

@ -33,6 +33,5 @@ For more advanced use, check out the Relay tutorial.
authorization authorization
debug debug
introspection introspection
validation
testing testing
settings settings

View File

@ -6,7 +6,7 @@ Graphene-Django can be customised using settings. This page explains each settin
Usage Usage
----- -----
Add settings to your Django project by creating a Dictionary with name ``GRAPHENE`` in the project's ``settings.py``: Add settings to your Django project by creating a Dictonary with name ``GRAPHENE`` in the project's ``settings.py``:
.. code:: python .. code:: python
@ -142,15 +142,6 @@ Default: ``False``
# ] # ]
``DJANGO_CHOICE_FIELD_ENUM_CONVERT``
--------------------------------------
When set to ``True`` Django choice fields are automatically converted into Enum types.
Can be disabled globally by setting it to ``False``.
Default: ``True``
``DJANGO_CHOICE_FIELD_ENUM_V2_NAMING`` ``DJANGO_CHOICE_FIELD_ENUM_V2_NAMING``
-------------------------------------- --------------------------------------
@ -206,6 +197,9 @@ Set to ``False`` if you want to disable GraphiQL headers editor tab for some rea
This setting is passed to ``headerEditorEnabled`` GraphiQL options, for details refer to GraphiQLDocs_. This setting is passed to ``headerEditorEnabled`` GraphiQL options, for details refer to GraphiQLDocs_.
.. _GraphiQLDocs: https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
Default: ``True`` Default: ``True``
.. code:: python .. code:: python
@ -236,6 +230,8 @@ Set to ``True`` if you want to persist GraphiQL headers after refreshing the pag
This setting is passed to ``shouldPersistHeaders`` GraphiQL options, for details refer to GraphiQLDocs_. This setting is passed to ``shouldPersistHeaders`` GraphiQL options, for details refer to GraphiQLDocs_.
.. _GraphiQLDocs: https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
Default: ``False`` Default: ``False``
@ -244,48 +240,3 @@ Default: ``False``
GRAPHENE = { GRAPHENE = {
'GRAPHIQL_SHOULD_PERSIST_HEADERS': False, 'GRAPHIQL_SHOULD_PERSIST_HEADERS': False,
} }
``GRAPHIQL_INPUT_VALUE_DEPRECATION``
------------------------------------
Set to ``True`` if you want GraphiQL to show any deprecated fields on input object types' docs.
For example, having this schema:
.. code:: python
class MyMutationInputType(graphene.InputObjectType):
old_field = graphene.String(deprecation_reason="You should now use 'newField' instead.")
new_field = graphene.String()
class MyMutation(graphene.Mutation):
class Arguments:
input = types.MyMutationInputType()
GraphiQL will add a ``Show Deprecated Fields`` button to toggle information display on ``oldField`` and its deprecation
reason. Otherwise, you would get neither a button nor any information at all on ``oldField``.
This setting is passed to ``inputValueDeprecation`` GraphiQL options, for details refer to GraphiQLDocs_.
Default: ``False``
.. code:: python
GRAPHENE = {
'GRAPHIQL_INPUT_VALUE_DEPRECATION': False,
}
.. _GraphiQLDocs: https://graphiql-test.netlify.app/typedoc/modules/graphiql_react#graphiqlprovider-2
``MAX_VALIDATION_ERRORS``
------------------------------------
In case ``validation_rules`` are provided to ``GraphQLView``, if this is set to a non-negative ``int`` value,
``graphql.validation.validate`` will stop validation after this number of errors has been reached.
If not set or set to ``None``, the maximum number of errors will follow ``graphql.validation.validate`` default
*i.e.* 100.
Default: ``None``

View File

@ -104,7 +104,7 @@ Load some test data
Now is a good time to load up some test data. The easiest option will be Now is a good time to load up some test data. The easiest option will be
to `download the to `download the
ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/main/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__ ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/master/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__
fixture and place it in fixture and place it in
``cookbook/ingredients/fixtures/ingredients.json``. You can then run the ``cookbook/ingredients/fixtures/ingredients.json``. You can then run the
following: following:

View File

@ -7,7 +7,7 @@ Graphene has a number of additional features that are designed to make
working with Django *really simple*. working with Django *really simple*.
Note: The code in this quickstart is pulled from the `cookbook example Note: The code in this quickstart is pulled from the `cookbook example
app <https://github.com/graphql-python/graphene-django/tree/main/examples/cookbook>`__. app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__.
A good idea is to check the following things first: A good idea is to check the following things first:
@ -87,7 +87,7 @@ Load some test data
Now is a good time to load up some test data. The easiest option will be Now is a good time to load up some test data. The easiest option will be
to `download the to `download the
ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/main/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__ ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/master/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__
fixture and place it in fixture and place it in
``cookbook/ingredients/fixtures/ingredients.json``. You can then run the ``cookbook/ingredients/fixtures/ingredients.json``. You can then run the
following: following:

View File

@ -1,29 +0,0 @@
Query Validation
================
Graphene-Django supports query validation by allowing passing a list of validation rules (subclasses of `ValidationRule <https://github.com/graphql-python/graphql-core/blob/v3.2.3/src/graphql/validation/rules/__init__.py>`_ from graphql-core) to the ``validation_rules`` option in ``GraphQLView``.
.. code:: python
from django.urls import path
from graphene.validation import DisableIntrospection
from graphene_django.views import GraphQLView
urlpatterns = [
path("graphql", GraphQLView.as_view(validation_rules=(DisableIntrospection,))),
]
or
.. code:: python
from django.urls import path
from graphene.validation import DisableIntrospection
from graphene_django.views import GraphQLView
class View(GraphQLView):
validation_rules = (DisableIntrospection,)
urlpatterns = [
path("graphql", View.as_view()),
]

View File

@ -231,7 +231,7 @@
"fields": { "fields": {
"category": 3, "category": 3,
"name": "Newt", "name": "Newt",
"notes": "Braised and Confused" "notes": "Braised and Confuesd"
}, },
"model": "ingredients.ingredient", "model": "ingredients.ingredient",
"pk": 5 "pk": 5

View File

@ -1,5 +1,5 @@
graphene>=2.1,<3 graphene>=2.1,<3
graphene-django>=2.1,<3 graphene-django>=2.1,<3
graphql-core>=2.1,<3 graphql-core>=2.1,<3
django==4.2.18 django==3.1.14
django-filter>=2 django-filter>=2

View File

@ -28,5 +28,3 @@ TEMPLATES = [
GRAPHENE = {"SCHEMA": "graphene_django.tests.schema_view.schema"} GRAPHENE = {"SCHEMA": "graphene_django.tests.schema_view.schema"}
ROOT_URLCONF = "graphene_django.tests.urls" ROOT_URLCONF = "graphene_django.tests.urls"
USE_TZ = True

View File

@ -28,7 +28,7 @@ def initialize():
# Yeah, technically it's Corellian. But it flew in the service of the rebels, # Yeah, technically it's Corellian. But it flew in the service of the rebels,
# so for the purposes of this demo it's a rebel ship. # so for the purposes of this demo it's a rebel ship.
falcon = Ship(id="4", name="Millennium Falcon", faction=rebels) falcon = Ship(id="4", name="Millenium Falcon", faction=rebels)
falcon.save() falcon.save()
homeOne = Ship(id="5", name="Home One", faction=rebels) homeOne = Ship(id="5", name="Home One", faction=rebels)

View File

@ -1,5 +1,5 @@
import graphene import graphene
from graphene import Schema, relay from graphene import Schema, relay, resolve_only_args
from graphene_django import DjangoConnectionField, DjangoObjectType from graphene_django import DjangoConnectionField, DjangoObjectType
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship, get_ships from .data import create_ship, get_empire, get_faction, get_rebels, get_ship, get_ships
@ -62,13 +62,16 @@ class Query(graphene.ObjectType):
node = relay.Node.Field() node = relay.Node.Field()
ships = DjangoConnectionField(Ship, description="All the ships.") ships = DjangoConnectionField(Ship, description="All the ships.")
def resolve_ships(self, info): @resolve_only_args
def resolve_ships(self):
return get_ships() return get_ships()
def resolve_rebels(self, info): @resolve_only_args
def resolve_rebels(self):
return get_rebels() return get_rebels()
def resolve_empire(self, info): @resolve_only_args
def resolve_empire(self):
return get_empire() return get_empire()

View File

@ -40,7 +40,7 @@ def test_mutations():
{"node": {"id": "U2hpcDox", "name": "X-Wing"}}, {"node": {"id": "U2hpcDox", "name": "X-Wing"}},
{"node": {"id": "U2hpcDoy", "name": "Y-Wing"}}, {"node": {"id": "U2hpcDoy", "name": "Y-Wing"}},
{"node": {"id": "U2hpcDoz", "name": "A-Wing"}}, {"node": {"id": "U2hpcDoz", "name": "A-Wing"}},
{"node": {"id": "U2hpcDo0", "name": "Millennium Falcon"}}, {"node": {"id": "U2hpcDo0", "name": "Millenium Falcon"}},
{"node": {"id": "U2hpcDo1", "name": "Home One"}}, {"node": {"id": "U2hpcDo1", "name": "Home One"}},
{"node": {"id": "U2hpcDo5", "name": "Peter"}}, {"node": {"id": "U2hpcDo5", "name": "Peter"}},
] ]

View File

@ -2,7 +2,7 @@ from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType from .types import DjangoObjectType
from .utils import bypass_get_queryset from .utils import bypass_get_queryset
__version__ = "3.2.3" __version__ = "3.1.5"
__all__ = [ __all__ = [
"__version__", "__version__",

View File

@ -1,11 +1,7 @@
import sys
from collections.abc import Callable
from pathlib import PurePath
# For backwards compatibility, we import JSONField to have it available for import via # For backwards compatibility, we import JSONField to have it available for import via
# this compat module (https://github.com/graphql-python/graphene-django/issues/1428). # this compat module (https://github.com/graphql-python/graphene-django/issues/1428).
# Django's JSONField is available in Django 3.2+ (the minimum version we support) # Django's JSONField is available in Django 3.2+ (the minimum version we support)
from django.db.models import Choices, JSONField from django.db.models import JSONField
class MissingType: class MissingType:
@ -23,43 +19,4 @@ try:
RangeField, RangeField,
) )
except ImportError: except ImportError:
IntegerRangeField, HStoreField, RangeField = (MissingType,) * 3 IntegerRangeField, ArrayField, HStoreField, RangeField = (MissingType,) * 4
# For unit tests we fake ArrayField using JSONFields
if any(
PurePath(sys.argv[0]).match(p)
for p in [
"**/pytest",
"**/py.test",
"**/pytest/__main__.py",
]
):
class ArrayField(JSONField):
def __init__(self, *args, **kwargs):
if len(args) > 0:
self.base_field = args[0]
super().__init__(**kwargs)
else:
ArrayField = MissingType
try:
from django.utils.choices import normalize_choices
except ImportError:
def normalize_choices(choices):
if isinstance(choices, type) and issubclass(choices, Choices):
choices = choices.choices
if isinstance(choices, Callable):
choices = choices()
# In restframework==3.15.0, choices are not passed
# as OrderedDict anymore, so it's safer to check
# for a dict
if isinstance(choices, dict):
choices = choices.items()
return choices

View File

@ -1,4 +1,5 @@
import inspect import inspect
from collections import OrderedDict
from functools import partial, singledispatch, wraps from functools import partial, singledispatch, wraps
from django.db import models from django.db import models
@ -36,7 +37,7 @@ except ImportError:
from graphql import assert_valid_name as assert_name from graphql import assert_valid_name as assert_name
from graphql.pyutils import register_description from graphql.pyutils import register_description
from .compat import ArrayField, HStoreField, RangeField, normalize_choices from .compat import ArrayField, HStoreField, RangeField
from .fields import DjangoConnectionField, DjangoListField from .fields import DjangoConnectionField, DjangoListField
from .settings import graphene_settings from .settings import graphene_settings
from .utils.str_converters import to_const from .utils.str_converters import to_const
@ -60,24 +61,6 @@ class BlankValueField(Field):
return blank_field_wrapper(resolver) return blank_field_wrapper(resolver)
class EnumValueField(BlankValueField):
def wrap_resolve(self, parent_resolver):
resolver = super().wrap_resolve(parent_resolver)
# create custom resolver
def enum_field_wrapper(func):
@wraps(func)
def wrapped_resolver(*args, **kwargs):
return_value = func(*args, **kwargs)
if isinstance(return_value, models.Choices):
return_value = return_value.value
return return_value
return wrapped_resolver
return enum_field_wrapper(resolver)
def convert_choice_name(name): def convert_choice_name(name):
name = to_const(force_str(name)) name = to_const(force_str(name))
try: try:
@ -89,7 +72,8 @@ def convert_choice_name(name):
def get_choices(choices): def get_choices(choices):
converted_names = [] converted_names = []
choices = normalize_choices(choices) if isinstance(choices, OrderedDict):
choices = choices.items()
for value, help_text in choices: for value, help_text in choices:
if isinstance(help_text, (tuple, list)): if isinstance(help_text, (tuple, list)):
yield from get_choices(help_text) yield from get_choices(help_text)
@ -149,24 +133,20 @@ def convert_choice_field_to_enum(field, name=None):
def convert_django_field_with_choices( def convert_django_field_with_choices(
field, registry=None, convert_choices_to_enum=None field, registry=None, convert_choices_to_enum=True
): ):
if registry is not None: if registry is not None:
converted = registry.get_converted_field(field) converted = registry.get_converted_field(field)
if converted: if converted:
return converted return converted
choices = getattr(field, "choices", None) choices = getattr(field, "choices", None)
if convert_choices_to_enum is None:
convert_choices_to_enum = bool(
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT
)
if choices and convert_choices_to_enum: if choices and convert_choices_to_enum:
EnumCls = convert_choice_field_to_enum(field) EnumCls = convert_choice_field_to_enum(field)
required = not (field.blank or field.null) required = not (field.blank or field.null)
converted = EnumCls( converted = EnumCls(
description=get_django_field_description(field), required=required description=get_django_field_description(field), required=required
).mount_as(EnumValueField) ).mount_as(BlankValueField)
else: else:
converted = convert_django_field(field, registry) converted = convert_django_field(field, registry)
if registry is not None: if registry is not None:
@ -199,13 +179,19 @@ def convert_field_to_string(field, registry=None):
) )
@convert_django_field.register(models.AutoField)
@convert_django_field.register(models.BigAutoField) @convert_django_field.register(models.BigAutoField)
@convert_django_field.register(models.SmallAutoField) @convert_django_field.register(models.AutoField)
def convert_field_to_id(field, registry=None): def convert_field_to_id(field, registry=None):
return ID(description=get_django_field_description(field), required=not field.null) return ID(description=get_django_field_description(field), required=not field.null)
if hasattr(models, "SmallAutoField"):
@convert_django_field.register(models.SmallAutoField)
def convert_field_small_to_id(field, registry=None):
return convert_field_to_id(field, registry)
@convert_django_field.register(models.UUIDField) @convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None): def convert_field_to_uuid(field, registry=None):
return UUID( return UUID(

View File

@ -20,20 +20,17 @@ from .utils import maybe_queryset
class DjangoListField(Field): class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs): def __init__(self, _type, *args, **kwargs):
from .types import DjangoObjectType
if isinstance(_type, NonNull): if isinstance(_type, NonNull):
_type = _type.of_type _type = _type.of_type
# Django would never return a Set of None vvvvvvv # Django would never return a Set of None vvvvvvv
super().__init__(List(NonNull(_type)), *args, **kwargs) super().__init__(List(NonNull(_type)), *args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
assert issubclass( assert issubclass(
self._underlying_type, DjangoObjectType self._underlying_type, DjangoObjectType
), "DjangoListField only accepts DjangoObjectType types as underlying type" ), "DjangoListField only accepts DjangoObjectType types"
return super().type
@property @property
def _underlying_type(self): def _underlying_type(self):
@ -197,7 +194,7 @@ class DjangoConnectionField(ConnectionField):
enforce_first_or_last, enforce_first_or_last,
root, root,
info, info,
**args, **args
): ):
first = args.get("first") first = args.get("first")
last = args.get("last") last = args.get("last")
@ -247,7 +244,7 @@ class DjangoConnectionField(ConnectionField):
def wrap_resolve(self, parent_resolver): def wrap_resolve(self, parent_resolver):
return partial( return partial(
self.connection_resolver, self.connection_resolver,
self.resolver or parent_resolver, parent_resolver,
self.connection_type, self.connection_type,
self.get_manager(), self.get_manager(),
self.get_queryset_resolver(), self.get_queryset_resolver(),

View File

@ -36,7 +36,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
extra_filter_meta=None, extra_filter_meta=None,
filterset_class=None, filterset_class=None,
*args, *args,
**kwargs, **kwargs
): ):
self._fields = fields self._fields = fields
self._provided_filterset_class = filterset_class self._provided_filterset_class = filterset_class

View File

@ -1,36 +1,13 @@
from django_filters.constants import EMPTY_VALUES from django_filters.constants import EMPTY_VALUES
from django_filters.filters import FilterMethod
from .typed_filter import TypedFilter from .typed_filter import TypedFilter
class ArrayFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)
class ArrayFilter(TypedFilter): class ArrayFilter(TypedFilter):
""" """
Filter made for PostgreSQL ArrayField. Filter made for PostgreSQL ArrayField.
""" """
@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ArrayFilterMethod that consider empty lists as values.
Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ArrayFilterMethod(self)
def filter(self, qs, value): def filter(self, qs, value):
""" """
Override the default filter class to check first whether the list is Override the default filter class to check first whether the list is

View File

@ -1,36 +1,12 @@
from django_filters.filters import FilterMethod
from .typed_filter import TypedFilter from .typed_filter import TypedFilter
class ListFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)
class ListFilter(TypedFilter): class ListFilter(TypedFilter):
""" """
Filter that takes a list of value as input. Filter that takes a list of value as input.
It is for example used for `__in` filters. It is for example used for `__in` filters.
""" """
@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ListFilterMethod that consider empty lists as values.
Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ListFilterMethod(self)
def filter(self, qs, value): def filter(self, qs, value):
""" """
Override the default filter class to check first whether the list is Override the default filter class to check first whether the list is

View File

@ -1,4 +1,4 @@
from functools import reduce from unittest.mock import MagicMock
import pytest import pytest
from django.db import models from django.db import models
@ -25,15 +25,15 @@ else:
) )
STORE = {"events": []}
class Event(models.Model): class Event(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50)) tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField()) tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField()) random_field = ArrayField(models.BooleanField())
def __repr__(self):
return f"Event [{self.name}]"
@pytest.fixture @pytest.fixture
def EventFilterSet(): def EventFilterSet():
@ -44,18 +44,10 @@ def EventFilterSet():
"name": ["exact", "contains"], "name": ["exact", "contains"],
} }
# Those are actually usable with our Query fixture below # Those are actually usable with our Query fixture bellow
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact") tags = ArrayFilter(field_name="tags", lookup_expr="exact")
tags__len = ArrayFilter(
field_name="tags", lookup_expr="len", input_type=graphene.Int
)
tags__len__in = ArrayFilter(
field_name="tags",
method="tags__len__in_filter",
input_type=graphene.List(graphene.Int),
)
# Those are actually not usable and only to check type declarations # Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains") tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
@ -69,14 +61,6 @@ def EventFilterSet():
) )
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact") random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
def tags__len__in_filter(self, queryset, _name, value):
if not value:
return queryset.none()
return reduce(
lambda q1, q2: q1.union(q2),
[queryset.filter(tags__len=v) for v in value],
).distinct()
return EventFilterSet return EventFilterSet
@ -99,94 +83,68 @@ def Query(EventType):
we are running unit tests in sqlite which does not have ArrayFields. we are running unit tests in sqlite which does not have ArrayFields.
""" """
events = [
Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]),
Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType) events = DjangoFilterConnectionField(EventType)
def resolve_events(self, info, **kwargs): def resolve_events(self, info, **kwargs):
class FakeQuerySet(QuerySet): events = [
def __init__(self, model=None): Event(name="Live Show", tags=["concert", "music", "rock"]),
self.model = Event Event(name="Musical", tags=["movie", "music"]),
self.__store = list(events) Event(name="Ballet", tags=["concert", "dance"]),
Event(name="Speech", tags=[]),
]
def all(self): STORE["events"] = events
return self
def filter(self, **kwargs): m_queryset = MagicMock(spec=QuerySet)
queryset = FakeQuerySet() m_queryset.model = Event
queryset.__store = list(self.__store)
if "tags__contains" in kwargs: def filter_events(**kwargs):
queryset.__store = list( if "tags__contains" in kwargs:
filter( STORE["events"] = list(
lambda e: set(kwargs["tags__contains"]).issubset( filter(
set(e.tags) lambda e: set(kwargs["tags__contains"]).issubset(
), set(e.tags)
queryset.__store, ),
) STORE["events"],
) )
if "tags__overlap" in kwargs: )
queryset.__store = list( if "tags__overlap" in kwargs:
filter( STORE["events"] = list(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint( filter(
set(e.tags) lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
), set(e.tags)
queryset.__store, ),
) STORE["events"],
) )
if "tags__exact" in kwargs: )
queryset.__store = list( if "tags__exact" in kwargs:
filter( STORE["events"] = list(
lambda e: set(kwargs["tags__exact"]) == set(e.tags), filter(
queryset.__store, lambda e: set(kwargs["tags__exact"]) == set(e.tags),
) STORE["events"],
) )
if "tags__len" in kwargs: )
queryset.__store = list(
filter(
lambda e: len(e.tags) == kwargs["tags__len"],
queryset.__store,
)
)
return queryset
def union(self, *args): def mock_queryset_filter(*args, **kwargs):
queryset = FakeQuerySet() filter_events(**kwargs)
queryset.__store = self.__store return m_queryset
for arg in args:
queryset.__store += arg.__store
return queryset
def none(self): def mock_queryset_none(*args, **kwargs):
queryset = FakeQuerySet() STORE["events"] = []
queryset.__store = [] return m_queryset
return queryset
def count(self): def mock_queryset_count(*args, **kwargs):
return len(self.__store) return len(STORE["events"])
def distinct(self): m_queryset.all.return_value = m_queryset
queryset = FakeQuerySet() m_queryset.filter.side_effect = mock_queryset_filter
queryset.__store = [] m_queryset.none.side_effect = mock_queryset_none
for event in self.__store: m_queryset.count.side_effect = mock_queryset_count
if event not in queryset.__store: m_queryset.__getitem__.side_effect = lambda index: STORE[
queryset.__store.append(event) "events"
queryset.__store = sorted(queryset.__store, key=lambda e: e.name) ].__getitem__(index)
return queryset
def __getitem__(self, index): return m_queryset
return self.__store[index]
return FakeQuerySet()
return Query return Query
@pytest.fixture
def schema(Query):
return graphene.Schema(query=Query)

View File

@ -1,14 +1,18 @@
import pytest import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_multiple(schema): def test_array_field_contains_multiple(Query):
""" """
Test contains filter on a array field of string. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Contains: ["concert", "music"]) { events (tags_Contains: ["concert", "music"]) {
@ -28,11 +32,13 @@ def test_array_field_contains_multiple(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_one(schema): def test_array_field_contains_one(Query):
""" """
Test contains filter on a array field of string. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Contains: ["music"]) { events (tags_Contains: ["music"]) {
@ -53,11 +59,13 @@ def test_array_field_contains_one(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_empty_list(schema): def test_array_field_contains_empty_list(Query):
""" """
Test contains filter on a array field of string. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Contains: []) { events (tags_Contains: []) {

View File

@ -1,186 +0,0 @@
import pytest
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_len_filter(schema):
query = """
query {
events (tags_Len: 2) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
]
query = """
query {
events (tags_Len: 0) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Speech"}},
]
query = """
query {
events (tags_Len: 10) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
query = """
query {
events (tags_Len: "2") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == 'Int cannot represent non-integer value: "2"'
query = """
query {
events (tags_Len: True) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == "Int cannot represent non-integer value: True"
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_custom_filter(schema):
query = """
query {
events (tags_Len_In: 2) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Ballet"}},
{"node": {"name": "Musical"}},
]
query = """
query {
events (tags_Len_In: [0, 2]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Ballet"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Speech"}},
]
query = """
query {
events (tags_Len_In: [10]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
query = """
query {
events (tags_Len_In: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
query = """
query {
events (tags_Len_In: "12") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == 'Int cannot represent non-integer value: "12"'
query = """
query {
events (tags_Len_In: True) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == "Int cannot represent non-integer value: True"

View File

@ -1,14 +1,18 @@
import pytest import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_no_match(schema): def test_array_field_exact_no_match(Query):
""" """
Test exact filter on a array field of string. Test exact filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags: ["concert", "music"]) { events (tags: ["concert", "music"]) {
@ -26,11 +30,13 @@ def test_array_field_exact_no_match(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_match(schema): def test_array_field_exact_match(Query):
""" """
Test exact filter on a array field of string. Test exact filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags: ["movie", "music"]) { events (tags: ["movie", "music"]) {
@ -50,11 +56,13 @@ def test_array_field_exact_match(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_empty_list(schema): def test_array_field_exact_empty_list(Query):
""" """
Test exact filter on a array field of string. Test exact filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags: []) { events (tags: []) {
@ -74,10 +82,11 @@ def test_array_field_exact_empty_list(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_filter_schema_type(schema): def test_array_field_filter_schema_type(Query):
""" """
Check that the type in the filter is an array field like on the object type. Check that the type in the filter is an array field like on the object type.
""" """
schema = Schema(query=Query)
schema_str = str(schema) schema_str = str(schema)
assert ( assert (
@ -103,8 +112,6 @@ def test_array_field_filter_schema_type(schema):
"tags_Contains": "[String!]", "tags_Contains": "[String!]",
"tags_Overlap": "[String!]", "tags_Overlap": "[String!]",
"tags": "[String!]", "tags": "[String!]",
"tags_Len": "Int",
"tags_Len_In": "[Int]",
"tagsIds_Contains": "[Int!]", "tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]", "tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]", "tagsIds": "[Int!]",

View File

@ -1,14 +1,18 @@
import pytest import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_multiple(schema): def test_array_field_overlap_multiple(Query):
""" """
Test overlap filter on a array field of string. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Overlap: ["concert", "music"]) { events (tags_Overlap: ["concert", "music"]) {
@ -30,11 +34,13 @@ def test_array_field_overlap_multiple(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_one(schema): def test_array_field_overlap_one(Query):
""" """
Test overlap filter on a array field of string. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Overlap: ["music"]) { events (tags_Overlap: ["music"]) {
@ -55,11 +61,13 @@ def test_array_field_overlap_one(schema):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_empty_list(schema): def test_array_field_overlap_empty_list(Query):
""" """
Test overlap filter on a array field of string. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Overlap: []) { events (tags_Overlap: []) {

View File

@ -789,7 +789,7 @@ def test_order_by():
query = """ query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(orderBy: "-firstname") { allReporters(orderBy: "-firtsnaMe") {
edges { edges {
node { node {
firstName firstName
@ -802,7 +802,7 @@ def test_order_by():
assert result.errors assert result.errors
def test_order_by_is_preserved(): def test_order_by_is_perserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter

View File

@ -1,8 +1,4 @@
import operator
from functools import reduce
import pytest import pytest
from django.db.models import Q
from django_filters import FilterSet from django_filters import FilterSet
import graphene import graphene
@ -48,10 +44,6 @@ def schema():
only_first = TypedFilter( only_first = TypedFilter(
input_type=graphene.Boolean, method="only_first_filter" input_type=graphene.Boolean, method="only_first_filter"
) )
headline_search = ListFilter(
method="headline_search_filter",
input_type=graphene.List(graphene.String),
)
def first_n_filter(self, queryset, _name, value): def first_n_filter(self, queryset, _name, value):
return queryset[:value] return queryset[:value]
@ -62,13 +54,6 @@ def schema():
else: else:
return queryset return queryset
def headline_search_filter(self, queryset, _name, value):
if not value:
return queryset.none()
return queryset.filter(
reduce(operator.or_, [Q(headline__icontains=v) for v in value])
)
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
@ -102,7 +87,6 @@ def test_typed_filter_schema(schema):
"lang_InStr": "[String]", "lang_InStr": "[String]",
"firstN": "Int", "firstN": "Int",
"onlyFirst": "Boolean", "onlyFirst": "Boolean",
"headlineSearch": "[String]",
} }
all_articles_filters = ( all_articles_filters = (
@ -120,7 +104,24 @@ def test_typed_filters_work(schema):
Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es") Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es") Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en") Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en")
Article.objects.create(headline="AB", reporter=reporter, editor=reporter, lang="es")
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }' query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
@ -136,7 +137,7 @@ def test_typed_filters_work(schema):
assert not result.errors assert not result.errors
assert result.data["articles"]["edges"] == [ assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}}, {"node": {"headline": "A"}},
{"node": {"headline": "AB"}}, {"node": {"headline": "B"}},
] ]
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }" query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
@ -146,86 +147,3 @@ def test_typed_filters_work(schema):
assert result.data["articles"]["edges"] == [ assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}}, {"node": {"headline": "A"}},
] ]
def test_list_filters_work(schema):
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en")
Article.objects.create(headline="AB", reporter=reporter, editor=reporter, lang="es")
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
]
query = "query { articles (lang_InStr: []) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == []
query = "query { articles (lang_InStr: null) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
{"node": {"headline": "C"}},
]
query = 'query { articles (headlineSearch: ["a", "B"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
]
query = "query { articles (headlineSearch: []) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == []
query = "query { articles (headlineSearch: null) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
{"node": {"headline": "C"}},
]
query = 'query { articles (headlineSearch: [""]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
{"node": {"headline": "C"}},
]

View File

@ -43,7 +43,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
isinstance(filter_field, TypedFilter) isinstance(filter_field, TypedFilter)
and filter_field.input_type is not None and filter_field.input_type is not None
): ):
# First check if the filter input type has been explicitly given # First check if the filter input type has been explicitely given
field_type = filter_field.input_type field_type = filter_field.input_type
else: else:
if name not in filterset_class.declared_filters or isinstance( if name not in filterset_class.declared_filters or isinstance(
@ -145,7 +145,7 @@ def replace_csv_filters(filterset_class):
label=filter_field.label, label=filter_field.label,
method=filter_field.method, method=filter_field.method,
exclude=filter_field.exclude, exclude=filter_field.exclude,
**filter_field.extra, **filter_field.extra
) )
elif filter_type == "range": elif filter_type == "range":
filterset_class.base_filters[name] = RangeFilter( filterset_class.base_filters[name] = RangeFilter(
@ -154,5 +154,5 @@ def replace_csv_filters(filterset_class):
label=filter_field.label, label=filter_field.label,
method=filter_field.method, method=filter_field.method,
exclude=filter_field.exclude, exclude=filter_field.exclude,
**filter_field.extra, **filter_field.extra
) )

View File

@ -23,7 +23,8 @@ def fields_for_form(form, only_fields, exclude_fields):
for name, field in form.fields.items(): for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_excluded = ( is_excluded = (
name in exclude_fields # or name
in exclude_fields # or
# name in already_created_fields # name in already_created_fields
) )

View File

@ -1,4 +1,4 @@
from django import VERSION as DJANGO_VERSION, forms from django import forms
from pytest import raises from pytest import raises
from graphene import ( from graphene import (
@ -19,16 +19,12 @@ from graphene import (
from ..converter import convert_form_field from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args, **kwargs): def assert_conversion(django_field, graphene_field, *args):
# Arrange field = django_field(*args, help_text="Custom Help Text")
help_text = kwargs.setdefault("help_text", "Custom Help Text")
field = django_field(*args, **kwargs)
# Act
graphene_type = convert_form_field(field) graphene_type = convert_form_field(field)
# Assert
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field() field = graphene_type.Field()
assert field.description == help_text assert field.description == "Custom Help Text"
return field return field
@ -63,12 +59,7 @@ def test_should_slug_convert_string():
def test_should_url_convert_string(): def test_should_url_convert_string():
kwargs = {} assert_conversion(forms.URLField, String)
if DJANGO_VERSION >= (5, 0):
# silence RemovedInDjango60Warning
kwargs["assume_scheme"] = "https"
assert_conversion(forms.URLField, String, **kwargs)
def test_should_choice_convert_string(): def test_should_choice_convert_string():
@ -84,7 +75,8 @@ def test_should_regex_convert_string():
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
assert_conversion(forms.UUIDField, UUID) if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID)
def test_should_integer_convert_int(): def test_should_integer_convert_int():

View File

@ -3,8 +3,8 @@ from graphene import ID
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from graphene.utils.str_converters import to_camel_case from graphene.utils.str_converters import to_camel_case
from ..converter import EnumValueField from ..converter import BlankValueField
from ..types import ErrorType # noqa Import ErrorType for backwards compatibility from ..types import ErrorType # noqa Import ErrorType for backwards compatability
from .mutation import fields_for_form from .mutation import fields_for_form
@ -57,10 +57,11 @@ class DjangoFormInputObjectType(InputObjectType):
if ( if (
object_type object_type
and name in object_type._meta.fields and name in object_type._meta.fields
and isinstance(object_type._meta.fields[name], EnumValueField) and isinstance(object_type._meta.fields[name], BlankValueField)
): ):
# Field type EnumValueField here means that field # Field type BlankValueField here means that field
# with choices have been converted to enum # with choises have been converted to enum
# (BlankValueField is using only for that task ?)
setattr(cls, name, cls.get_enum_cnv_cls_instance(name, object_type)) setattr(cls, name, cls.get_enum_cnv_cls_instance(name, object_type))
elif ( elif (
object_type object_type

View File

@ -19,7 +19,6 @@ class SerializerMutationOptions(MutationOptions):
model_class = None model_class = None
model_operations = ["create", "update"] model_operations = ["create", "update"]
serializer_class = None serializer_class = None
optional_fields = ()
def fields_for_serializer( def fields_for_serializer(
@ -29,7 +28,6 @@ def fields_for_serializer(
is_input=False, is_input=False,
convert_choices_to_enum=True, convert_choices_to_enum=True,
lookup_field=None, lookup_field=None,
optional_fields=(),
): ):
fields = OrderedDict() fields = OrderedDict()
for name, field in serializer.fields.items(): for name, field in serializer.fields.items():
@ -50,13 +48,9 @@ def fields_for_serializer(
if is_not_in_only or is_excluded: if is_not_in_only or is_excluded:
continue continue
is_optional = name in optional_fields or "__all__" in optional_fields
fields[name] = convert_serializer_field( fields[name] = convert_serializer_field(
field, field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum
is_input=is_input,
convert_choices_to_enum=convert_choices_to_enum,
force_optional=is_optional,
) )
return fields return fields
@ -80,8 +74,7 @@ class SerializerMutation(ClientIDMutation):
exclude_fields=(), exclude_fields=(),
convert_choices_to_enum=True, convert_choices_to_enum=True,
_meta=None, _meta=None,
optional_fields=(), **options
**options,
): ):
if not serializer_class: if not serializer_class:
raise Exception("serializer_class is required for the SerializerMutation") raise Exception("serializer_class is required for the SerializerMutation")
@ -105,7 +98,6 @@ class SerializerMutation(ClientIDMutation):
is_input=True, is_input=True,
convert_choices_to_enum=convert_choices_to_enum, convert_choices_to_enum=convert_choices_to_enum,
lookup_field=lookup_field, lookup_field=lookup_field,
optional_fields=optional_fields,
) )
output_fields = fields_for_serializer( output_fields = fields_for_serializer(
serializer, serializer,

View File

@ -18,9 +18,7 @@ def get_graphene_type_from_serializer_field(field):
) )
def convert_serializer_field( def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True):
field, is_input=True, convert_choices_to_enum=True, force_optional=False
):
""" """
Converts a django rest frameworks field to a graphql field Converts a django rest frameworks field to a graphql field
and marks the field as required if we are creating an input type and marks the field as required if we are creating an input type
@ -33,10 +31,7 @@ def convert_serializer_field(
graphql_type = get_graphene_type_from_serializer_field(field) graphql_type = get_graphene_type_from_serializer_field(field)
args = [] args = []
kwargs = { kwargs = {"description": field.help_text, "required": is_input and field.required}
"description": field.help_text,
"required": is_input and field.required and not force_optional,
}
# if it is a tuple or a list it means that we are returning # if it is a tuple or a list it means that we are returning
# the graphql type and the child type # the graphql type and the child type

View File

@ -96,7 +96,8 @@ def test_should_regex_convert_string():
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
assert_conversion(serializers.UUIDField, graphene.String) if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String)
def test_should_model_convert_field(): def test_should_model_convert_field():

View File

@ -3,7 +3,7 @@ import datetime
from pytest import raises from pytest import raises
from rest_framework import serializers from rest_framework import serializers
from graphene import Field, ResolveInfo, String from graphene import Field, ResolveInfo
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from ...types import DjangoObjectType from ...types import DjangoObjectType
@ -105,16 +105,6 @@ def test_exclude_fields():
assert "created" not in MyMutation.Input._meta.fields assert "created" not in MyMutation.Input._meta.fields
def test_model_serializer_optional_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
optional_fields = ("cool_name",)
assert "cool_name" in MyMutation.Input._meta.fields
assert MyMutation.Input._meta.fields["cool_name"].type == String
def test_write_only_field(): def test_write_only_field():
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer): class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
password = serializers.CharField(write_only=True) password = serializers.CharField(write_only=True)
@ -275,7 +265,7 @@ def test_perform_mutate_success():
result = MyMethodMutation.mutate_and_get_payload( result = MyMethodMutation.mutate_and_get_payload(
None, None,
mock_info(), mock_info(),
**{"cool_name": "Narf", "last_edited": datetime.date(2020, 1, 4)}, **{"cool_name": "Narf", "last_edited": datetime.date(2020, 1, 4)}
) )
assert result.errors is None assert result.errors is None

View File

@ -30,8 +30,6 @@ DEFAULTS = {
# Max items returned in ConnectionFields / FilterConnectionFields # Max items returned in ConnectionFields / FilterConnectionFields
"RELAY_CONNECTION_MAX_LIMIT": 100, "RELAY_CONNECTION_MAX_LIMIT": 100,
"CAMELCASE_ERRORS": True, "CAMELCASE_ERRORS": True,
# Automatically convert Choice fields of Django into Enum fields
"DJANGO_CHOICE_FIELD_ENUM_CONVERT": True,
# Set to True to enable v2 naming convention for choice field Enum's # Set to True to enable v2 naming convention for choice field Enum's
"DJANGO_CHOICE_FIELD_ENUM_V2_NAMING": False, "DJANGO_CHOICE_FIELD_ENUM_V2_NAMING": False,
"DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME": None, "DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME": None,
@ -42,10 +40,8 @@ DEFAULTS = {
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options # https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
"GRAPHIQL_HEADER_EDITOR_ENABLED": True, "GRAPHIQL_HEADER_EDITOR_ENABLED": True,
"GRAPHIQL_SHOULD_PERSIST_HEADERS": False, "GRAPHIQL_SHOULD_PERSIST_HEADERS": False,
"GRAPHIQL_INPUT_VALUE_DEPRECATION": False,
"ATOMIC_MUTATIONS": False, "ATOMIC_MUTATIONS": False,
"TESTING_ENDPOINT": "/graphql", "TESTING_ENDPOINT": "/graphql",
"MAX_VALIDATION_ERRORS": None,
} }
if settings.DEBUG: if settings.DEBUG:

View File

@ -122,7 +122,6 @@
onEditOperationName: onEditOperationName, onEditOperationName: onEditOperationName,
isHeadersEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled, isHeadersEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled,
shouldPersistHeaders: GRAPHENE_SETTINGS.graphiqlShouldPersistHeaders, shouldPersistHeaders: GRAPHENE_SETTINGS.graphiqlShouldPersistHeaders,
inputValueDeprecation: GRAPHENE_SETTINGS.graphiqlInputValueDeprecation,
query: query, query: query,
}; };
if (parameters.variables) { if (parameters.variables) {

View File

@ -54,7 +54,6 @@ add "&raw" to the end of the URL within a browser.
{% endif %} {% endif %}
graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }}, graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }},
graphiqlShouldPersistHeaders: {{ graphiql_should_persist_headers|yesno:"true,false" }}, graphiqlShouldPersistHeaders: {{ graphiql_should_persist_headers|yesno:"true,false" }},
graphiqlInputValueDeprecation: {{ graphiql_input_value_deprecation|yesno:"true,false" }},
}; };
</script> </script>
<script src="{% static 'graphene_django/graphiql.js' %}"></script> <script src="{% static 'graphene_django/graphiql.js' %}"></script>

View File

@ -1,43 +1,11 @@
import django
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
CHOICES = ((1, "this"), (2, _("that"))) CHOICES = ((1, "this"), (2, _("that")))
def get_choices_as_class(choices_class):
if django.VERSION >= (5, 0):
return choices_class
else:
return choices_class.choices
def get_choices_as_callable(choices_class):
if django.VERSION >= (5, 0):
def choices():
return choices_class.choices
return choices
else:
return choices_class.choices
class TypedIntChoice(models.IntegerChoices):
CHOICE_THIS = 1
CHOICE_THAT = 2
class TypedStrChoice(models.TextChoices):
CHOICE_THIS = "this"
CHOICE_THAT = "that"
class Person(models.Model): class Person(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
parent = models.ForeignKey(
"self", on_delete=models.CASCADE, null=True, blank=True, related_name="children"
)
class Pet(models.Model): class Pet(models.Model):
@ -80,21 +48,6 @@ class Reporter(models.Model):
email = models.EmailField() email = models.EmailField()
pets = models.ManyToManyField("self") pets = models.ManyToManyField("self")
a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True) a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True)
typed_choice = models.IntegerField(
choices=TypedIntChoice.choices,
null=True,
blank=True,
)
class_choice = models.IntegerField(
choices=get_choices_as_class(TypedIntChoice),
null=True,
blank=True,
)
callable_choice = models.IntegerField(
choices=get_choices_as_callable(TypedStrChoice),
null=True,
blank=True,
)
objects = models.Manager() objects = models.Manager()
doe_objects = DoeReporterManager() doe_objects = DoeReporterManager()
fans = models.ManyToManyField(Person) fans = models.ManyToManyField(Person)
@ -144,7 +97,7 @@ class CNNReporter(Reporter):
class APNewsReporter(Reporter): class APNewsReporter(Reporter):
""" """
This class only inherits from Reporter for testing multi table inheritance This class only inherits from Reporter for testing multi table inheritence
similar to what you'd see in django-polymorphic similar to what you'd see in django-polymorphic
""" """

View File

@ -25,7 +25,7 @@ from ..converter import (
) )
from ..registry import Registry from ..registry import Registry
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import Article, Film, FilmDetails, Reporter, TypedIntChoice, TypedStrChoice from .models import Article, Film, FilmDetails, Reporter
# from graphene.core.types.custom_scalars import DateTime, Time, JSONString # from graphene.core.types.custom_scalars import DateTime, Time, JSONString
@ -53,8 +53,9 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
def test_should_unknown_django_field_raise_exception(): def test_should_unknown_django_field_raise_exception():
with raises(Exception, match="Don't know how to convert the Django field"): with raises(Exception) as excinfo:
convert_django_field(None) convert_django_field(None)
assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string(): def test_should_date_time_convert_string():
@ -114,7 +115,8 @@ def test_should_big_auto_convert_id():
def test_should_small_auto_convert_id(): def test_should_small_auto_convert_id():
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True) if hasattr(models, "SmallAutoField"):
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True)
def test_should_uuid_convert_id(): def test_should_uuid_convert_id():
@ -164,34 +166,14 @@ def test_field_with_choices_convert_enum():
help_text="Language", choices=(("es", "Spanish"), ("en", "English")) help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
) )
class ChoicesModel(models.Model): class TranslatedModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = "test" app_label = "test"
graphene_type = convert_django_field_with_choices(field).type.of_type graphene_type = convert_django_field_with_choices(field).type.of_type
assert graphene_type._meta.name == "TestChoicesModelLanguageChoices" assert graphene_type._meta.name == "TestTranslatedModelLanguageChoices"
assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_callable_choices_convert_enum():
def get_choices():
return ("es", "Spanish"), ("en", "English")
field = models.CharField(help_text="Language", choices=get_choices)
class CallableChoicesModel(models.Model):
language = field
class Meta:
app_label = "test"
graphene_type = convert_django_field_with_choices(field).type.of_type
assert graphene_type._meta.name == "TestCallableChoicesModelLanguageChoices"
assert graphene_type._meta.enum.__members__["ES"].value == "es" assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish" assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en" assert graphene_type._meta.enum.__members__["EN"].value == "en"
@ -441,102 +423,35 @@ def test_choice_enum_blank_value():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ("callable_choice",) fields = (
"first_name",
"a_choice",
)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(root, info): def resolve_reporter(root, info):
# return a model instance with blank choice field value return Reporter.objects.first()
return Reporter(callable_choice="")
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
# Create model with empty choice option
Reporter.objects.create(
first_name="Bridget", last_name="Jones", email="bridget@example.com"
)
result = schema.execute( result = schema.execute(
""" """
query { query {
reporter { reporter {
callableChoice firstName
aChoice
} }
} }
""" """
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
"reporter": {"callableChoice": None}, "reporter": {"firstName": "Bridget", "aChoice": None},
} }
def test_typed_choice_value():
"""Test that typed choices fields are resolved correctly to the enum values"""
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
fields = ("typed_choice", "class_choice", "callable_choice")
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(root, info):
# assign choice values to the fields instead of their str or int values
return Reporter(
typed_choice=TypedIntChoice.CHOICE_THIS,
class_choice=TypedIntChoice.CHOICE_THAT,
callable_choice=TypedStrChoice.CHOICE_THIS,
)
class CreateReporter(graphene.Mutation):
reporter = graphene.Field(ReporterType)
def mutate(root, info, **kwargs):
return CreateReporter(
reporter=Reporter(
typed_choice=TypedIntChoice.CHOICE_THIS,
class_choice=TypedIntChoice.CHOICE_THAT,
callable_choice=TypedStrChoice.CHOICE_THIS,
),
)
class Mutation(graphene.ObjectType):
create_reporter = CreateReporter.Field()
schema = graphene.Schema(query=Query, mutation=Mutation)
reporter_fragment = """
fragment reporter on ReporterType {
typedChoice
classChoice
callableChoice
}
"""
expected_reporter = {
"typedChoice": "A_1",
"classChoice": "A_2",
"callableChoice": "THIS",
}
result = schema.execute(
reporter_fragment
+ """
query {
reporter { ...reporter }
}
"""
)
assert not result.errors
assert result.data["reporter"] == expected_reporter
result = schema.execute(
reporter_fragment
+ """
mutation {
createReporter {
reporter { ...reporter }
}
}
"""
)
assert not result.errors
assert result.data["createReporter"]["reporter"] == expected_reporter

View File

@ -12,23 +12,17 @@ from .models import (
Article as ArticleModel, Article as ArticleModel,
Film as FilmModel, Film as FilmModel,
FilmDetails as FilmDetailsModel, FilmDetails as FilmDetailsModel,
Person as PersonModel,
Reporter as ReporterModel, Reporter as ReporterModel,
) )
class TestDjangoListField: class TestDjangoListField:
def test_only_django_object_types(self): def test_only_django_object_types(self):
class Query(ObjectType): class TestType(ObjectType):
something = DjangoListField(String) foo = String()
with pytest.raises(TypeError) as excinfo: with pytest.raises(AssertionError):
Schema(query=Query) DjangoListField(TestType)
assert (
"Query fields cannot be resolved. DjangoListField only accepts DjangoObjectType types as underlying type"
in str(excinfo.value)
)
def test_only_import_paths(self): def test_only_import_paths(self):
list_field = DjangoListField("graphene_django.tests.schema.Human") list_field = DjangoListField("graphene_django.tests.schema.Human")
@ -268,69 +262,6 @@ class TestDjangoListField:
] ]
} }
def test_same_type_nested_list_field(self):
class Person(DjangoObjectType):
class Meta:
model = PersonModel
fields = ("name", "parent")
children = DjangoListField(lambda: Person)
class Query(ObjectType):
persons = DjangoListField(Person)
schema = Schema(query=Query)
query = """
query {
persons {
name
children {
name
}
}
}
"""
p1 = PersonModel.objects.create(name="Tara")
PersonModel.objects.create(name="Debra")
PersonModel.objects.create(
name="Toto",
parent=p1,
)
PersonModel.objects.create(
name="Tata",
parent=p1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {
"persons": [
{
"name": "Tara",
"children": [
{"name": "Toto"},
{"name": "Tata"},
],
},
{
"name": "Debra",
"children": [],
},
{
"name": "Toto",
"children": [],
},
{
"name": "Tata",
"children": [],
},
]
}
def test_get_queryset_filter(self): def test_get_queryset_filter(self):
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:

View File

@ -26,7 +26,6 @@ class TestShouldCallGetQuerySetOnForeignKey:
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):
@ -37,7 +36,6 @@ class TestShouldCallGetQuerySetOnForeignKey:
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):
@ -202,7 +200,6 @@ class TestShouldCallGetQuerySetOnForeignKeyNode:
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
fields = "__all__"
interfaces = (Node,) interfaces = (Node,)
@classmethod @classmethod
@ -214,7 +211,6 @@ class TestShouldCallGetQuerySetOnForeignKeyNode:
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
fields = "__all__"
interfaces = (Node,) interfaces = (Node,)
@classmethod @classmethod
@ -374,7 +370,6 @@ class TestShouldCallGetQuerySetOnOneToOne:
class FilmDetailsType(DjangoObjectType): class FilmDetailsType(DjangoObjectType):
class Meta: class Meta:
model = FilmDetails model = FilmDetails
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):
@ -385,7 +380,6 @@ class TestShouldCallGetQuerySetOnOneToOne:
class FilmType(DjangoObjectType): class FilmType(DjangoObjectType):
class Meta: class Meta:
model = Film model = Film
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):

View File

@ -1,6 +1,5 @@
import base64 import base64
import datetime import datetime
from unittest.mock import ANY, Mock
import pytest import pytest
from django.db import models from django.db import models
@ -2001,62 +2000,14 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
assert result.data == expected assert result.data == expected
def test_connection_should_call_resolver_function():
resolver_mock = Mock(
name="resolver",
return_value=[
Reporter(first_name="Some", last_name="One"),
Reporter(first_name="John", last_name="Doe"),
],
)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
fields = "__all__"
interfaces = [Node]
class Query(graphene.ObjectType):
reporters = DjangoConnectionField(ReporterType, resolver=resolver_mock)
schema = graphene.Schema(query=Query)
result = schema.execute(
"""
query {
reporters {
edges {
node {
firstName
lastName
}
}
}
}
"""
)
resolver_mock.assert_called_once_with(None, ANY)
assert not result.errors
assert result.data == {
"reporters": {
"edges": [
{"node": {"firstName": "Some", "lastName": "One"}},
{"node": {"firstName": "John", "lastName": "Doe"}},
],
},
}
def test_should_query_nullable_foreign_key(): def test_should_query_nullable_foreign_key():
class PetType(DjangoObjectType): class PetType(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
fields = "__all__"
class PersonType(DjangoObjectType): class PersonType(DjangoObjectType):
class Meta: class Meta:
model = Person model = Person
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
pet = graphene.Field(PetType, name=graphene.String(required=True)) pet = graphene.Field(PetType, name=graphene.String(required=True))
@ -2071,8 +2022,10 @@ def test_should_query_nullable_foreign_key():
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
person = Person.objects.create(name="Jane") person = Person.objects.create(name="Jane")
Pet.objects.create(name="Stray dog", age=1) [
Pet.objects.create(name="Jane's dog", owner=person, age=1) Pet.objects.create(name="Stray dog", age=1),
Pet.objects.create(name="Jane's dog", owner=person, age=1),
]
query_pet = """ query_pet = """
query getPet($name: String!) { query getPet($name: String!) {
@ -2115,7 +2068,6 @@ def test_should_query_nullable_one_to_one_relation_with_custom_resolver():
class FilmType(DjangoObjectType): class FilmType(DjangoObjectType):
class Meta: class Meta:
model = Film model = Film
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):
@ -2124,7 +2076,6 @@ def test_should_query_nullable_one_to_one_relation_with_custom_resolver():
class FilmDetailsType(DjangoObjectType): class FilmDetailsType(DjangoObjectType):
class Meta: class Meta:
model = FilmDetails model = FilmDetails
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):

View File

@ -40,9 +40,6 @@ def test_should_map_fields_correctly():
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"typed_choice",
"class_choice",
"callable_choice",
"fans", "fans",
"reporter_type", "reporter_type",
] ]

View File

@ -1,4 +1,3 @@
import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from textwrap import dedent from textwrap import dedent
from unittest.mock import patch from unittest.mock import patch
@ -77,9 +76,6 @@ def test_django_objecttype_map_correct_fields():
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"typed_choice",
"class_choice",
"callable_choice",
"fans", "fans",
"reporter_type", "reporter_type",
] ]
@ -189,9 +185,6 @@ def test_schema_representation():
email: String! email: String!
pets: [Reporter!]! pets: [Reporter!]!
aChoice: TestsReporterAChoiceChoices aChoice: TestsReporterAChoiceChoices
typedChoice: TestsReporterTypedChoiceChoices
classChoice: TestsReporterClassChoiceChoices
callableChoice: TestsReporterCallableChoiceChoices
reporterType: TestsReporterReporterTypeChoices reporterType: TestsReporterReporterTypeChoices
articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection! articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection!
} }
@ -205,33 +198,6 @@ def test_schema_representation():
A_2 A_2
} }
\"""An enumeration.\"""
enum TestsReporterTypedChoiceChoices {
\"""Choice This\"""
A_1
\"""Choice That\"""
A_2
}
\"""An enumeration.\"""
enum TestsReporterClassChoiceChoices {
\"""Choice This\"""
A_1
\"""Choice That\"""
A_2
}
\"""An enumeration.\"""
enum TestsReporterCallableChoiceChoices {
\"""Choice This\"""
THIS
\"""Choice That\"""
THAT
}
\"""An enumeration.\""" \"""An enumeration.\"""
enum TestsReporterReporterTypeChoices { enum TestsReporterReporterTypeChoices {
\"""Regular\""" \"""Regular\"""
@ -433,7 +399,7 @@ def test_django_objecttype_fields_exist_on_model():
with pytest.warns( with pytest.warns(
UserWarning, UserWarning,
match=r"Field name .* matches an attribute on Django model .* but it's not a model field", match=r"Field name .* matches an attribute on Django model .* but it's not a model field",
): ) as record:
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
@ -441,8 +407,7 @@ def test_django_objecttype_fields_exist_on_model():
fields = ["first_name", "some_method", "email"] fields = ["first_name", "some_method", "email"]
# Don't warn if selecting a custom field # Don't warn if selecting a custom field
with warnings.catch_warnings(): with pytest.warns(None) as record:
warnings.simplefilter("error")
class Reporter3(DjangoObjectType): class Reporter3(DjangoObjectType):
custom_field = String() custom_field = String()
@ -451,6 +416,8 @@ def test_django_objecttype_fields_exist_on_model():
model = ReporterModel model = ReporterModel
fields = ["first_name", "custom_field", "email"] fields = ["first_name", "custom_field", "email"]
assert len(record) == 0
@with_local_registry @with_local_registry
def test_django_objecttype_exclude_fields_exist_on_model(): def test_django_objecttype_exclude_fields_exist_on_model():
@ -478,14 +445,15 @@ def test_django_objecttype_exclude_fields_exist_on_model():
exclude = ["custom_field"] exclude = ["custom_field"]
# Don't warn on exclude fields # Don't warn on exclude fields
with warnings.catch_warnings(): with pytest.warns(None) as record:
warnings.simplefilter("error")
class Reporter4(DjangoObjectType): class Reporter4(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude = ["email", "first_name"] exclude = ["email", "first_name"]
assert len(record) == 0
@with_local_registry @with_local_registry
def test_django_objecttype_neither_fields_nor_exclude(): def test_django_objecttype_neither_fields_nor_exclude():
@ -499,22 +467,24 @@ def test_django_objecttype_neither_fields_nor_exclude():
class Meta: class Meta:
model = ReporterModel model = ReporterModel
with warnings.catch_warnings(): with pytest.warns(None) as record:
warnings.simplefilter("error")
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
fields = ["email"] fields = ["email"]
with warnings.catch_warnings(): assert len(record) == 0
warnings.simplefilter("error")
with pytest.warns(None) as record:
class Reporter3(DjangoObjectType): class Reporter3(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude = ["email"] exclude = ["email"]
assert len(record) == 0
def custom_enum_name(field): def custom_enum_name(field):
return f"CustomEnum{field.name.title()}" return f"CustomEnum{field.name.title()}"
@ -691,122 +661,6 @@ class TestDjangoObjectType:
}""" }"""
) )
def test_django_objecttype_convert_choices_global_false(
self, graphene_settings, PetModel
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT = False
class Pet(DjangoObjectType):
class Meta:
model = PetModel
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: String!
cuteness: Int!
}"""
)
def test_django_objecttype_convert_choices_true_global_false(
self, graphene_settings, PetModel
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT = False
class Pet(DjangoObjectType):
class Meta:
model = PetModel
fields = "__all__"
convert_choices_to_enum = True
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: TestsPetModelKindChoices!
cuteness: TestsPetModelCutenessChoices!
}
\"""An enumeration.\"""
enum TestsPetModelKindChoices {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}
\"""An enumeration.\"""
enum TestsPetModelCutenessChoices {
\"""Kind of cute\"""
A_1
\"""Pretty cute\"""
A_2
\"""OMG SO CUTE!!!\"""
A_3
}"""
)
def test_django_objecttype_convert_choices_enum_list_global_false(
self, graphene_settings, PetModel
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT = False
class Pet(DjangoObjectType):
class Meta:
model = PetModel
convert_choices_to_enum = ["kind"]
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: TestsPetModelKindChoices!
cuteness: Int!
}
\"""An enumeration.\"""
enum TestsPetModelKindChoices {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}"""
)
@with_local_registry @with_local_registry
def test_django_objecttype_name_connection_propagation(): def test_django_objecttype_name_connection_propagation():

View File

@ -1,5 +1,4 @@
import json import json
from http import HTTPStatus
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -38,7 +37,7 @@ def jl(**kwargs):
def test_graphiql_is_enabled(client): def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT="text/html") response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html" assert response["Content-Type"].split(";")[0] == "text/html"
@ -47,7 +46,7 @@ def test_qfactor_graphiql(client):
url_string(query="{test}"), url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9", HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "text/html" assert response["Content-Type"].split(";")[0] == "text/html"
@ -56,7 +55,7 @@ def test_qfactor_json(client):
url_string(query="{test}"), url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9", HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response["Content-Type"].split(";")[0] == "application/json" assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -64,7 +63,7 @@ def test_qfactor_json(client):
def test_allows_get_with_query_param(client): def test_allows_get_with_query_param(client):
response = client.get(url_string(query="{test}")) response = client.get(url_string(query="{test}"))
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -76,7 +75,7 @@ def test_allows_get_with_variable_values(client):
) )
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -95,7 +94,7 @@ def test_allows_get_with_operation_name(client):
) )
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
} }
@ -104,7 +103,7 @@ def test_allows_get_with_operation_name(client):
def test_reports_validation_errors(client): def test_reports_validation_errors(client):
response = client.get(url_string(query="{ test, unknownOne, unknownTwo }")) response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{ {
@ -129,7 +128,7 @@ def test_errors_when_missing_operation_name(client):
) )
) )
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{ {
@ -147,7 +146,7 @@ def test_errors_when_sending_a_mutation_via_get(client):
""" """
) )
) )
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED assert response.status_code == 405
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{"message": "Can only perform a mutation operation from a POST request."} {"message": "Can only perform a mutation operation from a POST request."}
@ -166,7 +165,7 @@ def test_errors_when_selecting_a_mutation_within_a_get(client):
) )
) )
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED assert response.status_code == 405
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{"message": "Can only perform a mutation operation from a POST request."} {"message": "Can only perform a mutation operation from a POST request."}
@ -185,14 +184,14 @@ def test_allows_mutation_to_exist_within_a_get(client):
) )
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_post_with_json_encoding(client): def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query="{test}"), "application/json") response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -201,7 +200,7 @@ def test_batch_allows_post_with_json_encoding(client):
batch_url_string(), jl(id=1, query="{test}"), "application/json" batch_url_string(), jl(id=1, query="{test}"), "application/json"
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == [ assert response_json(response) == [
{"id": 1, "data": {"test": "Hello World"}, "status": 200} {"id": 1, "data": {"test": "Hello World"}, "status": 200}
] ]
@ -210,7 +209,7 @@ def test_batch_allows_post_with_json_encoding(client):
def test_batch_fails_if_is_empty(client): def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), "[]", "application/json") response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Received an empty list in the batch request."}] "errors": [{"message": "Received an empty list in the batch request."}]
} }
@ -223,7 +222,7 @@ def test_allows_sending_a_mutation_via_post(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}} assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
@ -234,7 +233,7 @@ def test_allows_post_with_url_encoding(client):
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -248,7 +247,7 @@ def test_supports_post_json_query_with_string_variables(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -263,7 +262,7 @@ def test_batch_supports_post_json_query_with_string_variables(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == [ assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
] ]
@ -279,7 +278,7 @@ def test_supports_post_json_query_with_json_variables(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -294,7 +293,7 @@ def test_batch_supports_post_json_query_with_json_variables(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == [ assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
] ]
@ -312,7 +311,7 @@ def test_supports_post_url_encoded_query_with_string_variables(client):
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -323,7 +322,7 @@ def test_supports_post_json_quey_with_get_variable_values(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -334,7 +333,7 @@ def test_post_url_encoded_query_with_get_variable_values(client):
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -345,7 +344,7 @@ def test_supports_post_raw_text_query_with_get_variable_values(client):
"application/graphql", "application/graphql",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -366,7 +365,7 @@ def test_allows_post_with_operation_name(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
} }
@ -390,7 +389,7 @@ def test_batch_allows_post_with_operation_name(client):
"application/json", "application/json",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == [ assert response_json(response) == [
{ {
"id": 1, "id": 1,
@ -414,7 +413,7 @@ def test_allows_post_with_get_operation_name(client):
"application/graphql", "application/graphql",
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
} }
@ -431,7 +430,7 @@ def test_inherited_class_with_attributes_works(client):
# Check graphiql works # Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html") response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
@pytest.mark.urls("graphene_django.tests.urls_pretty") @pytest.mark.urls("graphene_django.tests.urls_pretty")
@ -453,7 +452,7 @@ def test_supports_pretty_printing_by_request(client):
def test_handles_field_errors_caught_by_graphql(client): def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query="{thrower}")) response = client.get(url_string(query="{thrower}"))
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == { assert response_json(response) == {
"data": None, "data": None,
"errors": [ "errors": [
@ -468,7 +467,7 @@ def test_handles_field_errors_caught_by_graphql(client):
def test_handles_syntax_errors_caught_by_graphql(client): def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query="syntaxerror")) response = client.get(url_string(query="syntaxerror"))
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{ {
@ -482,7 +481,7 @@ def test_handles_syntax_errors_caught_by_graphql(client):
def test_handles_errors_caused_by_a_lack_of_query(client): def test_handles_errors_caused_by_a_lack_of_query(client):
response = client.get(url_string()) response = client.get(url_string())
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Must provide query string."}] "errors": [{"message": "Must provide query string."}]
} }
@ -491,7 +490,7 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
def test_handles_not_expected_json_bodies(client): def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), "[]", "application/json") response = client.post(url_string(), "[]", "application/json")
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "The received data is not a valid JSON query."}] "errors": [{"message": "The received data is not a valid JSON query."}]
} }
@ -500,7 +499,7 @@ def test_handles_not_expected_json_bodies(client):
def test_handles_invalid_json_bodies(client): def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), "[oh}", "application/json") response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "POST body sent invalid JSON."}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
@ -515,14 +514,14 @@ def test_handles_django_request_error(client, monkeypatch):
valid_json = json.dumps({"foo": "bar"}) valid_json = json.dumps({"foo": "bar"})
response = client.post(url_string(), valid_json, "application/json") response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == {"errors": [{"message": "foo-bar"}]} assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
def test_handles_incomplete_json_bodies(client): def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', "application/json") response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "POST body sent invalid JSON."}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
@ -534,7 +533,7 @@ def test_handles_plain_post_text(client):
"query helloWho($who: String){ test(who: $who) }", "query helloWho($who: String){ test(who: $who) }",
"text/plain", "text/plain",
) )
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Must provide query string."}] "errors": [{"message": "Must provide query string."}]
} }
@ -546,7 +545,7 @@ def test_handles_poorly_formed_variables(client):
query="query helloWho($who: String){ test(who: $who) }", variables="who:You" query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
) )
) )
assert response.status_code == HTTPStatus.BAD_REQUEST assert response.status_code == 400
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Variables are invalid JSON."}] "errors": [{"message": "Variables are invalid JSON."}]
} }
@ -554,7 +553,7 @@ def test_handles_poorly_formed_variables(client):
def test_handles_unsupported_http_methods(client): def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query="{test}")) response = client.put(url_string(query="{test}"))
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED assert response.status_code == 405
assert response["Allow"] == "GET, POST" assert response["Allow"] == "GET, POST"
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "GraphQL only supports GET and POST requests."}] "errors": [{"message": "GraphQL only supports GET and POST requests."}]
@ -564,7 +563,7 @@ def test_handles_unsupported_http_methods(client):
def test_passes_request_into_context_request(client): def test_passes_request_into_context_request(client):
response = client.get(url_string(query="{request}", q="testing")) response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == HTTPStatus.OK assert response.status_code == 200
assert response_json(response) == {"data": {"request": "testing"}} assert response_json(response) == {"data": {"request": "testing"}}
@ -828,97 +827,3 @@ def test_query_errors_atomic_request(set_rollback_mock, client):
def test_query_errors_non_atomic(set_rollback_mock, client): def test_query_errors_non_atomic(set_rollback_mock, client):
client.get(url_string(query="force error")) client.get(url_string(query="force error"))
set_rollback_mock.assert_not_called() set_rollback_mock.assert_not_called()
VALIDATION_URLS = [
"/graphql/validation/",
"/graphql/validation/alternative/",
"/graphql/validation/inherited/",
]
QUERY_WITH_TWO_INTROSPECTIONS = """
query Instrospection {
queryType: __schema {
queryType {name}
}
mutationType: __schema {
mutationType {name}
}
}
"""
N_INTROSPECTIONS = 2
INTROSPECTION_DISALLOWED_ERROR_MESSAGE = "introspection is disabled"
MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE = "too many validation errors"
@pytest.mark.urls("graphene_django.tests.urls_validation")
def test_allow_introspection(client):
response = client.post(
url_string("/graphql/", query="{__schema {queryType {name}}}")
)
assert response.status_code == HTTPStatus.OK
assert response_json(response) == {
"data": {"__schema": {"queryType": {"name": "QueryRoot"}}}
}
@pytest.mark.parametrize("url", VALIDATION_URLS)
@pytest.mark.urls("graphene_django.tests.urls_validation")
def test_validation_disallow_introspection(client, url):
response = client.post(url_string(url, query="{__schema {queryType {name}}}"))
assert response.status_code == HTTPStatus.BAD_REQUEST
json_response = response_json(response)
assert "data" not in json_response
assert "errors" in json_response
assert len(json_response["errors"]) == 1
error_message = json_response["errors"][0]["message"]
assert INTROSPECTION_DISALLOWED_ERROR_MESSAGE in error_message
@pytest.mark.parametrize("url", VALIDATION_URLS)
@pytest.mark.urls("graphene_django.tests.urls_validation")
@patch(
"graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", N_INTROSPECTIONS
)
def test_within_max_validation_errors(client, url):
response = client.post(url_string(url, query=QUERY_WITH_TWO_INTROSPECTIONS))
assert response.status_code == HTTPStatus.BAD_REQUEST
json_response = response_json(response)
assert "data" not in json_response
assert "errors" in json_response
assert len(json_response["errors"]) == N_INTROSPECTIONS
error_messages = [error["message"].lower() for error in json_response["errors"]]
n_introspection_error_messages = sum(
INTROSPECTION_DISALLOWED_ERROR_MESSAGE in msg for msg in error_messages
)
assert n_introspection_error_messages == N_INTROSPECTIONS
assert all(
MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE not in msg for msg in error_messages
)
@pytest.mark.parametrize("url", VALIDATION_URLS)
@pytest.mark.urls("graphene_django.tests.urls_validation")
@patch("graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", 1)
def test_exceeds_max_validation_errors(client, url):
response = client.post(url_string(url, query=QUERY_WITH_TWO_INTROSPECTIONS))
assert response.status_code == HTTPStatus.BAD_REQUEST
json_response = response_json(response)
assert "data" not in json_response
assert "errors" in json_response
error_messages = (error["message"].lower() for error in json_response["errors"])
assert any(MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE in msg for msg in error_messages)

View File

@ -1,26 +0,0 @@
from django.urls import path
from graphene.validation import DisableIntrospection
from ..views import GraphQLView
from .schema_view import schema
class View(GraphQLView):
schema = schema
class NoIntrospectionView(View):
validation_rules = (DisableIntrospection,)
class NoIntrospectionViewInherited(NoIntrospectionView):
pass
urlpatterns = [
path("graphql/", View.as_view()),
path("graphql/validation/", View.as_view(validation_rules=(DisableIntrospection,))),
path("graphql/validation/alternative/", NoIntrospectionView.as_view()),
path("graphql/validation/inherited/", NoIntrospectionViewInherited.as_view()),
]

View File

@ -23,7 +23,7 @@ ALL_FIELDS = "__all__"
def construct_fields( def construct_fields(
model, registry, only_fields, exclude_fields, convert_choices_to_enum=None model, registry, only_fields, exclude_fields, convert_choices_to_enum
): ):
_model_fields = get_model_fields(model) _model_fields = get_model_fields(model)
@ -47,7 +47,7 @@ def construct_fields(
continue continue
_convert_choices_to_enum = convert_choices_to_enum _convert_choices_to_enum = convert_choices_to_enum
if isinstance(_convert_choices_to_enum, list): if not isinstance(_convert_choices_to_enum, bool):
# then `convert_choices_to_enum` is a list of field names to convert # then `convert_choices_to_enum` is a list of field names to convert
if name in _convert_choices_to_enum: if name in _convert_choices_to_enum:
_convert_choices_to_enum = True _convert_choices_to_enum = True
@ -102,8 +102,10 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
if name in all_field_names: if name in all_field_names:
# Field is a custom field # Field is a custom field
warnings.warn( warnings.warn(
f'Excluding the custom field "{name}" on DjangoObjectType "{type_}" has no effect. ' (
'Either remove the custom field or remove the field from the "exclude" list.' 'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. '
'Either remove the custom field or remove the field from the "exclude" list.'
).format(field_name=name, type_=type_)
) )
else: else:
if not hasattr(model, name): if not hasattr(model, name):
@ -146,7 +148,7 @@ class DjangoObjectType(ObjectType):
connection_class=None, connection_class=None,
use_connection=None, use_connection=None,
interfaces=(), interfaces=(),
convert_choices_to_enum=None, convert_choices_to_enum=True,
_meta=None, _meta=None,
**options, **options,
): ):

View File

@ -4,7 +4,6 @@ import warnings
from django.test import Client, TestCase, TransactionTestCase from django.test import Client, TestCase, TransactionTestCase
from graphene_django.settings import graphene_settings from graphene_django.settings import graphene_settings
from graphene_django.utils.utils import _DJANGO_VERSION_AT_LEAST_4_2
DEFAULT_GRAPHQL_URL = "/graphql" DEFAULT_GRAPHQL_URL = "/graphql"
@ -56,14 +55,8 @@ def graphql_query(
else: else:
body["variables"] = {"input": input_data} body["variables"] = {"input": input_data}
if headers: if headers:
header_params = (
{"headers": headers} if _DJANGO_VERSION_AT_LEAST_4_2 else headers
)
resp = client.post( resp = client.post(
graphql_url, graphql_url, json.dumps(body), content_type="application/json", **headers
json.dumps(body),
content_type="application/json",
**header_params,
) )
else: else:
resp = client.post( resp = client.post(

View File

@ -1,6 +1,5 @@
import inspect import inspect
import django
from django.db import connection, models, transaction from django.db import connection, models, transaction
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -111,7 +110,24 @@ def is_valid_django_model(model):
def import_single_dispatch(): def import_single_dispatch():
from functools import singledispatch try:
from functools import singledispatch
except ImportError:
singledispatch = None
if not singledispatch:
try:
from singledispatch import singledispatch
except ImportError:
pass
if not singledispatch:
raise Exception(
"It seems your python version does not include "
"functools.singledispatch. Please install the 'singledispatch' "
"package. More information here: "
"https://pypi.python.org/pypi/singledispatch"
)
return singledispatch return singledispatch
@ -129,8 +145,3 @@ def bypass_get_queryset(resolver):
""" """
resolver._bypass_get_queryset = True resolver._bypass_get_queryset = True
return resolver return resolver
_DJANGO_VERSION_AT_LEAST_4_2 = django.VERSION[0] > 4 or (
django.VERSION[0] >= 4 and django.VERSION[1] >= 2
)

View File

@ -9,17 +9,10 @@ from django.shortcuts import render
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import View from django.views.generic import View
from graphql import ( from graphql import OperationType, get_operation_ast, parse
ExecutionResult,
OperationType,
execute,
get_operation_ast,
parse,
validate_schema,
)
from graphql.error import GraphQLError from graphql.error import GraphQLError
from graphql.execution import ExecutionResult
from graphql.execution.middleware import MiddlewareManager from graphql.execution.middleware import MiddlewareManager
from graphql.validation import validate
from graphene import Schema from graphene import Schema
from graphene_django.constants import MUTATION_ERRORS_FLAG from graphene_django.constants import MUTATION_ERRORS_FLAG
@ -96,7 +89,6 @@ class GraphQLView(View):
batch = False batch = False
subscription_path = None subscription_path = None
execution_context_class = None execution_context_class = None
validation_rules = None
def __init__( def __init__(
self, self,
@ -108,7 +100,6 @@ class GraphQLView(View):
batch=False, batch=False,
subscription_path=None, subscription_path=None,
execution_context_class=None, execution_context_class=None,
validation_rules=None,
): ):
if not schema: if not schema:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
@ -137,8 +128,6 @@ class GraphQLView(View):
), "A Schema is required to be provided to GraphQLView." ), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing" assert not all((graphiql, batch)), "Use either graphiql or batch processing"
self.validation_rules = validation_rules or self.validation_rules
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def get_root_value(self, request): def get_root_value(self, request):
return self.root_value return self.root_value
@ -178,13 +167,11 @@ class GraphQLView(View):
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version, graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version,
graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri, graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri,
graphiql_plugin_explorer_css_sri=self.graphiql_plugin_explorer_css_sri,
# The SUBSCRIPTION_PATH setting. # The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path, subscription_path=self.subscription_path,
# GraphiQL headers tab, # GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS, graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
graphiql_input_value_deprecation=graphene_settings.GRAPHIQL_INPUT_VALUE_DEPRECATION,
) )
if self.batch: if self.batch:
@ -306,61 +293,43 @@ class GraphQLView(View):
return None return None
raise HttpError(HttpResponseBadRequest("Must provide query string.")) raise HttpError(HttpResponseBadRequest("Must provide query string."))
schema = self.schema.graphql_schema
schema_validation_errors = validate_schema(schema)
if schema_validation_errors:
return ExecutionResult(data=None, errors=schema_validation_errors)
try: try:
document = parse(query) document = parse(query)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e]) return ExecutionResult(errors=[e])
operation_ast = get_operation_ast(document, operation_name) if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name)
if operation_ast and operation_ast.operation != OperationType.QUERY:
if show_graphiql:
return None
if ( raise HttpError(
request.method.lower() == "get" HttpResponseNotAllowed(
and operation_ast is not None ["POST"],
and operation_ast.operation != OperationType.QUERY "Can only perform a {} operation from a POST request.".format(
): operation_ast.operation.value
if show_graphiql: ),
return None )
raise HttpError(
HttpResponseNotAllowed(
["POST"],
"Can only perform a {} operation from a POST request.".format(
operation_ast.operation.value
),
) )
)
validation_errors = validate(
schema,
document,
self.validation_rules,
graphene_settings.MAX_VALIDATION_ERRORS,
)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
try: try:
execute_options = { extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = {
"source": query,
"root_value": self.get_root_value(request), "root_value": self.get_root_value(request),
"context_value": self.get_context(request),
"variable_values": variables, "variable_values": variables,
"operation_name": operation_name, "operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request), "middleware": self.get_middleware(request),
} }
if self.execution_context_class: options.update(extra_options)
execute_options[
"execution_context_class"
] = self.execution_context_class
operation_ast = get_operation_ast(document, operation_name)
if ( if (
operation_ast is not None operation_ast
and operation_ast.operation == OperationType.MUTATION and operation_ast.operation == OperationType.MUTATION
and ( and (
graphene_settings.ATOMIC_MUTATIONS is True graphene_settings.ATOMIC_MUTATIONS is True
@ -368,12 +337,12 @@ class GraphQLView(View):
) )
): ):
with transaction.atomic(): with transaction.atomic():
result = execute(schema, document, **execute_options) result = self.schema.execute(**options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True: if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True) transaction.set_rollback(True)
return result return result
return execute(schema, document, **execute_options) return self.schema.execute(**options)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e]) return ExecutionResult(errors=[e])

View File

@ -10,7 +10,3 @@ omit = */tests/*
[tool:pytest] [tool:pytest]
DJANGO_SETTINGS_MODULE = examples.django_test_settings DJANGO_SETTINGS_MODULE = examples.django_test_settings
addopts = --random-order addopts = --random-order
filterwarnings =
error
# we can't do anything about the DeprecationWarning about typing.ByteString in graphql
default:'typing\.ByteString' is deprecated:DeprecationWarning:graphql\.pyutils\.is_iterable

View File

@ -26,7 +26,8 @@ tests_require = [
dev_requires = [ dev_requires = [
"ruff==0.1.2", "black==23.7.0",
"ruff==0.0.283",
"pre-commit", "pre-commit",
] + tests_require ] + tests_require
@ -49,14 +50,11 @@ setup(
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Django", "Framework :: Django",
"Framework :: Django :: 3.2", "Framework :: Django :: 3.2",
"Framework :: Django :: 4.1", "Framework :: Django :: 4.1",
"Framework :: Django :: 4.2", "Framework :: Django :: 4.2",
"Framework :: Django :: 5.1",
"Framework :: Django :: 5.2",
], ],
keywords="api graphql protocol rest relay graphene", keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests", "examples", "examples.*"]), packages=find_packages(exclude=["tests", "examples", "examples.*"]),

13
tox.ini
View File

@ -1,8 +1,8 @@
[tox] [tox]
envlist = envlist =
py{38,39,310}-django32 py{38,39,310}-django32
py{38,39}-django42 py{38,39}-django{41,42}
py{310,311,312}-django{42,50,51,main} py{310,311}-django{41,42,main}
pre-commit pre-commit
[gh-actions] [gh-actions]
@ -11,15 +11,12 @@ python =
3.9: py39 3.9: py39
3.10: py310 3.10: py310
3.11: py311 3.11: py311
3.12: py312
[gh-actions:env] [gh-actions:env]
DJANGO = DJANGO =
3.2: django32 3.2: django32
4.1: django41
4.2: django42 4.2: django42
5.0: django50
5.1: django51
5.2: django52
main: djangomain main: djangomain
[testenv] [testenv]
@ -32,10 +29,8 @@ deps =
-e.[test] -e.[test]
psycopg2-binary psycopg2-binary
django32: Django>=3.2,<4.0 django32: Django>=3.2,<4.0
django41: Django>=4.1,<4.2
django42: Django>=4.2,<4.3 django42: Django>=4.2,<4.3
django50: Django>=5.0,<5.1
django51: Django>=5.1,<5.2
django52: Django>=5.2,<6.0
djangomain: https://github.com/django/django/archive/main.zip djangomain: https://github.com/django/django/archive/main.zip
commands = {posargs:pytest --cov=graphene_django graphene_django examples} commands = {posargs:pytest --cov=graphene_django graphene_django examples}