Compare commits

...

59 Commits
v3.1.1 ... main

Author SHA1 Message Date
dependabot[bot]
f02ea337a2
Bump django from 3.2.25 to 4.2.18 in /examples/cookbook (#1543)
Bumps [django](https://github.com/django/django) from 3.2.25 to 4.2.18.
- [Commits](https://github.com/django/django/compare/3.2.25...4.2.18)

---
updated-dependencies:
- dependency-name: django
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-23 07:01:33 -07:00
Jeongseok Kang
ad26bfa2f6
ci: Upgrade GitHub Actions versions (#1546)
* ci: Upgrade actions/checkout

* ci: Upgrade actions/setup-python
2025-06-23 07:00:48 -07:00
Jeongseok Kang
788a20490a
chore: Add support for Django 5.2 (#1544)
* chore: Add support for Django 5.2

* chore: Update setup.py
2025-06-23 06:59:21 -07:00
Firas Kafri
c52cf2b045
Bump version to 3.2.3 2025-03-13 11:29:45 +03:00
Florian Zimmermann
e69e4a0399
Bugfix: call resolver function in DjangoConnectionField as documented (#1529)
* treat warnings as errors when running the tests

* silence warnings

* bugfix: let DjangoConnectionField call its resolver function

that is, the one specified using DjangoConnectionField(..., resolver=some_func)

* ignore the DeprecationWarning about typing.ByteString in graphql
2025-03-13 11:25:48 +03:00
Sergey Fursov
97deb761e9
fix typed choices, make working with different Django 5x choices options (#1539)
* fix typed choices, make working with different Django 5x choices options

* remove `graphene_django/compat.py` from ruff exclusions
2025-03-13 11:23:51 +03:00
Sergey Fursov
8d4a64a40d
add official Django 5.1 support (#1540) 2024-12-27 13:46:47 +08:00
Alexandre Detiste
269225085d
remove dead code: singledispatch has been in the standard library ... (#1534)
* remove dead code: singledispatch has been in the stard library for many years

(BTW this function does not seems to be used anywhere anymore)

* lint
2024-09-15 21:50:15 +07:00
Markus Richter
28c71c58f7 Bump to 3.2.2 2024-06-12 10:52:45 +08:00
Kien Dang
6f21dc7a94
Not require explicitly set ordering in DjangoConnectionField (#1518)
* Revert "feat!: check django model has a default ordering when used in a relay connection (#1495)"

This reverts commit 96c09ac439.

* Fix assert no warning for pytest>=8
2024-04-18 12:00:31 +08:00
Ülgen Sarıkavak
ea45de02ad
Make use of http.HTTPStatus for response status code checks (#1487) 2024-04-09 03:43:34 +03:00
dependabot[bot]
eac113e136
Bump django from 3.2.24 to 3.2.25 in /examples/cookbook (#1508)
Bumps [django](https://github.com/django/django) from 3.2.24 to 3.2.25.
- [Commits](https://github.com/django/django/compare/3.2.24...3.2.25)

---
updated-dependencies:
- dependency-name: django
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-09 03:39:21 +03:00
Kien Dang
d69c90550f
Bump to 3.2.1 (#1512) 2024-04-09 03:37:32 +03:00
Pablo Alexis Domínguez Grau
3f813d4679
Fix ReadTheDocs builds (#1509)
* Add RTD config file

* Doc fixes to reference main branch instead of master
2024-03-29 12:11:56 +08:00
Alisson Patricio
45c2aa09b5
Allows field's choices to be a callable (#1497)
* Allows field's choices to be a callable

Starting in Django 5 field's choices can also be a callable

* test if field with callable choices converts into enum

---------

Co-authored-by: Kien Dang <mail@kien.ai>
2024-03-21 00:48:51 +08:00
Diogo Silva
ac09cd2967
fix: Fix broke 'get_choices' with restframework 3.15.0 (#1506) 2024-03-18 09:58:47 +08:00
dependabot[bot]
54372b41d5
Bump django from 3.1.14 to 3.2.24 in /examples/cookbook (#1498)
Bumps [django](https://github.com/django/django) from 3.1.14 to 3.2.24.
- [Commits](https://github.com/django/django/compare/3.1.14...3.2.24)

---
updated-dependencies:
- dependency-name: django
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-02-08 10:50:13 +08:00
Thomas Leonard
96c09ac439
feat!: check django model has a default ordering when used in a relay connection (#1495)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
2024-01-30 12:09:18 +03:00
Laurent
b85177cebf
fix: same type list (#1492)
* fix: same type list

* chore: improve test
2024-01-20 16:36:00 +08:00
Firas Kafri
4d0484f312
Bump version 2023-12-20 13:22:33 +03:00
Noxx
c416a2b0f5
Provide setting to enable/disable converting choices to enums globally (#1477)
Co-authored-by: Firas Kafri <3097061+firaskafri@users.noreply.github.com>
Co-authored-by: Kien Dang <mail@kien.ai>
2023-12-20 17:55:15 +08:00
Kien Dang
feb7252b8a
Add support for validation rules (#1475)
* Add support for validation rules

* Enable customizing validate max_errors through settings

* Add tests for validation rules

* Add examples for validation rules

* Allow setting validation_rules in class def

* Add tests for validation_rules inherited from parent class

* Make tests for validation rules stricter
2023-12-20 12:48:45 +03:00
Firas Kafri
3a64994e52
Bump version (#1486) 2023-12-20 12:44:40 +03:00
Kien Dang
db2d40ec94
Remove Django 4.1 (EOL) and add Django 5.0 to CI (#1483) 2023-12-14 11:20:54 +03:00
Kien Dang
62126dd467
Add Python 3.12 to CI (#1481) 2023-12-05 22:11:00 +03:00
danthewildcat
e735f5dbdb
Optimize views (#1439)
* Optimize execute_graphql_request

* Require operation_ast to be found by view handler

* Remove unused show_graphiql kwarg

* Old style if syntax

* Revert "Remove unused show_graphiql kwarg"

This reverts commit 33b3426092.

* Add missing schema validation step

* Pass args directly to improve clarity

* Remove duplicated operation_ast not None check

---------

Co-authored-by: Firas Kafri <3097061+firaskafri@users.noreply.github.com>
Co-authored-by: Kien Dang <mail@kien.ai>
2023-10-29 23:42:27 +08:00
Kien Dang
36cf100e8b
Use ruff format to replace black (#1473)
* Use ruff format to replace black

* Adjust ruff config to be compatible with ruff-format

https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules

* Format

* Replace black with ruff format in Makefile
2023-10-25 11:33:00 +03:00
Kien Dang
e8f36b018d
Fix test Client headers for Django 4.2 (#1465)
* Fix test Client headers for Django 4.2

* Lazy import pkg_resources

since it could be quite heavy

* Remove use of pkg_resources altogether
2023-09-18 23:23:53 +08:00
mnasiri
83d3d27f14
Fix graphiql explorer styles by sending graphiql_plugin_explorer_css_sri param to render_graphiql function of the GraphQlView (#1418) (#1460) 2023-09-14 00:26:18 +08:00
Romain Létendart
ee7560f629
Support displaying deprecated input fields in GraphiQL docs (#1458)
* Update GraphiQL docs URL in docs/settings

And deduplicate link declaration.

* Support displaying deprecated input fields in GraphiQL docs
2023-09-13 09:49:01 +03:00
lilac-supernova-2
67def2e074
Typo fixes (#1459)
* Fix Star Wars spaceship name

* Fix some typos in comments

* Typo fixes

* More typo fixes
2023-09-06 10:29:58 +03:00
mahmoudmostafa0
e49a01c189
adding optional_field in Serializermutation to enfore some fields to be optional (#1455)
* adding optional_fields to enforce fields to be optional

* adding support for all

* adding unit tests

* Update graphene_django/rest_framework/mutation.py

Co-authored-by: Kien Dang <kiend@pm.me>

* linting

* linting

* add missing import

---------

Co-authored-by: Kien Dang <kiend@pm.me>
2023-08-28 00:15:35 +03:00
Thomas Leonard
0473f1a9a3
fix: empty list is not an empty value for list filters even when a custom filtering method is provided (#1450)
Co-authored-by: Thomas Leonard <thomas@loftorbital.com>
2023-08-11 23:24:58 +08:00
Kien Dang
720db1f987
Only release on pypi after tests pass (#1452) 2023-08-11 09:51:59 +03:00
Firas Kafri
4ac3f3f42d
Update __init__.py 2023-08-10 01:12:15 +03:00
Firas Kafri
ee7598e71a
Remove typo 2023-08-09 23:41:57 +03:00
Firas Kafri
05d7fb5396
Bump version 2023-08-09 20:49:51 +03:00
Kien Dang
79b4a23ae0
Miscellaneous CI fixes (#1447)
* Update Makefile

* django master requires at least python 3.10 now

* Allow customizing options passed to tox -e pre-commit

* py.test -> pytest

* Update ruff

* Fix E721

Do not compare types, use `isinstance()`

* Add back black to dev dependencies

* Pin black and ruff versions
2023-08-09 20:48:42 +03:00
Laurent
db34d2e815
fix: foreign key nullable and custom resolver (#1446)
* fix: nullable one to one relation

* fix: makefile
2023-08-09 20:28:26 +03:00
Kien Dang
9a773b9d7b
Use ruff in pre-commit (#1441)
* Use ruff in pre-commit

* Add pyupgrade

* Add isort

* Add bugbear

* Fix B015 Pointless comparison

* Fix B026

* B018 false positive

* Remove flake8 and isort config from setup.cfg

* Remove black and flake8 from dev dependencies

* Update black

* Show list of fixes applied with autofix on

* Fix typo

* Add C4 flake8-comprehensions

* Add ruff to dev dependencies

* Fix up
2023-08-06 01:47:00 +03:00
Kien Dang
45a732f1db
Prevent duplicate CI runs, also work with PRs from forks (#1443)
* Prevent duplicate CI runs

* Trigger CI on pull requests from forks
2023-08-06 01:45:10 +03:00
Kien Dang
5eb5fe294a
Remove Python 3.7 (EOL since EOL since 2023-06-27) from CI (#1440)
* Remove Python 3.7 (EOL since EOL since 2023-06-27) from CI

* Remove unused context

* Use pyupgrade --py38-plus in pre-commit
2023-08-04 11:15:23 +03:00
James
5d7a04fce9
Update mutation.py to serialize Enum objects into input values (#1431)
* Fix for issue #1385: Update mutation.py to serialize Enum objects into input values for ChoiceFields

* Update graphene_django/rest_framework/mutation.py

Co-authored-by: Steven DeMartini <1647130+sjdemartini@users.noreply.github.com>

---------

Co-authored-by: Steven DeMartini <1647130+sjdemartini@users.noreply.github.com>
2023-07-27 02:41:40 +03:00
Firas Kafri
3172710d12
exclude 'fans' from ReporterForm tests (#1434) 2023-07-18 20:35:51 +03:00
Tom Dror
b1abebdb97
Support base class relations and reverse for proxy models (#1380)
* support reverse relationship for proxy models

* support multi table inheritence

* update query test for multi table inheritance

* remove debugger

* support local many to many in model inheritance

* format and lint

---------

Co-authored-by: Firas K <3097061+firaskafri@users.noreply.github.com>
2023-07-18 20:17:45 +03:00
Laurent
0de35ca3b0
fix: fk resolver permissions leak (#1411)
* fix: fk resolver permissions leak

* fix: only one query for 1o1 relation

* tests: added queries count check

* fix: docstring

* fix: typo

* docs: added warning to authorization

* feat: added bypass_get_queryset decorator
2023-07-18 15:16:52 +03:00
Firas Kafri
2fafa881a8
Bump version 2023-07-18 15:13:58 +03:00
Steven DeMartini
cd43022283
Maintain JSONField in graphene_django.compat module (#1429)
Fixes https://github.com/graphql-python/graphene-django/issues/1428

This should improve backwards compatibility, fixing issues in downstream
packages (notably graphene-django-cud
https://github.com/tOgg1/graphene-django-cud/issues/109, and also
graphene-django-extras, both of which depended on
`graphene_django.compat.JSONField`).

Co-authored-by: Steven DeMartini <sjdemartini@users.noreply.github.com>
2023-07-18 15:11:30 +03:00
Jeongseok Kang
3f061a0c50
docs: Update location of GraphQL Relay Specification (#1432) 2023-07-18 15:10:22 +03:00
Firas Kafri
e950164c8e
Bump version to 3.1.2 2023-06-17 09:29:18 +03:00
Steven DeMartini
2358bd30a4
Update compat.py MissingType results after PGJSONField removal (#1423)
As mentioned in https://github.com/graphql-python/graphene-django/pull/1421/files#r1221711648
2023-06-07 20:06:37 +03:00
Dulmandakh
3e7a16af73
CI: remove Django 4.0 (#1422)
* CI: remove Django 4.0

* fix tags
2023-06-07 17:36:51 +03:00
Dulmandakh
8fa8aea3c0
remove JSONField compat (#1421)
* remove JSONFIeld compat

* fix black
2023-06-07 17:36:29 +03:00
Dulmandakh
c925a32dc3
CI: add Django 4.2 (#1420)
* CI: add Django 4.2

* fix tox
2023-06-07 16:52:40 +03:00
Sezgin ACER
8934393909
Add check for serializers.HiddenField on fields_for_serializer function (#1419)
* Add check for `serializers.HiddenField` on fields_for_serializer function

* Add pre-commit changes
2023-06-06 09:20:32 +03:00
Steven DeMartini
520ddeabf6
Fix graphiql explorer styles by including official CSS, and update local example app for testing (#1418)
* Add venv and .venv to gitignore since common venv paths

* Update cookbook-plain app requirements and local-dev notes

This also adds the DEFAULT_AUTO_FIELD to the app's Django settings to
resolve this warning when running `migrate`:

```
ingredients.Category: (models.W042) Auto-created primary key used when not defining a primary key type, by default 'django.db.models.AutoField'.
	HINT: Configure the DEFAULT_AUTO_FIELD setting or the IngredientsConfig.default_auto_field attribute to point to a subclass of AutoField, e.g. 'django.db.models.BigAutoField'.
```

* Fix #1417 graphiql explorer styles by including official CSS

Like in the official graphiql-plugin-explorer example here
6198646919/packages/graphiql-plugin-explorer/examples/index.html (L26-L29)

Resolves https://github.com/graphql-python/graphene-django/issues/1417

* Update GraphiQL version

---------

Co-authored-by: Steven DeMartini <sjdemartini@users.noreply.github.com>
Co-authored-by: Kien Dang <mail@kien.ai>
2023-06-02 11:48:53 +03:00
Kien Dang
38709d8396
Correct schema write test (#1416)
<Mock>.called_once() just returns a Mock, so assert <Mock>.called_once()
always passes. We want <Mock>.assert_called_once().
2023-05-27 16:53:22 +03:00
Kien Dang
63fd98393f
Set pypi GH action to latest v1 (#1415) 2023-05-27 16:26:52 +03:00
Firas Kafri
4e5acd4702
Fix linting issues (#1412) 2023-05-24 16:13:23 +03:00
99 changed files with 2748 additions and 674 deletions

View File

@ -6,13 +6,18 @@ on:
- 'v*' - 'v*'
jobs: jobs:
build: lint:
uses: ./.github/workflows/lint.yml
tests:
uses: ./.github/workflows/tests.yml
release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [lint, tests]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python 3.11 - name: Set up Python 3.11
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: '3.11' python-version: '3.11'
- name: Build wheel and source tarball - name: Build wheel and source tarball
@ -20,7 +25,7 @@ jobs:
pip install wheel pip install wheel
python setup.py sdist bdist_wheel python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI - name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@v1.8.6 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
user: __token__ user: __token__
password: ${{ secrets.pypi_password }} password: ${{ secrets.pypi_password }}

View File

@ -1,15 +1,19 @@
name: Lint name: Lint
on: [push, pull_request] on:
push:
branches: ["main"]
pull_request:
workflow_call:
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python 3.11 - name: Set up Python 3.11
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: '3.11' python-version: '3.11'
- name: Install dependencies - name: Install dependencies

View File

@ -1,6 +1,10 @@
name: Tests name: Tests
on: [push, pull_request] on:
push:
branches: ["main"]
pull_request:
workflow_call:
jobs: jobs:
build: build:
@ -8,17 +12,29 @@ jobs:
strategy: strategy:
max-parallel: 4 max-parallel: 4
matrix: matrix:
django: ["3.2", "4.0", "4.1"] django: ["3.2", "4.2", "5.0", "5.1", "5.2"]
python-version: ["3.8", "3.9", "3.10"] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
include: exclude:
- django: "3.2" - django: "3.2"
python-version: "3.7"
- django: "4.1"
python-version: "3.11" python-version: "3.11"
- django: "3.2"
python-version: "3.12"
- django: "5.0"
python-version: "3.8"
- django: "5.0"
python-version: "3.9"
- django: "5.1"
python-version: "3.8"
- django: "5.1"
python-version: "3.9"
- django: "5.2"
python-version: "3.8"
- django: "5.2"
python-version: "3.9"
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
@ -29,4 +45,3 @@ jobs:
run: tox run: tox
env: env:
DJANGO: ${{ matrix.django }} DJANGO: ${{ matrix.django }}
TOXENV: ${{ matrix.toxenv }}

3
.gitignore vendored
View File

@ -11,6 +11,9 @@ __pycache__/
# Distribution / packaging # Distribution / packaging
.Python .Python
env/ env/
.env/
venv/
.venv/
build/ build/
develop-eggs/ develop-eggs/
dist/ dist/

View File

@ -2,7 +2,7 @@ default_language_version:
python: python3.11 python: python3.11
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.5.0
hooks: hooks:
- id: check-merge-conflict - id: check-merge-conflict
- id: check-json - id: check-json
@ -15,16 +15,9 @@ repos:
- --autofix - --autofix
- id: trailing-whitespace - id: trailing-whitespace
exclude: README.md exclude: README.md
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v3.3.2 rev: v0.1.2
hooks: hooks:
- id: pyupgrade - id: ruff
args: [--py37-plus] args: [--fix, --exit-non-zero-on-fix, --show-fixes]
- repo: https://github.com/psf/black - id: ruff-format
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8

18
.readthedocs.yaml Normal file
View File

@ -0,0 +1,18 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: docs/conf.py
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt

32
.ruff.toml Normal file
View File

@ -0,0 +1,32 @@
select = [
"E", # pycodestyle
"W", # pycodestyle
"F", # pyflake
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line-too-long
"B017", # pytest.raises(Exception) should be considered evil
"B028", # warnings.warn called without an explicit stacklevel keyword argument
"B904", # check for raise statements in exception handlers that lack a from clause
"W191", # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules
]
exclude = [
"**/docs",
]
target-version = "py38"
[per-file-ignores]
# Ignore unused imports (F401) in these files
"__init__.py" = ["F401"]
[isort]
known-first-party = ["graphene", "graphene-django"]
known-local-folder = ["cookbook"]
combine-as-imports = true

View File

@ -33,7 +33,7 @@ make tests
## Opening Pull Requests ## Opening Pull Requests
Please fork the project and open a pull request against the master branch. Please fork the project and open a pull request against the `main` branch.
This will trigger a series of test and lint checks. This will trigger a series of test and lint checks.

View File

@ -10,15 +10,15 @@ dev-setup:
.PHONY: tests ## Run unit tests .PHONY: tests ## Run unit tests
tests: tests:
py.test graphene_django --cov=graphene_django -vv PYTHONPATH=. pytest graphene_django --cov=graphene_django -vv
.PHONY: format ## Format code .PHONY: format ## Format code
format: format:
black graphene_django examples setup.py ruff format graphene_django examples setup.py
.PHONY: lint ## Lint code .PHONY: lint ## Lint code
lint: lint:
flake8 graphene_django examples ruff graphene_django examples
.PHONY: docs ## Generate docs .PHONY: docs ## Generate docs
docs: dev-setup docs: dev-setup

View File

@ -30,7 +30,7 @@ Graphene-Django is an open-source library that provides seamless integration bet
To install Graphene-Django, run the following command: To install Graphene-Django, run the following command:
``` ```sh
pip install graphene-django pip install graphene-django
``` ```
@ -114,11 +114,11 @@ class MyModelAPITestCase(GraphQLTestCase):
## Contributing ## Contributing
Contributions to Graphene-Django are always welcome! To get started, check the repository's [issue tracker](https://github.com/graphql-python/graphene-django/issues) and [contribution guidelines](https://github.com/graphql-python/graphene-django/blob/master/CONTRIBUTING.md). Contributions to Graphene-Django are always welcome! To get started, check the repository's [issue tracker](https://github.com/graphql-python/graphene-django/issues) and [contribution guidelines](https://github.com/graphql-python/graphene-django/blob/main/CONTRIBUTING.md).
## License ## License
Graphene-Django is released under the [MIT License](https://github.com/graphql-python/graphene-django/blob/master/LICENSE). Graphene-Django is released under the [MIT License](https://github.com/graphql-python/graphene-django/blob/main/LICENSE).
## Resources ## Resources

View File

@ -144,6 +144,21 @@ If you are using ``DjangoObjectType`` you can define a custom `get_queryset`.
return queryset.filter(published=True) return queryset.filter(published=True)
return queryset return queryset
.. warning::
Defining a custom ``get_queryset`` gives the guaranteed it will be called
when resolving the ``DjangoObjectType``, even through related objects.
Note that because of this, benefits from using ``select_related``
in objects that define a relation to this ``DjangoObjectType`` will be canceled out.
In the case of ``prefetch_related``, the benefits of the optimization will be lost only
if the custom ``get_queryset`` modifies the queryset. For more information about this, refers
to Django documentation about ``prefetch_related``: https://docs.djangoproject.com/en/4.2/ref/models/querysets/#prefetch-related.
If you want to explicitly disable the execution of the custom ``get_queryset`` when resolving,
you can decorate the resolver with `@graphene_django.bypass_get_queryset`. Note that this
can lead to authorization leaks if you are performing authorization checks in the custom
``get_queryset``.
Filtering ID-based Node Access Filtering ID-based Node Access
------------------------------ ------------------------------

View File

@ -33,5 +33,6 @@ For more advanced use, check out the Relay tutorial.
authorization authorization
debug debug
introspection introspection
validation
testing testing
settings settings

View File

@ -6,7 +6,7 @@ Graphene-Django can be customised using settings. This page explains each settin
Usage Usage
----- -----
Add settings to your Django project by creating a Dictonary with name ``GRAPHENE`` in the project's ``settings.py``: Add settings to your Django project by creating a Dictionary with name ``GRAPHENE`` in the project's ``settings.py``:
.. code:: python .. code:: python
@ -142,6 +142,15 @@ Default: ``False``
# ] # ]
``DJANGO_CHOICE_FIELD_ENUM_CONVERT``
--------------------------------------
When set to ``True`` Django choice fields are automatically converted into Enum types.
Can be disabled globally by setting it to ``False``.
Default: ``True``
``DJANGO_CHOICE_FIELD_ENUM_V2_NAMING`` ``DJANGO_CHOICE_FIELD_ENUM_V2_NAMING``
-------------------------------------- --------------------------------------
@ -197,9 +206,6 @@ Set to ``False`` if you want to disable GraphiQL headers editor tab for some rea
This setting is passed to ``headerEditorEnabled`` GraphiQL options, for details refer to GraphiQLDocs_. This setting is passed to ``headerEditorEnabled`` GraphiQL options, for details refer to GraphiQLDocs_.
.. _GraphiQLDocs: https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
Default: ``True`` Default: ``True``
.. code:: python .. code:: python
@ -230,8 +236,6 @@ Set to ``True`` if you want to persist GraphiQL headers after refreshing the pag
This setting is passed to ``shouldPersistHeaders`` GraphiQL options, for details refer to GraphiQLDocs_. This setting is passed to ``shouldPersistHeaders`` GraphiQL options, for details refer to GraphiQLDocs_.
.. _GraphiQLDocs: https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
Default: ``False`` Default: ``False``
@ -240,3 +244,48 @@ Default: ``False``
GRAPHENE = { GRAPHENE = {
'GRAPHIQL_SHOULD_PERSIST_HEADERS': False, 'GRAPHIQL_SHOULD_PERSIST_HEADERS': False,
} }
``GRAPHIQL_INPUT_VALUE_DEPRECATION``
------------------------------------
Set to ``True`` if you want GraphiQL to show any deprecated fields on input object types' docs.
For example, having this schema:
.. code:: python
class MyMutationInputType(graphene.InputObjectType):
old_field = graphene.String(deprecation_reason="You should now use 'newField' instead.")
new_field = graphene.String()
class MyMutation(graphene.Mutation):
class Arguments:
input = types.MyMutationInputType()
GraphiQL will add a ``Show Deprecated Fields`` button to toggle information display on ``oldField`` and its deprecation
reason. Otherwise, you would get neither a button nor any information at all on ``oldField``.
This setting is passed to ``inputValueDeprecation`` GraphiQL options, for details refer to GraphiQLDocs_.
Default: ``False``
.. code:: python
GRAPHENE = {
'GRAPHIQL_INPUT_VALUE_DEPRECATION': False,
}
.. _GraphiQLDocs: https://graphiql-test.netlify.app/typedoc/modules/graphiql_react#graphiqlprovider-2
``MAX_VALIDATION_ERRORS``
------------------------------------
In case ``validation_rules`` are provided to ``GraphQLView``, if this is set to a non-negative ``int`` value,
``graphql.validation.validate`` will stop validation after this number of errors has been reached.
If not set or set to ``None``, the maximum number of errors will follow ``graphql.validation.validate`` default
*i.e.* 100.
Default: ``None``

View File

@ -104,7 +104,7 @@ Load some test data
Now is a good time to load up some test data. The easiest option will be Now is a good time to load up some test data. The easiest option will be
to `download the to `download the
ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/master/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__ ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/main/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__
fixture and place it in fixture and place it in
``cookbook/ingredients/fixtures/ingredients.json``. You can then run the ``cookbook/ingredients/fixtures/ingredients.json``. You can then run the
following: following:

View File

@ -7,12 +7,12 @@ Graphene has a number of additional features that are designed to make
working with Django *really simple*. working with Django *really simple*.
Note: The code in this quickstart is pulled from the `cookbook example Note: The code in this quickstart is pulled from the `cookbook example
app <https://github.com/graphql-python/graphene-django/tree/master/examples/cookbook>`__. app <https://github.com/graphql-python/graphene-django/tree/main/examples/cookbook>`__.
A good idea is to check the following things first: A good idea is to check the following things first:
* `Graphene Relay documentation <http://docs.graphene-python.org/en/latest/relay/>`__ * `Graphene Relay documentation <http://docs.graphene-python.org/en/latest/relay/>`__
* `GraphQL Relay Specification <https://facebook.github.io/relay/docs/en/graphql-server-specification.html>`__ * `GraphQL Relay Specification <https://relay.dev/docs/guides/graphql-server-specification/>`__
Setup the Django project Setup the Django project
------------------------ ------------------------
@ -87,7 +87,7 @@ Load some test data
Now is a good time to load up some test data. The easiest option will be Now is a good time to load up some test data. The easiest option will be
to `download the to `download the
ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/master/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__ ingredients.json <https://raw.githubusercontent.com/graphql-python/graphene-django/main/examples/cookbook/cookbook/ingredients/fixtures/ingredients.json>`__
fixture and place it in fixture and place it in
``cookbook/ingredients/fixtures/ingredients.json``. You can then run the ``cookbook/ingredients/fixtures/ingredients.json``. You can then run the
following: following:

29
docs/validation.rst Normal file
View File

@ -0,0 +1,29 @@
Query Validation
================
Graphene-Django supports query validation by allowing passing a list of validation rules (subclasses of `ValidationRule <https://github.com/graphql-python/graphql-core/blob/v3.2.3/src/graphql/validation/rules/__init__.py>`_ from graphql-core) to the ``validation_rules`` option in ``GraphQLView``.
.. code:: python
from django.urls import path
from graphene.validation import DisableIntrospection
from graphene_django.views import GraphQLView
urlpatterns = [
path("graphql", GraphQLView.as_view(validation_rules=(DisableIntrospection,))),
]
or
.. code:: python
from django.urls import path
from graphene.validation import DisableIntrospection
from graphene_django.views import GraphQLView
class View(GraphQLView):
validation_rules = (DisableIntrospection,)
urlpatterns = [
path("graphql", View.as_view()),
]

View File

@ -62,3 +62,12 @@ Now head on over to
and run some queries! and run some queries!
(See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema) (See the [Graphene-Django Tutorial](http://docs.graphene-python.org/projects/django/en/latest/tutorial-plain/#testing-our-graphql-schema)
for some example queries) for some example queries)
Testing local graphene-django changes
-------------------------------------
In `requirements.txt`, replace the entire `graphene-django=...` line with the following (so that we install the local version instead of the one from PyPI):
```
../../ # graphene-django
```

View File

@ -1,8 +1,8 @@
import graphene
from graphene_django.debug import DjangoDebug
import cookbook.ingredients.schema import cookbook.ingredients.schema
import cookbook.recipes.schema import cookbook.recipes.schema
import graphene
from graphene_django.debug import DjangoDebug
class Query( class Query(

View File

@ -5,10 +5,10 @@ Django settings for cookbook project.
Generated by 'django-admin startproject' using Django 1.9. Generated by 'django-admin startproject' using Django 1.9.
For more information on this file, see For more information on this file, see
https://docs.djangoproject.com/en/1.9/topics/settings/ https://docs.djangoproject.com/en/3.2/topics/settings/
For the full list of settings and their values, see For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.9/ref/settings/ https://docs.djangoproject.com/en/3.2/ref/settings/
""" """
import os import os
@ -18,7 +18,7 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production # Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/ # See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret! # SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = "_$=$%eqxk$8ss4n7mtgarw^5$8^d5+c83!vwatr@i_81myb=e4" SECRET_KEY = "_$=$%eqxk$8ss4n7mtgarw^5$8^d5+c83!vwatr@i_81myb=e4"
@ -81,7 +81,7 @@ WSGI_APPLICATION = "cookbook.wsgi.application"
# Database # Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases # https://docs.djangoproject.com/en/3.2/ref/settings/#databases
DATABASES = { DATABASES = {
"default": { "default": {
@ -90,9 +90,11 @@ DATABASES = {
} }
} }
# https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
# Password validation # Password validation
# https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators # https://docs.djangoproject.com/en/3.2/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [ AUTH_PASSWORD_VALIDATORS = [
{ {
@ -105,7 +107,7 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/1.9/topics/i18n/ # https://docs.djangoproject.com/en/3.2/topics/i18n/
LANGUAGE_CODE = "en-us" LANGUAGE_CODE = "en-us"
@ -119,6 +121,6 @@ USE_TZ = True
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.9/howto/static-files/ # https://docs.djangoproject.com/en/3.2/howto/static-files/
STATIC_URL = "/static/" STATIC_URL = "/static/"

View File

@ -1,9 +1,8 @@
from django.urls import path
from django.contrib import admin from django.contrib import admin
from django.urls import path
from graphene_django.views import GraphQLView from graphene_django.views import GraphQLView
urlpatterns = [ urlpatterns = [
path("admin/", admin.site.urls), path("admin/", admin.site.urls),
path("graphql/", GraphQLView.as_view(graphiql=True)), path("graphql/", GraphQLView.as_view(graphiql=True)),

View File

@ -1,4 +1,3 @@
graphene>=2.1,<3 django~=3.2
graphene-django>=2.1,<3 graphene
graphql-core>=2.1,<3 graphene-django>=3.1
django==3.1.14

View File

@ -1,8 +1,9 @@
from cookbook.ingredients.models import Category, Ingredient
from graphene import Node from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
from cookbook.ingredients.models import Category, Ingredient
# Graphene will automatically map the Category model's fields onto the CategoryNode. # Graphene will automatically map the Category model's fields onto the CategoryNode.
# This is configured in the CategoryNode's Meta class (as you can see below) # This is configured in the CategoryNode's Meta class (as you can see below)

View File

@ -6,7 +6,9 @@ from cookbook.ingredients.models import Ingredient
class Recipe(models.Model): class Recipe(models.Model):
title = models.CharField(max_length=100) title = models.CharField(max_length=100)
instructions = models.TextField() instructions = models.TextField()
__unicode__ = lambda self: self.title
def __unicode__(self):
return self.title
class RecipeIngredient(models.Model): class RecipeIngredient(models.Model):

View File

@ -1,8 +1,9 @@
from cookbook.recipes.models import Recipe, RecipeIngredient
from graphene import Node from graphene import Node
from graphene_django.filter import DjangoFilterConnectionField from graphene_django.filter import DjangoFilterConnectionField
from graphene_django.types import DjangoObjectType from graphene_django.types import DjangoObjectType
from cookbook.recipes.models import Recipe, RecipeIngredient
class RecipeNode(DjangoObjectType): class RecipeNode(DjangoObjectType):
class Meta: class Meta:

View File

@ -1,8 +1,8 @@
import graphene
from graphene_django.debug import DjangoDebug
import cookbook.ingredients.schema import cookbook.ingredients.schema
import cookbook.recipes.schema import cookbook.recipes.schema
import graphene
from graphene_django.debug import DjangoDebug
class Query( class Query(

View File

@ -3,7 +3,6 @@ from django.contrib import admin
from graphene_django.views import GraphQLView from graphene_django.views import GraphQLView
urlpatterns = [ urlpatterns = [
url(r"^admin/", admin.site.urls), url(r"^admin/", admin.site.urls),
url(r"^graphql$", GraphQLView.as_view(graphiql=True)), url(r"^graphql$", GraphQLView.as_view(graphiql=True)),

View File

@ -231,7 +231,7 @@
"fields": { "fields": {
"category": 3, "category": 3,
"name": "Newt", "name": "Newt",
"notes": "Braised and Confuesd" "notes": "Braised and Confused"
}, },
"model": "ingredients.ingredient", "model": "ingredients.ingredient",
"pk": 5 "pk": 5

View File

@ -1,5 +1,5 @@
graphene>=2.1,<3 graphene>=2.1,<3
graphene-django>=2.1,<3 graphene-django>=2.1,<3
graphql-core>=2.1,<3 graphql-core>=2.1,<3
django==3.1.14 django==4.2.18
django-filter>=2 django-filter>=2

View File

@ -1,5 +1,5 @@
import sys
import os import os
import sys
ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT_PATH + "/examples/") sys.path.insert(0, ROOT_PATH + "/examples/")
@ -28,3 +28,5 @@ TEMPLATES = [
GRAPHENE = {"SCHEMA": "graphene_django.tests.schema_view.schema"} GRAPHENE = {"SCHEMA": "graphene_django.tests.schema_view.schema"}
ROOT_URLCONF = "graphene_django.tests.urls" ROOT_URLCONF = "graphene_django.tests.urls"
USE_TZ = True

View File

@ -28,7 +28,7 @@ def initialize():
# Yeah, technically it's Corellian. But it flew in the service of the rebels, # Yeah, technically it's Corellian. But it flew in the service of the rebels,
# so for the purposes of this demo it's a rebel ship. # so for the purposes of this demo it's a rebel ship.
falcon = Ship(id="4", name="Millenium Falcon", faction=rebels) falcon = Ship(id="4", name="Millennium Falcon", faction=rebels)
falcon.save() falcon.save()
homeOne = Ship(id="5", name="Home One", faction=rebels) homeOne = Ship(id="5", name="Home One", faction=rebels)

View File

@ -1,11 +1,13 @@
import graphene import graphene
from graphene import Schema, relay, resolve_only_args from graphene import Schema, relay
from graphene_django import DjangoConnectionField, DjangoObjectType from graphene_django import DjangoConnectionField, DjangoObjectType
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship, get_ships from .data import create_ship, get_empire, get_faction, get_rebels, get_ship, get_ships
from .models import Character as CharacterModel from .models import (
from .models import Faction as FactionModel Character as CharacterModel,
from .models import Ship as ShipModel Faction as FactionModel,
Ship as ShipModel,
)
class Ship(DjangoObjectType): class Ship(DjangoObjectType):
@ -60,16 +62,13 @@ class Query(graphene.ObjectType):
node = relay.Node.Field() node = relay.Node.Field()
ships = DjangoConnectionField(Ship, description="All the ships.") ships = DjangoConnectionField(Ship, description="All the ships.")
@resolve_only_args def resolve_ships(self, info):
def resolve_ships(self):
return get_ships() return get_ships()
@resolve_only_args def resolve_rebels(self, info):
def resolve_rebels(self):
return get_rebels() return get_rebels()
@resolve_only_args def resolve_empire(self, info):
def resolve_empire(self):
return get_empire() return get_empire()

View File

@ -40,7 +40,7 @@ def test_mutations():
{"node": {"id": "U2hpcDox", "name": "X-Wing"}}, {"node": {"id": "U2hpcDox", "name": "X-Wing"}},
{"node": {"id": "U2hpcDoy", "name": "Y-Wing"}}, {"node": {"id": "U2hpcDoy", "name": "Y-Wing"}},
{"node": {"id": "U2hpcDoz", "name": "A-Wing"}}, {"node": {"id": "U2hpcDoz", "name": "A-Wing"}},
{"node": {"id": "U2hpcDo0", "name": "Millenium Falcon"}}, {"node": {"id": "U2hpcDo0", "name": "Millennium Falcon"}},
{"node": {"id": "U2hpcDo1", "name": "Home One"}}, {"node": {"id": "U2hpcDo1", "name": "Home One"}},
{"node": {"id": "U2hpcDo5", "name": "Peter"}}, {"node": {"id": "U2hpcDo5", "name": "Peter"}},
] ]

View File

@ -1,11 +1,13 @@
from .fields import DjangoConnectionField, DjangoListField from .fields import DjangoConnectionField, DjangoListField
from .types import DjangoObjectType from .types import DjangoObjectType
from .utils import bypass_get_queryset
__version__ = "3.1.1" __version__ = "3.2.3"
__all__ = [ __all__ = [
"__version__", "__version__",
"DjangoObjectType", "DjangoObjectType",
"DjangoListField", "DjangoListField",
"DjangoConnectionField", "DjangoConnectionField",
"bypass_get_queryset",
] ]

View File

@ -1,3 +1,13 @@
import sys
from collections.abc import Callable
from pathlib import PurePath
# For backwards compatibility, we import JSONField to have it available for import via
# this compat module (https://github.com/graphql-python/graphene-django/issues/1428).
# Django's JSONField is available in Django 3.2+ (the minimum version we support)
from django.db.models import Choices, JSONField
class MissingType: class MissingType:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -7,19 +17,49 @@ try:
# Postgres fields are only available in Django with psycopg2 installed # Postgres fields are only available in Django with psycopg2 installed
# and we cannot have psycopg2 on PyPy # and we cannot have psycopg2 on PyPy
from django.contrib.postgres.fields import ( from django.contrib.postgres.fields import (
IntegerRangeField,
ArrayField, ArrayField,
HStoreField, HStoreField,
JSONField as PGJSONField, IntegerRangeField,
RangeField, RangeField,
) )
except ImportError: except ImportError:
IntegerRangeField, ArrayField, HStoreField, PGJSONField, RangeField = ( IntegerRangeField, HStoreField, RangeField = (MissingType,) * 3
MissingType,
) * 5 # For unit tests we fake ArrayField using JSONFields
if any(
PurePath(sys.argv[0]).match(p)
for p in [
"**/pytest",
"**/py.test",
"**/pytest/__main__.py",
]
):
class ArrayField(JSONField):
def __init__(self, *args, **kwargs):
if len(args) > 0:
self.base_field = args[0]
super().__init__(**kwargs)
else:
ArrayField = MissingType
try: try:
# JSONField is only available from Django 3.1 from django.utils.choices import normalize_choices
from django.db.models import JSONField
except ImportError: except ImportError:
JSONField = MissingType
def normalize_choices(choices):
if isinstance(choices, type) and issubclass(choices, Choices):
choices = choices.choices
if isinstance(choices, Callable):
choices = choices()
# In restframework==3.15.0, choices are not passed
# as OrderedDict anymore, so it's safer to check
# for a dict
if isinstance(choices, dict):
choices = choices.items()
return choices

View File

@ -1,10 +1,11 @@
from collections import OrderedDict import inspect
from functools import singledispatch, wraps from functools import partial, singledispatch, wraps
from django.db import models from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.functional import Promise from django.utils.functional import Promise
from django.utils.module_loading import import_string from django.utils.module_loading import import_string
from graphql import GraphQLError
from graphene import ( from graphene import (
ID, ID,
@ -12,6 +13,7 @@ from graphene import (
Boolean, Boolean,
Date, Date,
DateTime, DateTime,
Decimal,
Dynamic, Dynamic,
Enum, Enum,
Field, Field,
@ -21,12 +23,11 @@ from graphene import (
NonNull, NonNull,
String, String,
Time, Time,
Decimal,
) )
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.resolver import get_default_resolver
from graphene.types.scalars import BigInt from graphene.types.scalars import BigInt
from graphene.utils.str_converters import to_camel_case from graphene.utils.str_converters import to_camel_case
from graphql import GraphQLError
try: try:
from graphql import assert_name from graphql import assert_name
@ -35,8 +36,8 @@ except ImportError:
from graphql import assert_valid_name as assert_name from graphql import assert_valid_name as assert_name
from graphql.pyutils import register_description from graphql.pyutils import register_description
from .compat import ArrayField, HStoreField, JSONField, PGJSONField, RangeField from .compat import ArrayField, HStoreField, RangeField, normalize_choices
from .fields import DjangoListField, DjangoConnectionField from .fields import DjangoConnectionField, DjangoListField
from .settings import graphene_settings from .settings import graphene_settings
from .utils.str_converters import to_const from .utils.str_converters import to_const
@ -59,6 +60,24 @@ class BlankValueField(Field):
return blank_field_wrapper(resolver) return blank_field_wrapper(resolver)
class EnumValueField(BlankValueField):
def wrap_resolve(self, parent_resolver):
resolver = super().wrap_resolve(parent_resolver)
# create custom resolver
def enum_field_wrapper(func):
@wraps(func)
def wrapped_resolver(*args, **kwargs):
return_value = func(*args, **kwargs)
if isinstance(return_value, models.Choices):
return_value = return_value.value
return return_value
return wrapped_resolver
return enum_field_wrapper(resolver)
def convert_choice_name(name): def convert_choice_name(name):
name = to_const(force_str(name)) name = to_const(force_str(name))
try: try:
@ -70,8 +89,7 @@ def convert_choice_name(name):
def get_choices(choices): def get_choices(choices):
converted_names = [] converted_names = []
if isinstance(choices, OrderedDict): choices = normalize_choices(choices)
choices = choices.items()
for value, help_text in choices: for value, help_text in choices:
if isinstance(help_text, (tuple, list)): if isinstance(help_text, (tuple, list)):
yield from get_choices(help_text) yield from get_choices(help_text)
@ -131,20 +149,24 @@ def convert_choice_field_to_enum(field, name=None):
def convert_django_field_with_choices( def convert_django_field_with_choices(
field, registry=None, convert_choices_to_enum=True field, registry=None, convert_choices_to_enum=None
): ):
if registry is not None: if registry is not None:
converted = registry.get_converted_field(field) converted = registry.get_converted_field(field)
if converted: if converted:
return converted return converted
choices = getattr(field, "choices", None) choices = getattr(field, "choices", None)
if convert_choices_to_enum is None:
convert_choices_to_enum = bool(
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT
)
if choices and convert_choices_to_enum: if choices and convert_choices_to_enum:
EnumCls = convert_choice_field_to_enum(field) EnumCls = convert_choice_field_to_enum(field)
required = not (field.blank or field.null) required = not (field.blank or field.null)
converted = EnumCls( converted = EnumCls(
description=get_django_field_description(field), required=required description=get_django_field_description(field), required=required
).mount_as(BlankValueField) ).mount_as(EnumValueField)
else: else:
converted = convert_django_field(field, registry) converted = convert_django_field(field, registry)
if registry is not None: if registry is not None:
@ -159,9 +181,7 @@ def get_django_field_description(field):
@singledispatch @singledispatch
def convert_django_field(field, registry=None): def convert_django_field(field, registry=None):
raise Exception( raise Exception(
"Don't know how to convert the Django field {} ({})".format( f"Don't know how to convert the Django field {field} ({field.__class__})"
field, field.__class__
)
) )
@ -179,19 +199,13 @@ def convert_field_to_string(field, registry=None):
) )
@convert_django_field.register(models.BigAutoField)
@convert_django_field.register(models.AutoField) @convert_django_field.register(models.AutoField)
@convert_django_field.register(models.BigAutoField)
@convert_django_field.register(models.SmallAutoField)
def convert_field_to_id(field, registry=None): def convert_field_to_id(field, registry=None):
return ID(description=get_django_field_description(field), required=not field.null) return ID(description=get_django_field_description(field), required=not field.null)
if hasattr(models, "SmallAutoField"):
@convert_django_field.register(models.SmallAutoField)
def convert_field_small_to_id(field, registry=None):
return convert_field_to_id(field, registry)
@convert_django_field.register(models.UUIDField) @convert_django_field.register(models.UUIDField)
def convert_field_to_uuid(field, registry=None): def convert_field_to_uuid(field, registry=None):
return UUID( return UUID(
@ -258,6 +272,10 @@ def convert_time_to_string(field, registry=None):
@convert_django_field.register(models.OneToOneRel) @convert_django_field.register(models.OneToOneRel)
def convert_onetoone_field_to_djangomodel(field, registry=None): def convert_onetoone_field_to_djangomodel(field, registry=None):
from graphene.utils.str_converters import to_snake_case
from .types import DjangoObjectType
model = field.related_model model = field.related_model
def dynamic_type(): def dynamic_type():
@ -265,7 +283,55 @@ def convert_onetoone_field_to_djangomodel(field, registry=None):
if not _type: if not _type:
return return
return Field(_type, required=not field.null) class CustomField(Field):
def wrap_resolve(self, parent_resolver):
"""
Implements a custom resolver which goes through the `get_node` method to ensure that
it goes through the `get_queryset` method of the DjangoObjectType.
"""
resolver = super().wrap_resolve(parent_resolver)
# If `get_queryset` was not overridden in the DjangoObjectType
# or if we explicitly bypass the `get_queryset` method,
# we can just return the default resolver.
if (
_type.get_queryset.__func__
is DjangoObjectType.get_queryset.__func__
or getattr(resolver, "_bypass_get_queryset", False)
):
return resolver
def custom_resolver(root, info, **args):
# Note: this function is used to resolve 1:1 relation fields
is_resolver_awaitable = inspect.iscoroutinefunction(resolver)
if is_resolver_awaitable:
fk_obj = resolver(root, info, **args)
# In case the resolver is a custom awaitable resolver that overwrites
# the default Django resolver
return fk_obj
field_name = to_snake_case(info.field_name)
reversed_field_name = root.__class__._meta.get_field(
field_name
).remote_field.name
try:
return _type.get_queryset(
_type._meta.model.objects.filter(
**{reversed_field_name: root.pk}
),
info,
).get()
except _type._meta.model.DoesNotExist:
return None
return custom_resolver
return CustomField(
_type,
required=not field.null,
)
return Dynamic(dynamic_type) return Dynamic(dynamic_type)
@ -313,6 +379,10 @@ def convert_field_to_list_or_connection(field, registry=None):
@convert_django_field.register(models.OneToOneField) @convert_django_field.register(models.OneToOneField)
@convert_django_field.register(models.ForeignKey) @convert_django_field.register(models.ForeignKey)
def convert_field_to_djangomodel(field, registry=None): def convert_field_to_djangomodel(field, registry=None):
from graphene.utils.str_converters import to_snake_case
from .types import DjangoObjectType
model = field.related_model model = field.related_model
def dynamic_type(): def dynamic_type():
@ -320,7 +390,79 @@ def convert_field_to_djangomodel(field, registry=None):
if not _type: if not _type:
return return
return Field( class CustomField(Field):
def wrap_resolve(self, parent_resolver):
"""
Implements a custom resolver which goes through the `get_node` method to ensure that
it goes through the `get_queryset` method of the DjangoObjectType.
"""
resolver = super().wrap_resolve(parent_resolver)
# If `get_queryset` was not overridden in the DjangoObjectType
# or if we explicitly bypass the `get_queryset` method,
# we can just return the default resolver.
if (
_type.get_queryset.__func__
is DjangoObjectType.get_queryset.__func__
or getattr(resolver, "_bypass_get_queryset", False)
):
return resolver
def custom_resolver(root, info, **args):
# Note: this function is used to resolve FK or 1:1 fields
# it does not differentiate between custom-resolved fields
# and default resolved fields.
# because this is a django foreign key or one-to-one field, the primary-key for
# this node can be accessed from the root node.
# ex: article.reporter_id
# get the name of the id field from the root's model
field_name = to_snake_case(info.field_name)
db_field_key = root.__class__._meta.get_field(field_name).attname
if hasattr(root, db_field_key):
# get the object's primary-key from root
object_pk = getattr(root, db_field_key)
else:
return None
is_resolver_awaitable = inspect.iscoroutinefunction(resolver)
if is_resolver_awaitable:
fk_obj = resolver(root, info, **args)
# In case the resolver is a custom awaitable resolver that overwrites
# the default Django resolver
return fk_obj
instance_from_get_node = _type.get_node(info, object_pk)
if instance_from_get_node is None:
# no instance to return
return
elif (
isinstance(resolver, partial)
and resolver.func is get_default_resolver()
):
return instance_from_get_node
elif resolver is not get_default_resolver():
# Default resolver is overridden
# For optimization, add the instance to the resolver
setattr(root, field_name, instance_from_get_node)
# Explanation:
# previously, _type.get_node` is called which results in at least one hit to the database.
# But, if we did not pass the instance to the root, calling the resolver will result in
# another call to get the instance which results in at least two database queries in total
# to resolve this node only.
# That's why the value of the object is set in the root so when the object is accessed
# in the resolver (root.field_name) it does not access the database unless queried explicitly.
fk_obj = resolver(root, info, **args)
return fk_obj
else:
return instance_from_get_node
return custom_resolver
return CustomField(
_type, _type,
description=get_django_field_description(field), description=get_django_field_description(field),
required=not field.null, required=not field.null,
@ -346,9 +488,8 @@ def convert_postgres_array_to_list(field, registry=None):
@convert_django_field.register(HStoreField) @convert_django_field.register(HStoreField)
@convert_django_field.register(PGJSONField) @convert_django_field.register(models.JSONField)
@convert_django_field.register(JSONField) def convert_json_field_to_string(field, registry=None):
def convert_pg_and_json_field_to_string(field, registry=None):
return JSONString( return JSONString(
description=get_django_field_description(field), required=not field.null description=get_django_field_description(field), required=not field.null
) )

View File

@ -1,9 +1,7 @@
from django.db import connections from django.db import connections
from promise import Promise
from .sql.tracking import unwrap_cursor, wrap_cursor
from .exception.formating import wrap_exception from .exception.formating import wrap_exception
from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug from .types import DjangoDebug

View File

@ -1,5 +1,6 @@
import graphene
import pytest import pytest
import graphene
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoConnectionField, DjangoObjectType from graphene_django import DjangoConnectionField, DjangoObjectType

View File

@ -1,7 +1,7 @@
from graphene import List, ObjectType from graphene import List, ObjectType
from .sql.types import DjangoDebugSQL
from .exception.types import DjangoDebugException from .exception.types import DjangoDebugException
from .sql.types import DjangoDebugSQL
class DjangoDebug(ObjectType): class DjangoDebug(ObjectType):

View File

@ -1,14 +1,12 @@
from functools import partial from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from graphql_relay import ( from graphql_relay import (
connection_from_array_slice, connection_from_array_slice,
cursor_to_offset, cursor_to_offset,
get_offset_with_default, get_offset_with_default,
offset_to_cursor, offset_to_cursor,
) )
from promise import Promise from promise import Promise
from graphene import Int, NonNull from graphene import Int, NonNull
@ -22,17 +20,20 @@ from .utils import maybe_queryset
class DjangoListField(Field): class DjangoListField(Field):
def __init__(self, _type, *args, **kwargs): def __init__(self, _type, *args, **kwargs):
from .types import DjangoObjectType
if isinstance(_type, NonNull): if isinstance(_type, NonNull):
_type = _type.of_type _type = _type.of_type
# Django would never return a Set of None vvvvvvv # Django would never return a Set of None vvvvvvv
super().__init__(List(NonNull(_type)), *args, **kwargs) super().__init__(List(NonNull(_type)), *args, **kwargs)
@property
def type(self):
from .types import DjangoObjectType
assert issubclass( assert issubclass(
self._underlying_type, DjangoObjectType self._underlying_type, DjangoObjectType
), "DjangoListField only accepts DjangoObjectType types" ), "DjangoListField only accepts DjangoObjectType types as underlying type"
return super().type
@property @property
def _underlying_type(self): def _underlying_type(self):
@ -196,7 +197,7 @@ class DjangoConnectionField(ConnectionField):
enforce_first_or_last, enforce_first_or_last,
root, root,
info, info,
**args **args,
): ):
first = args.get("first") first = args.get("first")
last = args.get("last") last = args.get("last")
@ -246,7 +247,7 @@ class DjangoConnectionField(ConnectionField):
def wrap_resolve(self, parent_resolver): def wrap_resolve(self, parent_resolver):
return partial( return partial(
self.connection_resolver, self.connection_resolver,
parent_resolver, self.resolver or parent_resolver,
self.connection_type, self.connection_type,
self.get_manager(), self.get_manager(),
self.get_queryset_resolver(), self.get_queryset_resolver(),

View File

@ -1,4 +1,5 @@
import warnings import warnings
from ..utils import DJANGO_FILTER_INSTALLED from ..utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED: if not DJANGO_FILTER_INSTALLED:

View File

@ -3,8 +3,8 @@ from functools import partial
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from graphene.types.enum import EnumType
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from graphene.types.enum import EnumType
from graphene.utils.str_converters import to_snake_case from graphene.utils.str_converters import to_snake_case
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
@ -36,7 +36,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
extra_filter_meta=None, extra_filter_meta=None,
filterset_class=None, filterset_class=None,
*args, *args,
**kwargs **kwargs,
): ):
self._fields = fields self._fields = fields
self._provided_filterset_class = filterset_class self._provided_filterset_class = filterset_class
@ -58,7 +58,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
def filterset_class(self): def filterset_class(self):
if not self._filterset_class: if not self._filterset_class:
fields = self._fields or self.node_type._meta.filter_fields fields = self._fields or self.node_type._meta.filter_fields
meta = dict(model=self.model, fields=fields) meta = {"model": self.model, "fields": fields}
if self._extra_filter_meta: if self._extra_filter_meta:
meta.update(self._extra_filter_meta) meta.update(self._extra_filter_meta)

View File

@ -1,4 +1,5 @@
import warnings import warnings
from ...utils import DJANGO_FILTER_INSTALLED from ...utils import DJANGO_FILTER_INSTALLED
if not DJANGO_FILTER_INSTALLED: if not DJANGO_FILTER_INSTALLED:

View File

@ -1,13 +1,36 @@
from django_filters.constants import EMPTY_VALUES from django_filters.constants import EMPTY_VALUES
from django_filters.filters import FilterMethod
from .typed_filter import TypedFilter from .typed_filter import TypedFilter
class ArrayFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)
class ArrayFilter(TypedFilter): class ArrayFilter(TypedFilter):
""" """
Filter made for PostgreSQL ArrayField. Filter made for PostgreSQL ArrayField.
""" """
@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ArrayFilterMethod that consider empty lists as values.
Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ArrayFilterMethod(self)
def filter(self, qs, value): def filter(self, qs, value):
""" """
Override the default filter class to check first whether the list is Override the default filter class to check first whether the list is

View File

@ -1,5 +1,4 @@
from django_filters import Filter, MultipleChoiceFilter from django_filters import Filter, MultipleChoiceFilter
from graphql_relay.node.node import from_global_id from graphql_relay.node.node import from_global_id
from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField from ...forms import GlobalIDFormField, GlobalIDMultipleChoiceField

View File

@ -1,12 +1,36 @@
from django_filters.filters import FilterMethod
from .typed_filter import TypedFilter from .typed_filter import TypedFilter
class ListFilterMethod(FilterMethod):
def __call__(self, qs, value):
if value is None:
return qs
return self.method(qs, self.f.field_name, value)
class ListFilter(TypedFilter): class ListFilter(TypedFilter):
""" """
Filter that takes a list of value as input. Filter that takes a list of value as input.
It is for example used for `__in` filters. It is for example used for `__in` filters.
""" """
@TypedFilter.method.setter
def method(self, value):
"""
Override method setter so that in case a custom `method` is provided
(see documentation https://django-filter.readthedocs.io/en/stable/ref/filters.html#method),
it doesn't fall back to checking if the value is in `EMPTY_VALUES` (from the `__call__` method
of the `FilterMethod` class) and instead use our ListFilterMethod that consider empty lists as values.
Indeed when providing a `method` the `filter` method below is overridden and replaced by `FilterMethod(self)`
which means that the validation of the empty value is made by the `FilterMethod.__call__` method instead.
"""
TypedFilter.method.fset(self, value)
if value is not None:
self.filter = ListFilterMethod(self)
def filter(self, qs, value): def filter(self, qs, value):
""" """
Override the default filter class to check first whether the list is Override the default filter class to check first whether the list is

View File

@ -1,12 +1,14 @@
import itertools import itertools
from django.db import models from django.db import models
from django_filters.filterset import BaseFilterSet, FilterSet from django_filters.filterset import (
from django_filters.filterset import FILTER_FOR_DBFIELD_DEFAULTS FILTER_FOR_DBFIELD_DEFAULTS,
BaseFilterSet,
FilterSet,
)
from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter from .filters import GlobalIDFilter, GlobalIDMultipleChoiceFilter
GRAPHENE_FILTER_SET_OVERRIDES = { GRAPHENE_FILTER_SET_OVERRIDES = {
models.AutoField: {"filter_class": GlobalIDFilter}, models.AutoField: {"filter_class": GlobalIDFilter},
models.OneToOneField: {"filter_class": GlobalIDFilter}, models.OneToOneField: {"filter_class": GlobalIDFilter},

View File

@ -1,15 +1,15 @@
from unittest.mock import MagicMock from functools import reduce
import pytest
import pytest
from django.db import models from django.db import models
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django_filters import filters
from django_filters import FilterSet from django_filters import FilterSet
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.filter import ArrayFilter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
from graphene_django.filter import ArrayFilter, ListFilter
from ...compat import ArrayField from ...compat import ArrayField
@ -25,15 +25,15 @@ else:
) )
STORE = {"events": []}
class Event(models.Model): class Event(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
tags = ArrayField(models.CharField(max_length=50)) tags = ArrayField(models.CharField(max_length=50))
tag_ids = ArrayField(models.IntegerField()) tag_ids = ArrayField(models.IntegerField())
random_field = ArrayField(models.BooleanField()) random_field = ArrayField(models.BooleanField())
def __repr__(self):
return f"Event [{self.name}]"
@pytest.fixture @pytest.fixture
def EventFilterSet(): def EventFilterSet():
@ -44,10 +44,18 @@ def EventFilterSet():
"name": ["exact", "contains"], "name": ["exact", "contains"],
} }
# Those are actually usable with our Query fixture bellow # Those are actually usable with our Query fixture below
tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains") tags__contains = ArrayFilter(field_name="tags", lookup_expr="contains")
tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap") tags__overlap = ArrayFilter(field_name="tags", lookup_expr="overlap")
tags = ArrayFilter(field_name="tags", lookup_expr="exact") tags = ArrayFilter(field_name="tags", lookup_expr="exact")
tags__len = ArrayFilter(
field_name="tags", lookup_expr="len", input_type=graphene.Int
)
tags__len__in = ArrayFilter(
field_name="tags",
method="tags__len__in_filter",
input_type=graphene.List(graphene.Int),
)
# Those are actually not usable and only to check type declarations # Those are actually not usable and only to check type declarations
tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains") tags_ids__contains = ArrayFilter(field_name="tag_ids", lookup_expr="contains")
@ -61,6 +69,14 @@ def EventFilterSet():
) )
random_field = ArrayFilter(field_name="random_field", lookup_expr="exact") random_field = ArrayFilter(field_name="random_field", lookup_expr="exact")
def tags__len__in_filter(self, queryset, _name, value):
if not value:
return queryset.none()
return reduce(
lambda q1, q2: q1.union(q2),
[queryset.filter(tags__len=v) for v in value],
).distinct()
return EventFilterSet return EventFilterSet
@ -83,10 +99,6 @@ def Query(EventType):
we are running unit tests in sqlite which does not have ArrayFields. we are running unit tests in sqlite which does not have ArrayFields.
""" """
class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)
def resolve_events(self, info, **kwargs):
events = [ events = [
Event(name="Live Show", tags=["concert", "music", "rock"]), Event(name="Live Show", tags=["concert", "music", "rock"]),
Event(name="Musical", tags=["movie", "music"]), Event(name="Musical", tags=["movie", "music"]),
@ -94,57 +106,87 @@ def Query(EventType):
Event(name="Speech", tags=[]), Event(name="Speech", tags=[]),
] ]
STORE["events"] = events class Query(graphene.ObjectType):
events = DjangoFilterConnectionField(EventType)
m_queryset = MagicMock(spec=QuerySet) def resolve_events(self, info, **kwargs):
m_queryset.model = Event class FakeQuerySet(QuerySet):
def __init__(self, model=None):
self.model = Event
self.__store = list(events)
def filter_events(**kwargs): def all(self):
return self
def filter(self, **kwargs):
queryset = FakeQuerySet()
queryset.__store = list(self.__store)
if "tags__contains" in kwargs: if "tags__contains" in kwargs:
STORE["events"] = list( queryset.__store = list(
filter( filter(
lambda e: set(kwargs["tags__contains"]).issubset( lambda e: set(kwargs["tags__contains"]).issubset(
set(e.tags) set(e.tags)
), ),
STORE["events"], queryset.__store,
) )
) )
if "tags__overlap" in kwargs: if "tags__overlap" in kwargs:
STORE["events"] = list( queryset.__store = list(
filter( filter(
lambda e: not set(kwargs["tags__overlap"]).isdisjoint( lambda e: not set(kwargs["tags__overlap"]).isdisjoint(
set(e.tags) set(e.tags)
), ),
STORE["events"], queryset.__store,
) )
) )
if "tags__exact" in kwargs: if "tags__exact" in kwargs:
STORE["events"] = list( queryset.__store = list(
filter( filter(
lambda e: set(kwargs["tags__exact"]) == set(e.tags), lambda e: set(kwargs["tags__exact"]) == set(e.tags),
STORE["events"], queryset.__store,
) )
) )
if "tags__len" in kwargs:
queryset.__store = list(
filter(
lambda e: len(e.tags) == kwargs["tags__len"],
queryset.__store,
)
)
return queryset
def mock_queryset_filter(*args, **kwargs): def union(self, *args):
filter_events(**kwargs) queryset = FakeQuerySet()
return m_queryset queryset.__store = self.__store
for arg in args:
queryset.__store += arg.__store
return queryset
def mock_queryset_none(*args, **kwargs): def none(self):
STORE["events"] = [] queryset = FakeQuerySet()
return m_queryset queryset.__store = []
return queryset
def mock_queryset_count(*args, **kwargs): def count(self):
return len(STORE["events"]) return len(self.__store)
m_queryset.all.return_value = m_queryset def distinct(self):
m_queryset.filter.side_effect = mock_queryset_filter queryset = FakeQuerySet()
m_queryset.none.side_effect = mock_queryset_none queryset.__store = []
m_queryset.count.side_effect = mock_queryset_count for event in self.__store:
m_queryset.__getitem__.side_effect = lambda index: STORE[ if event not in queryset.__store:
"events" queryset.__store.append(event)
].__getitem__(index) queryset.__store = sorted(queryset.__store, key=lambda e: e.name)
return queryset
return m_queryset def __getitem__(self, index):
return self.__store[index]
return FakeQuerySet()
return Query return Query
@pytest.fixture
def schema(Query):
return graphene.Schema(query=Query)

View File

@ -1,18 +1,14 @@
import pytest import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_multiple(Query): def test_array_field_contains_multiple(schema):
""" """
Test contains filter on a array field of string. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Contains: ["concert", "music"]) { events (tags_Contains: ["concert", "music"]) {
@ -32,13 +28,11 @@ def test_array_field_contains_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_one(Query): def test_array_field_contains_one(schema):
""" """
Test contains filter on a array field of string. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Contains: ["music"]) { events (tags_Contains: ["music"]) {
@ -59,13 +53,11 @@ def test_array_field_contains_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_contains_empty_list(Query): def test_array_field_contains_empty_list(schema):
""" """
Test contains filter on a array field of string. Test contains filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Contains: []) { events (tags_Contains: []) {

View File

@ -0,0 +1,186 @@
import pytest
from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_len_filter(schema):
query = """
query {
events (tags_Len: 2) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Musical"}},
{"node": {"name": "Ballet"}},
]
query = """
query {
events (tags_Len: 0) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Speech"}},
]
query = """
query {
events (tags_Len: 10) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
query = """
query {
events (tags_Len: "2") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == 'Int cannot represent non-integer value: "2"'
query = """
query {
events (tags_Len: True) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == "Int cannot represent non-integer value: True"
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_custom_filter(schema):
query = """
query {
events (tags_Len_In: 2) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Ballet"}},
{"node": {"name": "Musical"}},
]
query = """
query {
events (tags_Len_In: [0, 2]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == [
{"node": {"name": "Ballet"}},
{"node": {"name": "Musical"}},
{"node": {"name": "Speech"}},
]
query = """
query {
events (tags_Len_In: [10]) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
query = """
query {
events (tags_Len_In: []) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert not result.errors
assert result.data["events"]["edges"] == []
query = """
query {
events (tags_Len_In: "12") {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == 'Int cannot represent non-integer value: "12"'
query = """
query {
events (tags_Len_In: True) {
edges {
node {
name
}
}
}
}
"""
result = schema.execute(query)
assert len(result.errors) == 1
assert result.errors[0].message == "Int cannot represent non-integer value: True"

View File

@ -1,18 +1,14 @@
import pytest import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_no_match(Query): def test_array_field_exact_no_match(schema):
""" """
Test exact filter on a array field of string. Test exact filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags: ["concert", "music"]) { events (tags: ["concert", "music"]) {
@ -30,13 +26,11 @@ def test_array_field_exact_no_match(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_match(Query): def test_array_field_exact_match(schema):
""" """
Test exact filter on a array field of string. Test exact filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags: ["movie", "music"]) { events (tags: ["movie", "music"]) {
@ -56,13 +50,11 @@ def test_array_field_exact_match(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_exact_empty_list(Query): def test_array_field_exact_empty_list(schema):
""" """
Test exact filter on a array field of string. Test exact filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags: []) { events (tags: []) {
@ -82,11 +74,10 @@ def test_array_field_exact_empty_list(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_filter_schema_type(Query): def test_array_field_filter_schema_type(schema):
""" """
Check that the type in the filter is an array field like on the object type. Check that the type in the filter is an array field like on the object type.
""" """
schema = Schema(query=Query)
schema_str = str(schema) schema_str = str(schema)
assert ( assert (
@ -112,6 +103,8 @@ def test_array_field_filter_schema_type(Query):
"tags_Contains": "[String!]", "tags_Contains": "[String!]",
"tags_Overlap": "[String!]", "tags_Overlap": "[String!]",
"tags": "[String!]", "tags": "[String!]",
"tags_Len": "Int",
"tags_Len_In": "[Int]",
"tagsIds_Contains": "[Int!]", "tagsIds_Contains": "[Int!]",
"tagsIds_Overlap": "[Int!]", "tagsIds_Overlap": "[Int!]",
"tagsIds": "[Int!]", "tagsIds": "[Int!]",

View File

@ -1,18 +1,14 @@
import pytest import pytest
from graphene import Schema
from ...compat import ArrayField, MissingType from ...compat import ArrayField, MissingType
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_multiple(Query): def test_array_field_overlap_multiple(schema):
""" """
Test overlap filter on a array field of string. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Overlap: ["concert", "music"]) { events (tags_Overlap: ["concert", "music"]) {
@ -34,13 +30,11 @@ def test_array_field_overlap_multiple(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_one(Query): def test_array_field_overlap_one(schema):
""" """
Test overlap filter on a array field of string. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Overlap: ["music"]) { events (tags_Overlap: ["music"]) {
@ -61,13 +55,11 @@ def test_array_field_overlap_one(Query):
@pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist") @pytest.mark.skipif(ArrayField is MissingType, reason="ArrayField should exist")
def test_array_field_overlap_empty_list(Query): def test_array_field_overlap_empty_list(schema):
""" """
Test overlap filter on a array field of string. Test overlap filter on a array field of string.
""" """
schema = Schema(query=Query)
query = """ query = """
query { query {
events (tags_Overlap: []) { events (tags_Overlap: []) {

View File

@ -2,8 +2,7 @@ import pytest
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoConnectionField, DjangoObjectType
from graphene_django import DjangoObjectType, DjangoConnectionField
from graphene_django.tests.models import Article, Reporter from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED

View File

@ -19,8 +19,8 @@ if DJANGO_FILTER_INSTALLED:
from django_filters import FilterSet, NumberFilter, OrderingFilter from django_filters import FilterSet, NumberFilter, OrderingFilter
from graphene_django.filter import ( from graphene_django.filter import (
GlobalIDFilter,
DjangoFilterConnectionField, DjangoFilterConnectionField,
GlobalIDFilter,
GlobalIDMultipleChoiceFilter, GlobalIDMultipleChoiceFilter,
) )
from graphene_django.filter.tests.filters import ( from graphene_django.filter.tests.filters import (
@ -222,7 +222,7 @@ def test_filter_filterset_information_on_meta_related():
reporter = Field(ReporterFilterNode) reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode) article = Field(ArticleFilterNode)
schema = Schema(query=Query) Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields["articles"].get_type() articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, "headline", "reporter") assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field) assert_not_orderable(articles_field)
@ -294,7 +294,7 @@ def test_filter_filterset_class_information_on_meta_related():
reporter = Field(ReporterFilterNode) reporter = Field(ReporterFilterNode)
article = Field(ArticleFilterNode) article = Field(ArticleFilterNode)
schema = Schema(query=Query) Schema(query=Query)
articles_field = ReporterFilterNode._meta.fields["articles"].get_type() articles_field = ReporterFilterNode._meta.fields["articles"].get_type()
assert_arguments(articles_field, "headline", "reporter") assert_arguments(articles_field, "headline", "reporter")
assert_not_orderable(articles_field) assert_not_orderable(articles_field)
@ -789,7 +789,7 @@ def test_order_by():
query = """ query = """
query NodeFilteringQuery { query NodeFilteringQuery {
allReporters(orderBy: "-firtsnaMe") { allReporters(orderBy: "-firstname") {
edges { edges {
node { node {
firstName firstName
@ -802,7 +802,7 @@ def test_order_by():
assert result.errors assert result.errors
def test_order_by_is_perserved(): def test_order_by_is_preserved():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -1186,7 +1186,7 @@ def test_filter_filterset_based_on_mixin():
first_name="Adam", last_name="Doe", email="adam@doe.com" first_name="Adam", last_name="Doe", email="adam@doe.com"
) )
article_2 = Article.objects.create( Article.objects.create(
headline="Good Bye", headline="Good Bye",
reporter=reporter_2, reporter=reporter_2,
editor=reporter_2, editor=reporter_2,

View File

@ -1,14 +1,16 @@
from datetime import datetime from datetime import datetime
import pytest import pytest
from django_filters import (
FilterSet,
rest_framework as filters,
)
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema from graphene import ObjectType, Schema
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.tests.models import Pet, Person, Reporter, Article, Film
from graphene_django.filter.tests.filters import ArticleFilter from graphene_django.filter.tests.filters import ArticleFilter
from graphene_django.tests.models import Article, Film, Person, Pet, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
pytestmark = [] pytestmark = []
@ -348,9 +350,9 @@ def test_fk_id_in_filter(query):
schema = Schema(query=query) schema = Schema(query=query)
query = """ query = f"""
query {{ query {{
articles (reporter_In: [{}, {}]) {{ articles (reporter_In: [{john_doe.id}, {jean_bon.id}]) {{
edges {{ edges {{
node {{ node {{
headline headline
@ -361,10 +363,7 @@ def test_fk_id_in_filter(query):
}} }}
}} }}
}} }}
""".format( """
john_doe.id,
jean_bon.id,
)
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data["articles"]["edges"] == [ assert result.data["articles"]["edges"] == [

View File

@ -1,8 +1,7 @@
import json import json
import pytest import pytest
from django_filters import FilterSet
from django_filters import rest_framework as filters
from graphene import ObjectType, Schema from graphene import ObjectType, Schema
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType

View File

@ -1,10 +1,12 @@
import pytest import operator
from functools import reduce
import pytest
from django.db.models import Q
from django_filters import FilterSet from django_filters import FilterSet
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from graphene_django.tests.models import Article, Reporter from graphene_django.tests.models import Article, Reporter
from graphene_django.utils import DJANGO_FILTER_INSTALLED from graphene_django.utils import DJANGO_FILTER_INSTALLED
@ -14,8 +16,8 @@ pytestmark = []
if DJANGO_FILTER_INSTALLED: if DJANGO_FILTER_INSTALLED:
from graphene_django.filter import ( from graphene_django.filter import (
DjangoFilterConnectionField, DjangoFilterConnectionField,
TypedFilter,
ListFilter, ListFilter,
TypedFilter,
) )
else: else:
pytestmark.append( pytestmark.append(
@ -46,6 +48,10 @@ def schema():
only_first = TypedFilter( only_first = TypedFilter(
input_type=graphene.Boolean, method="only_first_filter" input_type=graphene.Boolean, method="only_first_filter"
) )
headline_search = ListFilter(
method="headline_search_filter",
input_type=graphene.List(graphene.String),
)
def first_n_filter(self, queryset, _name, value): def first_n_filter(self, queryset, _name, value):
return queryset[:value] return queryset[:value]
@ -56,6 +62,13 @@ def schema():
else: else:
return queryset return queryset
def headline_search_filter(self, queryset, _name, value):
if not value:
return queryset.none()
return queryset.filter(
reduce(operator.or_, [Q(headline__icontains=v) for v in value])
)
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
@ -89,6 +102,7 @@ def test_typed_filter_schema(schema):
"lang_InStr": "[String]", "lang_InStr": "[String]",
"firstN": "Int", "firstN": "Int",
"onlyFirst": "Boolean", "onlyFirst": "Boolean",
"headlineSearch": "[String]",
} }
all_articles_filters = ( all_articles_filters = (
@ -106,24 +120,7 @@ def test_typed_filters_work(schema):
Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es") Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es") Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en") Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en")
Article.objects.create(headline="AB", reporter=reporter, editor=reporter, lang="es")
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }' query = 'query { articles (lang_Contains: "n") { edges { node { headline } } } }'
@ -139,7 +136,7 @@ def test_typed_filters_work(schema):
assert not result.errors assert not result.errors
assert result.data["articles"]["edges"] == [ assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}}, {"node": {"headline": "A"}},
{"node": {"headline": "B"}}, {"node": {"headline": "AB"}},
] ]
query = "query { articles (onlyFirst: true) { edges { node { headline } } } }" query = "query { articles (onlyFirst: true) { edges { node { headline } } } }"
@ -149,3 +146,86 @@ def test_typed_filters_work(schema):
assert result.data["articles"]["edges"] == [ assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}}, {"node": {"headline": "A"}},
] ]
def test_list_filters_work(schema):
reporter = Reporter.objects.create(first_name="John", last_name="Doe", email="")
Article.objects.create(headline="A", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="B", reporter=reporter, editor=reporter, lang="es")
Article.objects.create(headline="C", reporter=reporter, editor=reporter, lang="en")
Article.objects.create(headline="AB", reporter=reporter, editor=reporter, lang="es")
query = "query { articles (lang_In: [ES]) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
]
query = 'query { articles (lang_InStr: ["es"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
]
query = "query { articles (lang_InStr: []) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == []
query = "query { articles (lang_InStr: null) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
{"node": {"headline": "C"}},
]
query = 'query { articles (headlineSearch: ["a", "B"]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
]
query = "query { articles (headlineSearch: []) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == []
query = "query { articles (headlineSearch: null) { edges { node { headline } } } }"
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
{"node": {"headline": "C"}},
]
query = 'query { articles (headlineSearch: [""]) { edges { node { headline } } } }'
result = schema.execute(query)
assert not result.errors
assert result.data["articles"]["edges"] == [
{"node": {"headline": "A"}},
{"node": {"headline": "AB"}},
{"node": {"headline": "B"}},
{"node": {"headline": "C"}},
]

View File

@ -1,10 +1,11 @@
import graphene
from django import forms from django import forms
from django_filters.utils import get_model_field, get_field_parts from django_filters.utils import get_model_field
from django_filters.filters import Filter, BaseCSVFilter
from .filters import ArrayFilter, ListFilter, RangeFilter, TypedFilter import graphene
from .filterset import custom_filterset_factory, setup_filterset
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
from .filters import ListFilter, RangeFilter, TypedFilter
from .filterset import custom_filterset_factory, setup_filterset
def get_field_type(registry, model, field_name): def get_field_type(registry, model, field_name):
@ -42,7 +43,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
isinstance(filter_field, TypedFilter) isinstance(filter_field, TypedFilter)
and filter_field.input_type is not None and filter_field.input_type is not None
): ):
# First check if the filter input type has been explicitely given # First check if the filter input type has been explicitly given
field_type = filter_field.input_type field_type = filter_field.input_type
else: else:
if name not in filterset_class.declared_filters or isinstance( if name not in filterset_class.declared_filters or isinstance(
@ -50,7 +51,7 @@ def get_filtering_args_from_filterset(filterset_class, type):
): ):
# Get the filter field for filters that are no explicitly declared. # Get the filter field for filters that are no explicitly declared.
if filter_type == "isnull": if filter_type == "isnull":
field = graphene.Boolean(required=required) field_type = graphene.Boolean
else: else:
model_field = get_model_field(model, filter_field.field_name) model_field = get_model_field(model, filter_field.field_name)
@ -144,7 +145,7 @@ def replace_csv_filters(filterset_class):
label=filter_field.label, label=filter_field.label,
method=filter_field.method, method=filter_field.method,
exclude=filter_field.exclude, exclude=filter_field.exclude,
**filter_field.extra **filter_field.extra,
) )
elif filter_type == "range": elif filter_type == "range":
filterset_class.base_filters[name] = RangeFilter( filterset_class.base_filters[name] = RangeFilter(
@ -153,5 +154,5 @@ def replace_csv_filters(filterset_class):
label=filter_field.label, label=filter_field.label,
method=filter_field.method, method=filter_field.method,
exclude=filter_field.exclude, exclude=filter_field.exclude,
**filter_field.extra **filter_field.extra,
) )

View File

@ -5,15 +5,15 @@ from django.core.exceptions import ImproperlyConfigured
from graphene import ( from graphene import (
ID, ID,
UUID,
Boolean, Boolean,
Date,
DateTime,
Decimal, Decimal,
Float, Float,
Int, Int,
List, List,
String, String,
UUID,
Date,
DateTime,
Time, Time,
) )
@ -27,8 +27,8 @@ def get_form_field_description(field):
@singledispatch @singledispatch
def convert_form_field(field): def convert_form_field(field):
raise ImproperlyConfigured( raise ImproperlyConfigured(
"Don't know how to convert the Django form field %s (%s) " f"Don't know how to convert the Django form field {field} ({field.__class__}) "
"to Graphene type" % (field, field.__class__) "to Graphene type"
) )

View File

@ -3,7 +3,6 @@ import binascii
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.forms import CharField, Field, MultipleChoiceField from django.forms import CharField, Field, MultipleChoiceField
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from graphql_relay import from_global_id from graphql_relay import from_global_id

View File

@ -23,8 +23,7 @@ def fields_for_form(form, only_fields, exclude_fields):
for name, field in form.fields.items(): for name, field in form.fields.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_excluded = ( is_excluded = (
name name in exclude_fields # or
in exclude_fields # or
# name in already_created_fields # name in already_created_fields
) )

View File

@ -1,31 +1,34 @@
from django import forms from django import VERSION as DJANGO_VERSION, forms
from pytest import raises from pytest import raises
import graphene
from graphene import ( from graphene import (
String,
Int,
Boolean,
Decimal,
Float,
ID, ID,
UUID, UUID,
Boolean,
Date,
DateTime,
Decimal,
Float,
Int,
List, List,
NonNull, NonNull,
DateTime, String,
Date,
Time, Time,
) )
from ..converter import convert_form_field from ..converter import convert_form_field
def assert_conversion(django_field, graphene_field, *args): def assert_conversion(django_field, graphene_field, *args, **kwargs):
field = django_field(*args, help_text="Custom Help Text") # Arrange
help_text = kwargs.setdefault("help_text", "Custom Help Text")
field = django_field(*args, **kwargs)
# Act
graphene_type = convert_form_field(field) graphene_type = convert_form_field(field)
# Assert
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field() field = graphene_type.Field()
assert field.description == "Custom Help Text" assert field.description == help_text
return field return field
@ -60,7 +63,12 @@ def test_should_slug_convert_string():
def test_should_url_convert_string(): def test_should_url_convert_string():
assert_conversion(forms.URLField, String) kwargs = {}
if DJANGO_VERSION >= (5, 0):
# silence RemovedInDjango60Warning
kwargs["assume_scheme"] = "https"
assert_conversion(forms.URLField, String, **kwargs)
def test_should_choice_convert_string(): def test_should_choice_convert_string():
@ -76,7 +84,6 @@ def test_should_regex_convert_string():
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(forms, "UUIDField"):
assert_conversion(forms.UUIDField, UUID) assert_conversion(forms.UUIDField, UUID)

View File

@ -1,11 +1,11 @@
import graphene
from django import forms from django import forms
from pytest import raises from pytest import raises
import graphene
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from ...tests.models import CHOICES, Film, Reporter
from ..types import DjangoFormInputObjectType from ..types import DjangoFormInputObjectType
from ...tests.models import Reporter, Film, CHOICES
# Reporter a_choice CHOICES = ((1, "this"), (2, _("that"))) # Reporter a_choice CHOICES = ((1, "this"), (2, _("that")))
THIS = CHOICES[0][0] THIS = CHOICES[0][0]
@ -31,7 +31,7 @@ class ReporterType(DjangoObjectType):
class ReporterForm(forms.ModelForm): class ReporterForm(forms.ModelForm):
class Meta: class Meta:
model = Reporter model = Reporter
exclude = ("pets", "email") exclude = ("pets", "email", "fans")
class MyForm(forms.Form): class MyForm(forms.Form):

View File

@ -1,4 +1,3 @@
import pytest
from django import forms from django import forms
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from pytest import raises from pytest import raises
@ -280,7 +279,7 @@ def test_model_form_mutation_mutate_invalid_form():
result = PetMutation.mutate_and_get_payload(None, None) result = PetMutation.mutate_and_get_payload(None, None)
# A pet was not created # A pet was not created
Pet.objects.count() == 0 assert Pet.objects.count() == 0
fields_w_error = [e.field for e in result.errors] fields_w_error = [e.field for e in result.errors]
assert len(result.errors) == 2 assert len(result.errors) == 2

View File

@ -1,12 +1,11 @@
import graphene import graphene
from graphene import ID from graphene import ID
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from graphene.utils.str_converters import to_camel_case from graphene.utils.str_converters import to_camel_case
from ..converter import EnumValueField
from ..types import ErrorType # noqa Import ErrorType for backwards compatibility
from .mutation import fields_for_form from .mutation import fields_for_form
from ..types import ErrorType # noqa Import ErrorType for backwards compatability
from ..converter import BlankValueField
class DjangoFormInputObjectType(InputObjectType): class DjangoFormInputObjectType(InputObjectType):
@ -58,11 +57,10 @@ class DjangoFormInputObjectType(InputObjectType):
if ( if (
object_type object_type
and name in object_type._meta.fields and name in object_type._meta.fields
and isinstance(object_type._meta.fields[name], BlankValueField) and isinstance(object_type._meta.fields[name], EnumValueField)
): ):
# Field type BlankValueField here means that field # Field type EnumValueField here means that field
# with choises have been converted to enum # with choices have been converted to enum
# (BlankValueField is using only for that task ?)
setattr(cls, name, cls.get_enum_cnv_cls_instance(name, object_type)) setattr(cls, name, cls.get_enum_cnv_cls_instance(name, object_type))
elif ( elif (
object_type object_type

View File

@ -1,12 +1,12 @@
import os import functools
import importlib import importlib
import json import json
import functools import os
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from django.utils import autoreload from django.utils import autoreload
from graphql import print_schema from graphql import print_schema
from graphene_django.settings import graphene_settings from graphene_django.settings import graphene_settings
@ -83,7 +83,7 @@ class Command(CommandArguments):
def handle(self, *args, **options): def handle(self, *args, **options):
options_schema = options.get("schema") options_schema = options.get("schema")
if options_schema and type(options_schema) is str: if options_schema and isinstance(options_schema, str):
module_str, schema_name = options_schema.rsplit(".", 1) module_str, schema_name = options_schema.rsplit(".", 1)
mod = importlib.import_module(module_str) mod = importlib.import_module(module_str)
schema = getattr(mod, schema_name) schema = getattr(mod, schema_name)

View File

@ -8,9 +8,7 @@ class Registry:
assert issubclass( assert issubclass(
cls, DjangoObjectType cls, DjangoObjectType
), 'Only DjangoObjectTypes can be registered, received "{}"'.format( ), f'Only DjangoObjectTypes can be registered, received "{cls.__name__}"'
cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match." assert cls._meta.registry == self, "Registry for a Model have to match."
# assert self.get_type_for_model(cls._meta.model) == cls, ( # assert self.get_type_for_model(cls._meta.model) == cls, (
# 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model) # 'Multiple DjangoObjectTypes registered for "{}"'.format(cls._meta.model)

View File

@ -14,3 +14,14 @@ class MyFakeModelWithPassword(models.Model):
class MyFakeModelWithDate(models.Model): class MyFakeModelWithDate(models.Model):
cool_name = models.CharField(max_length=50) cool_name = models.CharField(max_length=50)
last_edited = models.DateField() last_edited = models.DateField()
class MyFakeModelWithChoiceField(models.Model):
class ChoiceType(models.Choices):
ASDF = "asdf"
HI = "hi"
choice_type = models.CharField(
max_length=4,
default=ChoiceType.HI.name,
)

View File

@ -1,4 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from enum import Enum
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from rest_framework import serializers from rest_framework import serializers
@ -18,6 +19,7 @@ class SerializerMutationOptions(MutationOptions):
model_class = None model_class = None
model_operations = ["create", "update"] model_operations = ["create", "update"]
serializer_class = None serializer_class = None
optional_fields = ()
def fields_for_serializer( def fields_for_serializer(
@ -27,6 +29,7 @@ def fields_for_serializer(
is_input=False, is_input=False,
convert_choices_to_enum=True, convert_choices_to_enum=True,
lookup_field=None, lookup_field=None,
optional_fields=(),
): ):
fields = OrderedDict() fields = OrderedDict()
for name, field in serializer.fields.items(): for name, field in serializer.fields.items():
@ -39,14 +42,21 @@ def fields_for_serializer(
field.read_only field.read_only
and is_input and is_input
and lookup_field != name, # don't show read_only fields in Input and lookup_field != name, # don't show read_only fields in Input
isinstance(
field, serializers.HiddenField
), # don't show hidden fields in Input
] ]
) )
if is_not_in_only or is_excluded: if is_not_in_only or is_excluded:
continue continue
is_optional = name in optional_fields or "__all__" in optional_fields
fields[name] = convert_serializer_field( fields[name] = convert_serializer_field(
field, is_input=is_input, convert_choices_to_enum=convert_choices_to_enum field,
is_input=is_input,
convert_choices_to_enum=convert_choices_to_enum,
force_optional=is_optional,
) )
return fields return fields
@ -70,7 +80,8 @@ class SerializerMutation(ClientIDMutation):
exclude_fields=(), exclude_fields=(),
convert_choices_to_enum=True, convert_choices_to_enum=True,
_meta=None, _meta=None,
**options optional_fields=(),
**options,
): ):
if not serializer_class: if not serializer_class:
raise Exception("serializer_class is required for the SerializerMutation") raise Exception("serializer_class is required for the SerializerMutation")
@ -94,6 +105,7 @@ class SerializerMutation(ClientIDMutation):
is_input=True, is_input=True,
convert_choices_to_enum=convert_choices_to_enum, convert_choices_to_enum=convert_choices_to_enum,
lookup_field=lookup_field, lookup_field=lookup_field,
optional_fields=optional_fields,
) )
output_fields = fields_for_serializer( output_fields = fields_for_serializer(
serializer, serializer,
@ -121,8 +133,10 @@ class SerializerMutation(ClientIDMutation):
def get_serializer_kwargs(cls, root, info, **input): def get_serializer_kwargs(cls, root, info, **input):
lookup_field = cls._meta.lookup_field lookup_field = cls._meta.lookup_field
model_class = cls._meta.model_class model_class = cls._meta.model_class
if model_class: if model_class:
for input_dict_key, maybe_enum in input.items():
if isinstance(maybe_enum, Enum):
input[input_dict_key] = maybe_enum.value
if "update" in cls._meta.model_operations and lookup_field in input: if "update" in cls._meta.model_operations and lookup_field in input:
instance = get_object_or_404( instance = get_object_or_404(
model_class, **{lookup_field: input[lookup_field]} model_class, **{lookup_field: input[lookup_field]}

View File

@ -5,20 +5,22 @@ from rest_framework import serializers
import graphene import graphene
from ..registry import get_global_registry
from ..converter import convert_choices_to_named_enum_with_descriptions from ..converter import convert_choices_to_named_enum_with_descriptions
from ..registry import get_global_registry
from .types import DictType from .types import DictType
@singledispatch @singledispatch
def get_graphene_type_from_serializer_field(field): def get_graphene_type_from_serializer_field(field):
raise ImproperlyConfigured( raise ImproperlyConfigured(
"Don't know how to convert the serializer field %s (%s) " f"Don't know how to convert the serializer field {field} ({field.__class__}) "
"to Graphene type" % (field, field.__class__) "to Graphene type"
) )
def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True): def convert_serializer_field(
field, is_input=True, convert_choices_to_enum=True, force_optional=False
):
""" """
Converts a django rest frameworks field to a graphql field Converts a django rest frameworks field to a graphql field
and marks the field as required if we are creating an input type and marks the field as required if we are creating an input type
@ -31,7 +33,10 @@ def convert_serializer_field(field, is_input=True, convert_choices_to_enum=True)
graphql_type = get_graphene_type_from_serializer_field(field) graphql_type = get_graphene_type_from_serializer_field(field)
args = [] args = []
kwargs = {"description": field.help_text, "required": is_input and field.required} kwargs = {
"description": field.help_text,
"required": is_input and field.required and not force_optional,
}
# if it is a tuple or a list it means that we are returning # if it is a tuple or a list it means that we are returning
# the graphql type and the child type # the graphql type and the child type

View File

@ -1,11 +1,11 @@
import copy import copy
import graphene
from django.db import models from django.db import models
from graphene import InputObjectType
from pytest import raises from pytest import raises
from rest_framework import serializers from rest_framework import serializers
import graphene
from ..serializer_converter import convert_serializer_field from ..serializer_converter import convert_serializer_field
from ..types import DictType from ..types import DictType
@ -96,7 +96,6 @@ def test_should_regex_convert_string():
def test_should_uuid_convert_string(): def test_should_uuid_convert_string():
if hasattr(serializers, "UUIDField"):
assert_conversion(serializers.UUIDField, graphene.String) assert_conversion(serializers.UUIDField, graphene.String)

View File

@ -3,11 +3,16 @@ import datetime
from pytest import raises from pytest import raises
from rest_framework import serializers from rest_framework import serializers
from graphene import Field, ResolveInfo from graphene import Field, ResolveInfo, String
from graphene.types.inputobjecttype import InputObjectType from graphene.types.inputobjecttype import InputObjectType
from ...types import DjangoObjectType from ...types import DjangoObjectType
from ..models import MyFakeModel, MyFakeModelWithDate, MyFakeModelWithPassword from ..models import (
MyFakeModel,
MyFakeModelWithChoiceField,
MyFakeModelWithDate,
MyFakeModelWithPassword,
)
from ..mutation import SerializerMutation from ..mutation import SerializerMutation
@ -100,6 +105,16 @@ def test_exclude_fields():
assert "created" not in MyMutation.Input._meta.fields assert "created" not in MyMutation.Input._meta.fields
def test_model_serializer_optional_fields():
class MyMutation(SerializerMutation):
class Meta:
serializer_class = MyModelSerializer
optional_fields = ("cool_name",)
assert "cool_name" in MyMutation.Input._meta.fields
assert MyMutation.Input._meta.fields["cool_name"].type == String
def test_write_only_field(): def test_write_only_field():
class WriteOnlyFieldModelSerializer(serializers.ModelSerializer): class WriteOnlyFieldModelSerializer(serializers.ModelSerializer):
password = serializers.CharField(write_only=True) password = serializers.CharField(write_only=True)
@ -164,6 +179,21 @@ def test_read_only_fields():
), "'cool_name' is read_only field and shouldn't be on arguments" ), "'cool_name' is read_only field and shouldn't be on arguments"
def test_hidden_fields():
class SerializerWithHiddenField(serializers.Serializer):
cool_name = serializers.CharField()
user = serializers.HiddenField(default=serializers.CurrentUserDefault())
class MyMutation(SerializerMutation):
class Meta:
serializer_class = SerializerWithHiddenField
assert "cool_name" in MyMutation.Input._meta.fields
assert (
"user" not in MyMutation.Input._meta.fields
), "'user' is hidden field and shouldn't be on arguments"
def test_nested_model(): def test_nested_model():
class MyFakeModelGrapheneType(DjangoObjectType): class MyFakeModelGrapheneType(DjangoObjectType):
class Meta: class Meta:
@ -230,7 +260,7 @@ def test_model_invalid_update_mutate_and_get_payload_success():
model_operations = ["update"] model_operations = ["update"]
with raises(Exception) as exc: with raises(Exception) as exc:
result = InvalidModelMutation.mutate_and_get_payload( InvalidModelMutation.mutate_and_get_payload(
None, mock_info(), **{"cool_name": "Narf"} None, mock_info(), **{"cool_name": "Narf"}
) )
@ -245,7 +275,7 @@ def test_perform_mutate_success():
result = MyMethodMutation.mutate_and_get_payload( result = MyMethodMutation.mutate_and_get_payload(
None, None,
mock_info(), mock_info(),
**{"cool_name": "Narf", "last_edited": datetime.date(2020, 1, 4)} **{"cool_name": "Narf", "last_edited": datetime.date(2020, 1, 4)},
) )
assert result.errors is None assert result.errors is None
@ -253,6 +283,39 @@ def test_perform_mutate_success():
assert result.days_since_last_edit == 4 assert result.days_since_last_edit == 4
def test_perform_mutate_success_with_enum_choice_field():
class ListViewChoiceFieldSerializer(serializers.ModelSerializer):
choice_type = serializers.ChoiceField(
choices=[(x.name, x.value) for x in MyFakeModelWithChoiceField.ChoiceType],
required=False,
)
class Meta:
model = MyFakeModelWithChoiceField
fields = "__all__"
class SomeCreateSerializerMutation(SerializerMutation):
class Meta:
serializer_class = ListViewChoiceFieldSerializer
choice_type = {
"choice_type": SomeCreateSerializerMutation.Input.choice_type.type.get("ASDF")
}
name = MyFakeModelWithChoiceField.ChoiceType.ASDF.name
result = SomeCreateSerializerMutation.mutate_and_get_payload(
None, mock_info(), **choice_type
)
assert result.errors is None
assert result.choice_type == name
kwargs = SomeCreateSerializerMutation.get_serializer_kwargs(
None, mock_info(), **choice_type
)
assert kwargs["data"]["choice_type"] == name
assert 1 == MyFakeModelWithChoiceField.objects.count()
item = MyFakeModelWithChoiceField.objects.first()
assert item.choice_type == name
def test_mutate_and_get_payload_error(): def test_mutate_and_get_payload_error():
class MyMutation(SerializerMutation): class MyMutation(SerializerMutation):
class Meta: class Meta:

View File

@ -12,11 +12,10 @@ Graphene settings, checking for user settings first, then falling
back to the defaults. back to the defaults.
""" """
from django.conf import settings
from django.test.signals import setting_changed
import importlib # Available in Python 3.1+ import importlib # Available in Python 3.1+
from django.conf import settings
from django.test.signals import setting_changed
# Copied shamelessly from Django REST Framework # Copied shamelessly from Django REST Framework
@ -31,6 +30,8 @@ DEFAULTS = {
# Max items returned in ConnectionFields / FilterConnectionFields # Max items returned in ConnectionFields / FilterConnectionFields
"RELAY_CONNECTION_MAX_LIMIT": 100, "RELAY_CONNECTION_MAX_LIMIT": 100,
"CAMELCASE_ERRORS": True, "CAMELCASE_ERRORS": True,
# Automatically convert Choice fields of Django into Enum fields
"DJANGO_CHOICE_FIELD_ENUM_CONVERT": True,
# Set to True to enable v2 naming convention for choice field Enum's # Set to True to enable v2 naming convention for choice field Enum's
"DJANGO_CHOICE_FIELD_ENUM_V2_NAMING": False, "DJANGO_CHOICE_FIELD_ENUM_V2_NAMING": False,
"DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME": None, "DJANGO_CHOICE_FIELD_ENUM_CUSTOM_NAME": None,
@ -41,8 +42,10 @@ DEFAULTS = {
# https://github.com/graphql/graphiql/tree/main/packages/graphiql#options # https://github.com/graphql/graphiql/tree/main/packages/graphiql#options
"GRAPHIQL_HEADER_EDITOR_ENABLED": True, "GRAPHIQL_HEADER_EDITOR_ENABLED": True,
"GRAPHIQL_SHOULD_PERSIST_HEADERS": False, "GRAPHIQL_SHOULD_PERSIST_HEADERS": False,
"GRAPHIQL_INPUT_VALUE_DEPRECATION": False,
"ATOMIC_MUTATIONS": False, "ATOMIC_MUTATIONS": False,
"TESTING_ENDPOINT": "/graphql", "TESTING_ENDPOINT": "/graphql",
"MAX_VALIDATION_ERRORS": None,
} }
if settings.DEBUG: if settings.DEBUG:

View File

@ -122,6 +122,7 @@
onEditOperationName: onEditOperationName, onEditOperationName: onEditOperationName,
isHeadersEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled, isHeadersEditorEnabled: GRAPHENE_SETTINGS.graphiqlHeaderEditorEnabled,
shouldPersistHeaders: GRAPHENE_SETTINGS.graphiqlShouldPersistHeaders, shouldPersistHeaders: GRAPHENE_SETTINGS.graphiqlShouldPersistHeaders,
inputValueDeprecation: GRAPHENE_SETTINGS.graphiqlInputValueDeprecation,
query: query, query: query,
}; };
if (parameters.variables) { if (parameters.variables) {

View File

@ -21,6 +21,10 @@ add "&raw" to the end of the URL within a browser.
integrity="{{graphiql_css_sri}}" integrity="{{graphiql_css_sri}}"
rel="stylesheet" rel="stylesheet"
crossorigin="anonymous" /> crossorigin="anonymous" />
<link href="https://cdn.jsdelivr.net/npm/@graphiql/plugin-explorer@{{graphiql_plugin_explorer_version}}/dist/style.css"
integrity="{{graphiql_plugin_explorer_css_sri}}"
rel="stylesheet"
crossorigin="anonymous" />
<script src="https://cdn.jsdelivr.net/npm/whatwg-fetch@{{whatwg_fetch_version}}/dist/fetch.umd.js" <script src="https://cdn.jsdelivr.net/npm/whatwg-fetch@{{whatwg_fetch_version}}/dist/fetch.umd.js"
integrity="{{whatwg_fetch_sri}}" integrity="{{whatwg_fetch_sri}}"
crossorigin="anonymous"></script> crossorigin="anonymous"></script>
@ -50,6 +54,7 @@ add "&raw" to the end of the URL within a browser.
{% endif %} {% endif %}
graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }}, graphiqlHeaderEditorEnabled: {{ graphiql_header_editor_enabled|yesno:"true,false" }},
graphiqlShouldPersistHeaders: {{ graphiql_should_persist_headers|yesno:"true,false" }}, graphiqlShouldPersistHeaders: {{ graphiql_should_persist_headers|yesno:"true,false" }},
graphiqlInputValueDeprecation: {{ graphiql_input_value_deprecation|yesno:"true,false" }},
}; };
</script> </script>
<script src="{% static 'graphene_django/graphiql.js' %}"></script> <script src="{% static 'graphene_django/graphiql.js' %}"></script>

View File

@ -1,21 +1,14 @@
# https://github.com/graphql-python/graphene-django/issues/520 # https://github.com/graphql-python/graphene-django/issues/520
import datetime
from django import forms from django import forms
from rest_framework import serializers
import graphene import graphene
from graphene import Field, ResolveInfo from ...forms.mutation import DjangoFormMutation
from graphene.types.inputobjecttype import InputObjectType
from pytest import raises
from pytest import mark
from rest_framework import serializers
from ...types import DjangoObjectType
from ...rest_framework.models import MyFakeModel from ...rest_framework.models import MyFakeModel
from ...rest_framework.mutation import SerializerMutation from ...rest_framework.mutation import SerializerMutation
from ...forms.mutation import DjangoFormMutation
class MyModelSerializer(serializers.ModelSerializer): class MyModelSerializer(serializers.ModelSerializer):

View File

@ -1,11 +1,43 @@
import django
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
CHOICES = ((1, "this"), (2, _("that"))) CHOICES = ((1, "this"), (2, _("that")))
def get_choices_as_class(choices_class):
if django.VERSION >= (5, 0):
return choices_class
else:
return choices_class.choices
def get_choices_as_callable(choices_class):
if django.VERSION >= (5, 0):
def choices():
return choices_class.choices
return choices
else:
return choices_class.choices
class TypedIntChoice(models.IntegerChoices):
CHOICE_THIS = 1
CHOICE_THAT = 2
class TypedStrChoice(models.TextChoices):
CHOICE_THIS = "this"
CHOICE_THAT = "that"
class Person(models.Model): class Person(models.Model):
name = models.CharField(max_length=30) name = models.CharField(max_length=30)
parent = models.ForeignKey(
"self", on_delete=models.CASCADE, null=True, blank=True, related_name="children"
)
class Pet(models.Model): class Pet(models.Model):
@ -19,7 +51,11 @@ class Pet(models.Model):
class FilmDetails(models.Model): class FilmDetails(models.Model):
location = models.CharField(max_length=30) location = models.CharField(max_length=30)
film = models.OneToOneField( film = models.OneToOneField(
"Film", on_delete=models.CASCADE, related_name="details" "Film",
on_delete=models.CASCADE,
related_name="details",
null=True,
blank=True,
) )
@ -44,8 +80,24 @@ class Reporter(models.Model):
email = models.EmailField() email = models.EmailField()
pets = models.ManyToManyField("self") pets = models.ManyToManyField("self")
a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True) a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True)
typed_choice = models.IntegerField(
choices=TypedIntChoice.choices,
null=True,
blank=True,
)
class_choice = models.IntegerField(
choices=get_choices_as_class(TypedIntChoice),
null=True,
blank=True,
)
callable_choice = models.IntegerField(
choices=get_choices_as_callable(TypedStrChoice),
null=True,
blank=True,
)
objects = models.Manager() objects = models.Manager()
doe_objects = DoeReporterManager() doe_objects = DoeReporterManager()
fans = models.ManyToManyField(Person)
reporter_type = models.IntegerField( reporter_type = models.IntegerField(
"Reporter Type", "Reporter Type",
@ -90,6 +142,16 @@ class CNNReporter(Reporter):
objects = CNNReporterManager() objects = CNNReporterManager()
class APNewsReporter(Reporter):
"""
This class only inherits from Reporter for testing multi table inheritance
similar to what you'd see in django-polymorphic
"""
alias = models.CharField(max_length=30)
objects = models.Manager()
class Article(models.Model): class Article(models.Model):
headline = models.CharField(max_length=100) headline = models.CharField(max_length=100)
pub_date = models.DateField(auto_now_add=True) pub_date = models.DateField(auto_now_add=True)

View File

@ -1,5 +1,4 @@
from graphene import Field from graphene import Field
from graphene_django.forms.mutation import DjangoFormMutation, DjangoModelFormMutation from graphene_django.forms.mutation import DjangoFormMutation, DjangoModelFormMutation
from .forms import PetForm from .forms import PetForm

View File

@ -1,8 +1,8 @@
from io import StringIO
from textwrap import dedent from textwrap import dedent
from unittest.mock import mock_open, patch
from django.core import management from django.core import management
from io import StringIO
from unittest.mock import mock_open, patch
from graphene import ObjectType, Schema, String from graphene import ObjectType, Schema, String
@ -46,7 +46,7 @@ def test_generate_graphql_file_on_call_graphql_schema():
open_mock.assert_called_once() open_mock.assert_called_once()
handle = open_mock() handle = open_mock()
assert handle.write.called_once() handle.write.assert_called_once()
schema_output = handle.write.call_args[0][0] schema_output = handle.write.call_args[0][0]
assert schema_output == dedent( assert schema_output == dedent(

View File

@ -15,8 +15,6 @@ from graphene.types.scalars import BigInt
from ..compat import ( from ..compat import (
ArrayField, ArrayField,
HStoreField, HStoreField,
JSONField,
PGJSONField,
MissingType, MissingType,
RangeField, RangeField,
) )
@ -27,16 +25,16 @@ from ..converter import (
) )
from ..registry import Registry from ..registry import Registry
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import Article, Film, FilmDetails, Reporter from .models import Article, Film, FilmDetails, Reporter, TypedIntChoice, TypedStrChoice
# from graphene.core.types.custom_scalars import DateTime, Time, JSONString # from graphene.core.types.custom_scalars import DateTime, Time, JSONString
def assert_conversion(django_field, graphene_field, *args, **kwargs): def assert_conversion(django_field, graphene_field, *args, **kwargs):
_kwargs = kwargs.copy() _kwargs = {**kwargs, "help_text": "Custom Help Text"}
if "null" not in kwargs: if "null" not in kwargs:
_kwargs["null"] = True _kwargs["null"] = True
field = django_field(help_text="Custom Help Text", *args, **_kwargs) field = django_field(*args, **_kwargs)
graphene_type = convert_django_field(field) graphene_type = convert_django_field(field)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field() field = graphene_type.Field()
@ -55,9 +53,8 @@ def assert_conversion(django_field, graphene_field, *args, **kwargs):
def test_should_unknown_django_field_raise_exception(): def test_should_unknown_django_field_raise_exception():
with raises(Exception) as excinfo: with raises(Exception, match="Don't know how to convert the Django field"):
convert_django_field(None) convert_django_field(None)
assert "Don't know how to convert the Django field" in str(excinfo.value)
def test_should_date_time_convert_string(): def test_should_date_time_convert_string():
@ -117,7 +114,6 @@ def test_should_big_auto_convert_id():
def test_should_small_auto_convert_id(): def test_should_small_auto_convert_id():
if hasattr(models, "SmallAutoField"):
assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True) assert_conversion(models.SmallAutoField, graphene.ID, primary_key=True)
@ -168,14 +164,34 @@ def test_field_with_choices_convert_enum():
help_text="Language", choices=(("es", "Spanish"), ("en", "English")) help_text="Language", choices=(("es", "Spanish"), ("en", "English"))
) )
class TranslatedModel(models.Model): class ChoicesModel(models.Model):
language = field language = field
class Meta: class Meta:
app_label = "test" app_label = "test"
graphene_type = convert_django_field_with_choices(field).type.of_type graphene_type = convert_django_field_with_choices(field).type.of_type
assert graphene_type._meta.name == "TestTranslatedModelLanguageChoices" assert graphene_type._meta.name == "TestChoicesModelLanguageChoices"
assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en"
assert graphene_type._meta.enum.__members__["EN"].description == "English"
def test_field_with_callable_choices_convert_enum():
def get_choices():
return ("es", "Spanish"), ("en", "English")
field = models.CharField(help_text="Language", choices=get_choices)
class CallableChoicesModel(models.Model):
language = field
class Meta:
app_label = "test"
graphene_type = convert_django_field_with_choices(field).type.of_type
assert graphene_type._meta.name == "TestCallableChoicesModelLanguageChoices"
assert graphene_type._meta.enum.__members__["ES"].value == "es" assert graphene_type._meta.enum.__members__["ES"].value == "es"
assert graphene_type._meta.enum.__members__["ES"].description == "Spanish" assert graphene_type._meta.enum.__members__["ES"].description == "Spanish"
assert graphene_type._meta.enum.__members__["EN"].value == "en" assert graphene_type._meta.enum.__members__["EN"].value == "en"
@ -372,16 +388,6 @@ def test_should_postgres_hstore_convert_string():
assert_conversion(HStoreField, JSONString) assert_conversion(HStoreField, JSONString)
@pytest.mark.skipif(PGJSONField is MissingType, reason="PGJSONField should exist")
def test_should_postgres_json_convert_string():
assert_conversion(PGJSONField, JSONString)
@pytest.mark.skipif(JSONField is MissingType, reason="JSONField should exist")
def test_should_json_convert_string():
assert_conversion(JSONField, JSONString)
@pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist") @pytest.mark.skipif(RangeField is MissingType, reason="RangeField should exist")
def test_should_postgres_range_convert_list(): def test_should_postgres_range_convert_list():
from django.contrib.postgres.fields import IntegerRangeField from django.contrib.postgres.fields import IntegerRangeField
@ -435,35 +441,102 @@ def test_choice_enum_blank_value():
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
fields = ( fields = ("callable_choice",)
"first_name",
"a_choice",
)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(root, info): def resolve_reporter(root, info):
return Reporter.objects.first() # return a model instance with blank choice field value
return Reporter(callable_choice="")
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
# Create model with empty choice option
Reporter.objects.create(
first_name="Bridget", last_name="Jones", email="bridget@example.com"
)
result = schema.execute( result = schema.execute(
""" """
query { query {
reporter { reporter {
firstName callableChoice
aChoice
} }
} }
""" """
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
"reporter": {"firstName": "Bridget", "aChoice": None}, "reporter": {"callableChoice": None},
} }
def test_typed_choice_value():
"""Test that typed choices fields are resolved correctly to the enum values"""
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
fields = ("typed_choice", "class_choice", "callable_choice")
class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
def resolve_reporter(root, info):
# assign choice values to the fields instead of their str or int values
return Reporter(
typed_choice=TypedIntChoice.CHOICE_THIS,
class_choice=TypedIntChoice.CHOICE_THAT,
callable_choice=TypedStrChoice.CHOICE_THIS,
)
class CreateReporter(graphene.Mutation):
reporter = graphene.Field(ReporterType)
def mutate(root, info, **kwargs):
return CreateReporter(
reporter=Reporter(
typed_choice=TypedIntChoice.CHOICE_THIS,
class_choice=TypedIntChoice.CHOICE_THAT,
callable_choice=TypedStrChoice.CHOICE_THIS,
),
)
class Mutation(graphene.ObjectType):
create_reporter = CreateReporter.Field()
schema = graphene.Schema(query=Query, mutation=Mutation)
reporter_fragment = """
fragment reporter on ReporterType {
typedChoice
classChoice
callableChoice
}
"""
expected_reporter = {
"typedChoice": "A_1",
"classChoice": "A_2",
"callableChoice": "THIS",
}
result = schema.execute(
reporter_fragment
+ """
query {
reporter { ...reporter }
}
"""
)
assert not result.errors
assert result.data["reporter"] == expected_reporter
result = schema.execute(
reporter_fragment
+ """
mutation {
createReporter {
reporter { ...reporter }
}
}
"""
)
assert not result.errors
assert result.data["createReporter"]["reporter"] == expected_reporter

View File

@ -1,8 +1,8 @@
import datetime import datetime
import re import re
from django.db.models import Count, Prefetch
import pytest import pytest
from django.db.models import Count, Prefetch
from graphene import List, NonNull, ObjectType, Schema, String from graphene import List, NonNull, ObjectType, Schema, String
@ -12,17 +12,23 @@ from .models import (
Article as ArticleModel, Article as ArticleModel,
Film as FilmModel, Film as FilmModel,
FilmDetails as FilmDetailsModel, FilmDetails as FilmDetailsModel,
Person as PersonModel,
Reporter as ReporterModel, Reporter as ReporterModel,
) )
class TestDjangoListField: class TestDjangoListField:
def test_only_django_object_types(self): def test_only_django_object_types(self):
class TestType(ObjectType): class Query(ObjectType):
foo = String() something = DjangoListField(String)
with pytest.raises(AssertionError): with pytest.raises(TypeError) as excinfo:
list_field = DjangoListField(TestType) Schema(query=Query)
assert (
"Query fields cannot be resolved. DjangoListField only accepts DjangoObjectType types as underlying type"
in str(excinfo.value)
)
def test_only_import_paths(self): def test_only_import_paths(self):
list_field = DjangoListField("graphene_django.tests.schema.Human") list_field = DjangoListField("graphene_django.tests.schema.Human")
@ -262,6 +268,69 @@ class TestDjangoListField:
] ]
} }
def test_same_type_nested_list_field(self):
class Person(DjangoObjectType):
class Meta:
model = PersonModel
fields = ("name", "parent")
children = DjangoListField(lambda: Person)
class Query(ObjectType):
persons = DjangoListField(Person)
schema = Schema(query=Query)
query = """
query {
persons {
name
children {
name
}
}
}
"""
p1 = PersonModel.objects.create(name="Tara")
PersonModel.objects.create(name="Debra")
PersonModel.objects.create(
name="Toto",
parent=p1,
)
PersonModel.objects.create(
name="Tata",
parent=p1,
)
result = schema.execute(query)
assert not result.errors
assert result.data == {
"persons": [
{
"name": "Tara",
"children": [
{"name": "Toto"},
{"name": "Tata"},
],
},
{
"name": "Debra",
"children": [],
},
{
"name": "Toto",
"children": [],
},
{
"name": "Tata",
"children": [],
},
]
}
def test_get_queryset_filter(self): def test_get_queryset_filter(self):
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
class Meta: class Meta:

View File

@ -3,7 +3,6 @@ from pytest import raises
from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField from ..forms import GlobalIDFormField, GlobalIDMultipleChoiceField
# 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc' # 'TXlUeXBlOmFiYw==' -> 'MyType', 'abc'

View File

@ -1,14 +1,11 @@
import pytest import pytest
from graphql_relay import to_global_id
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node
from graphql_relay import to_global_id
from ..fields import DjangoConnectionField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from .models import Article, Film, FilmDetails, Reporter
from .models import Article, Reporter
class TestShouldCallGetQuerySetOnForeignKey: class TestShouldCallGetQuerySetOnForeignKey:
@ -29,6 +26,7 @@ class TestShouldCallGetQuerySetOnForeignKey:
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):
@ -39,6 +37,7 @@ class TestShouldCallGetQuerySetOnForeignKey:
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
fields = "__all__"
@classmethod @classmethod
def get_queryset(cls, queryset, info): def get_queryset(cls, queryset, info):
@ -127,6 +126,69 @@ class TestShouldCallGetQuerySetOnForeignKey:
assert not result.errors assert not result.errors
assert result.data == {"reporter": {"firstName": "Jane"}} assert result.data == {"reporter": {"firstName": "Jane"}}
def test_get_queryset_called_on_foreignkey(self):
# If a user tries to access a reporter through an article they should get our authorization error
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(query, variables={"id": self.articles[0].id})
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters through an article
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.articles[0].id},
context_value={"admin": True},
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
"reporter": {"firstName": "Jane"},
}
# An admin user should not be able to access draft article through a reporter
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
articles {
headline
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": self.reporter.id},
context_value={"admin": True},
)
assert not result.errors
assert result.data["reporter"] == {
"firstName": "Jane",
"articles": [{"headline": "A fantastic article"}],
}
class TestShouldCallGetQuerySetOnForeignKeyNode: class TestShouldCallGetQuerySetOnForeignKeyNode:
""" """
@ -140,6 +202,7 @@ class TestShouldCallGetQuerySetOnForeignKeyNode:
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
fields = "__all__"
interfaces = (Node,) interfaces = (Node,)
@classmethod @classmethod
@ -151,6 +214,7 @@ class TestShouldCallGetQuerySetOnForeignKeyNode:
class ArticleType(DjangoObjectType): class ArticleType(DjangoObjectType):
class Meta: class Meta:
model = Article model = Article
fields = "__all__"
interfaces = (Node,) interfaces = (Node,)
@classmethod @classmethod
@ -233,3 +297,274 @@ class TestShouldCallGetQuerySetOnForeignKeyNode:
) )
assert not result.errors assert not result.errors
assert result.data == {"reporter": {"firstName": "Jane"}} assert result.data == {"reporter": {"firstName": "Jane"}}
def test_get_queryset_called_on_foreignkey(self):
# If a user tries to access a reporter through an article they should get our authorization error
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query, variables={"id": to_global_id("ArticleType", self.articles[0].id)}
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access reporters."
# An admin user should be able to get reporters through an article
query = """
query getArticle($id: ID!) {
article(id: $id) {
headline
reporter {
firstName
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ArticleType", self.articles[0].id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data["article"] == {
"headline": "A fantastic article",
"reporter": {"firstName": "Jane"},
}
# An admin user should not be able to access draft article through a reporter
query = """
query getReporter($id: ID!) {
reporter(id: $id) {
firstName
articles {
edges {
node {
headline
}
}
}
}
}
"""
result = self.schema.execute(
query,
variables={"id": to_global_id("ReporterType", self.reporter.id)},
context_value={"admin": True},
)
assert not result.errors
assert result.data["reporter"] == {
"firstName": "Jane",
"articles": {"edges": [{"node": {"headline": "A fantastic article"}}]},
}
class TestShouldCallGetQuerySetOnOneToOne:
@pytest.fixture(autouse=True)
def setup_schema(self):
class FilmDetailsType(DjangoObjectType):
class Meta:
model = FilmDetails
fields = "__all__"
@classmethod
def get_queryset(cls, queryset, info):
if info.context and info.context.get("permission_get_film_details"):
return queryset
raise Exception("Not authorized to access film details.")
class FilmType(DjangoObjectType):
class Meta:
model = Film
fields = "__all__"
@classmethod
def get_queryset(cls, queryset, info):
if info.context and info.context.get("permission_get_film"):
return queryset
raise Exception("Not authorized to access film.")
class Query(graphene.ObjectType):
film_details = graphene.Field(
FilmDetailsType, id=graphene.ID(required=True)
)
film = graphene.Field(FilmType, id=graphene.ID(required=True))
def resolve_film_details(self, info, id):
return (
FilmDetailsType.get_queryset(FilmDetails.objects, info)
.filter(id=id)
.last()
)
def resolve_film(self, info, id):
return FilmType.get_queryset(Film.objects, info).filter(id=id).last()
self.schema = graphene.Schema(query=Query)
self.films = [
Film.objects.create(
genre="do",
),
Film.objects.create(
genre="ac",
),
]
self.film_details = [
FilmDetails.objects.create(
film=self.films[0],
),
FilmDetails.objects.create(
film=self.films[1],
),
]
def test_get_queryset_called_on_field(self):
# A user tries to access a film
query = """
query getFilm($id: ID!) {
film(id: $id) {
genre
}
}
"""
# With `permission_get_film`
result = self.schema.execute(
query,
variables={"id": self.films[0].id},
context_value={"permission_get_film": True},
)
assert not result.errors
assert result.data["film"] == {
"genre": "DO",
}
# Without `permission_get_film`
result = self.schema.execute(
query,
variables={"id": self.films[1].id},
context_value={"permission_get_film": False},
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access film."
# A user tries to access a film details
query = """
query getFilmDetails($id: ID!) {
filmDetails(id: $id) {
location
}
}
"""
# With `permission_get_film`
result = self.schema.execute(
query,
variables={"id": self.film_details[0].id},
context_value={"permission_get_film_details": True},
)
assert not result.errors
assert result.data == {"filmDetails": {"location": ""}}
# Without `permission_get_film`
result = self.schema.execute(
query,
variables={"id": self.film_details[0].id},
context_value={"permission_get_film_details": False},
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access film details."
def test_get_queryset_called_on_foreignkey(self, django_assert_num_queries):
# A user tries to access a film details through a film
query = """
query getFilm($id: ID!) {
film(id: $id) {
genre
details {
location
}
}
}
"""
# With `permission_get_film_details`
with django_assert_num_queries(2):
result = self.schema.execute(
query,
variables={"id": self.films[0].id},
context_value={
"permission_get_film": True,
"permission_get_film_details": True,
},
)
assert not result.errors
assert result.data["film"] == {
"genre": "DO",
"details": {"location": ""},
}
# Without `permission_get_film_details`
with django_assert_num_queries(1):
result = self.schema.execute(
query,
variables={"id": self.films[0].id},
context_value={
"permission_get_film": True,
"permission_get_film_details": False,
},
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access film details."
# A user tries to access a film through a film details
query = """
query getFilmDetails($id: ID!) {
filmDetails(id: $id) {
location
film {
genre
}
}
}
"""
# With `permission_get_film`
with django_assert_num_queries(2):
result = self.schema.execute(
query,
variables={"id": self.film_details[0].id},
context_value={
"permission_get_film": True,
"permission_get_film_details": True,
},
)
assert not result.errors
assert result.data["filmDetails"] == {
"location": "",
"film": {"genre": "DO"},
}
# Without `permission_get_film`
with django_assert_num_queries(1):
result = self.schema.execute(
query,
variables={"id": self.film_details[1].id},
context_value={
"permission_get_film": False,
"permission_get_film_details": True,
},
)
assert len(result.errors) == 1
assert result.errors[0].message == "Not authorized to access film."

View File

@ -1,5 +1,6 @@
import datetime
import base64 import base64
import datetime
from unittest.mock import ANY, Mock
import pytest import pytest
from django.db import models from django.db import models
@ -15,7 +16,16 @@ from ..compat import IntegerRangeField, MissingType
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from ..types import DjangoObjectType from ..types import DjangoObjectType
from ..utils import DJANGO_FILTER_INSTALLED from ..utils import DJANGO_FILTER_INSTALLED
from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter from .models import (
APNewsReporter,
Article,
CNNReporter,
Film,
FilmDetails,
Person,
Pet,
Reporter,
)
def test_should_query_only_fields(): def test_should_query_only_fields():
@ -117,15 +127,14 @@ def test_should_query_well():
@pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist") @pytest.mark.skipif(IntegerRangeField is MissingType, reason="RangeField should exist")
def test_should_query_postgres_fields(): def test_should_query_postgres_fields():
from django.contrib.postgres.fields import ( from django.contrib.postgres.fields import (
IntegerRangeField,
ArrayField, ArrayField,
JSONField,
HStoreField, HStoreField,
IntegerRangeField,
) )
class Event(models.Model): class Event(models.Model):
ages = IntegerRangeField(help_text="The age ranges") ages = IntegerRangeField(help_text="The age ranges")
data = JSONField(help_text="Data") data = models.JSONField(help_text="Data")
store = HStoreField() store = HStoreField()
tags = ArrayField(models.CharField(max_length=50)) tags = ArrayField(models.CharField(max_length=50))
@ -347,7 +356,7 @@ def test_should_query_connectionfields():
def test_should_keep_annotations(): def test_should_keep_annotations():
from django.db.models import Count, Avg from django.db.models import Avg, Count
class ReporterType(DjangoObjectType): class ReporterType(DjangoObjectType):
class Meta: class Meta:
@ -509,7 +518,7 @@ def test_should_query_node_filtering_with_distinct_queryset():
).distinct() ).distinct()
f = Film.objects.create() f = Film.objects.create()
fd = FilmDetails.objects.create(location="Berlin", film=f) FilmDetails.objects.create(location="Berlin", film=f)
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
query = """ query = """
@ -632,7 +641,7 @@ def test_should_enforce_first_or_last(graphene_settings):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
r = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
@ -674,7 +683,7 @@ def test_should_error_if_first_is_greater_than_max(graphene_settings):
assert Query.all_reporters.max_limit == 100 assert Query.all_reporters.max_limit == 100
r = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
@ -716,7 +725,7 @@ def test_should_error_if_last_is_greater_than_max(graphene_settings):
assert Query.all_reporters.max_limit == 100 assert Query.all_reporters.max_limit == 100
r = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
@ -780,7 +789,7 @@ def test_should_query_promise_connectionfields():
def test_should_query_connectionfields_with_last(): def test_should_query_connectionfields_with_last():
r = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
@ -817,11 +826,11 @@ def test_should_query_connectionfields_with_last():
def test_should_query_connectionfields_with_manager(): def test_should_query_connectionfields_with_manager():
r = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
r = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="NotDoe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="NotDoe", email="johndoe@example.com", a_choice=1
) )
@ -1065,11 +1074,306 @@ def test_proxy_model_support():
assert result.data == expected assert result.data == expected
def test_should_resolve_get_queryset_connectionfields(): def test_model_inheritance_support_reverse_relationships():
reporter_1 = Reporter.objects.create( """
This test asserts that we can query reverse relationships for all Reporters and proxied Reporters and multi table Reporters.
"""
class FilmType(DjangoObjectType):
class Meta:
model = Film
fields = "__all__"
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class CNNReporterType(DjangoObjectType):
class Meta:
model = CNNReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class APNewsReporterType(DjangoObjectType):
class Meta:
model = APNewsReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
film = Film.objects.create(genre="do")
reporter = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
reporter_2 = CNNReporter.objects.create(
cnn_reporter = CNNReporter.objects.create(
first_name="Some",
last_name="Guy",
email="someguy@cnn.com",
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
ap_news_reporter = APNewsReporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
film.reporters.add(cnn_reporter, ap_news_reporter)
film.save()
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
cnn_reporters = DjangoConnectionField(CNNReporterType)
ap_news_reporters = DjangoConnectionField(APNewsReporterType)
schema = graphene.Schema(query=Query)
query = """
query ProxyModelQuery {
allReporters {
edges {
node {
id
films {
id
}
}
}
}
cnnReporters {
edges {
node {
id
films {
id
}
}
}
}
apNewsReporters {
edges {
node {
id
films {
id
}
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [
{
"node": {
"id": to_global_id("ReporterType", reporter.id),
"films": [],
},
},
{
"node": {
"id": to_global_id("ReporterType", cnn_reporter.id),
"films": [{"id": f"{film.id}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", ap_news_reporter.id),
"films": [{"id": f"{film.id}"}],
},
},
]
},
"cnnReporters": {
"edges": [
{
"node": {
"id": to_global_id("CNNReporterType", cnn_reporter.id),
"films": [{"id": f"{film.id}"}],
}
}
]
},
"apNewsReporters": {
"edges": [
{
"node": {
"id": to_global_id("APNewsReporterType", ap_news_reporter.id),
"films": [{"id": f"{film.id}"}],
}
}
]
},
}
result = schema.execute(query)
assert result.data == expected
def test_model_inheritance_support_local_relationships():
"""
This test asserts that we can query local relationships for all Reporters and proxied Reporters and multi table Reporters.
"""
class PersonType(DjangoObjectType):
class Meta:
model = Person
fields = "__all__"
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class CNNReporterType(DjangoObjectType):
class Meta:
model = CNNReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
class APNewsReporterType(DjangoObjectType):
class Meta:
model = APNewsReporter
interfaces = (Node,)
use_connection = True
fields = "__all__"
film = Film.objects.create(genre="do")
reporter = Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
reporter_fan = Person.objects.create(name="Reporter Fan")
reporter.fans.add(reporter_fan)
reporter.save()
cnn_reporter = CNNReporter.objects.create(
first_name="Some",
last_name="Guy",
email="someguy@cnn.com",
a_choice=1,
reporter_type=2, # set this guy to be CNN
)
cnn_fan = Person.objects.create(name="CNN Fan")
cnn_reporter.fans.add(cnn_fan)
cnn_reporter.save()
ap_news_reporter = APNewsReporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
ap_news_fan = Person.objects.create(name="AP News Fan")
ap_news_reporter.fans.add(ap_news_fan)
ap_news_reporter.save()
film.reporters.add(cnn_reporter, ap_news_reporter)
film.save()
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
cnn_reporters = DjangoConnectionField(CNNReporterType)
ap_news_reporters = DjangoConnectionField(APNewsReporterType)
schema = graphene.Schema(query=Query)
query = """
query ProxyModelQuery {
allReporters {
edges {
node {
id
fans {
name
}
}
}
}
cnnReporters {
edges {
node {
id
fans {
name
}
}
}
}
apNewsReporters {
edges {
node {
id
fans {
name
}
}
}
}
}
"""
expected = {
"allReporters": {
"edges": [
{
"node": {
"id": to_global_id("ReporterType", reporter.id),
"fans": [{"name": f"{reporter_fan.name}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", cnn_reporter.id),
"fans": [{"name": f"{cnn_fan.name}"}],
},
},
{
"node": {
"id": to_global_id("ReporterType", ap_news_reporter.id),
"fans": [{"name": f"{ap_news_fan.name}"}],
},
},
]
},
"cnnReporters": {
"edges": [
{
"node": {
"id": to_global_id("CNNReporterType", cnn_reporter.id),
"fans": [{"name": f"{cnn_fan.name}"}],
}
}
]
},
"apNewsReporters": {
"edges": [
{
"node": {
"id": to_global_id("APNewsReporterType", ap_news_reporter.id),
"fans": [{"name": f"{ap_news_fan.name}"}],
}
}
]
},
}
result = schema.execute(query)
assert result.data == expected
def test_should_resolve_get_queryset_connectionfields():
Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
)
CNNReporter.objects.create(
first_name="Some", first_name="Some",
last_name="Guy", last_name="Guy",
email="someguy@cnn.com", email="someguy@cnn.com",
@ -1111,10 +1415,10 @@ def test_should_resolve_get_queryset_connectionfields():
def test_connection_should_limit_after_to_list_length(): def test_connection_should_limit_after_to_list_length():
reporter_1 = Reporter.objects.create( Reporter.objects.create(
first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1
) )
reporter_2 = Reporter.objects.create( Reporter.objects.create(
first_name="Some", last_name="Guy", email="someguy@cnn.com", a_choice=1 first_name="Some", last_name="Guy", email="someguy@cnn.com", a_choice=1
) )
@ -1141,19 +1445,19 @@ def test_connection_should_limit_after_to_list_length():
""" """
after = base64.b64encode(b"arrayconnection:10").decode() after = base64.b64encode(b"arrayconnection:10").decode()
result = schema.execute(query, variable_values=dict(after=after)) result = schema.execute(query, variable_values={"after": after})
expected = {"allReporters": {"edges": []}} expected = {"allReporters": {"edges": []}}
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
REPORTERS = [ REPORTERS = [
dict( {
first_name=f"First {i}", "first_name": f"First {i}",
last_name=f"Last {i}", "last_name": f"Last {i}",
email=f"johndoe+{i}@example.com", "email": f"johndoe+{i}@example.com",
a_choice=1, "a_choice": 1,
) }
for i in range(6) for i in range(6)
] ]
@ -1228,7 +1532,7 @@ def test_should_have_next_page(graphene_settings):
assert result.data["allReporters"]["pageInfo"]["hasNextPage"] assert result.data["allReporters"]["pageInfo"]["hasNextPage"]
last_result = result.data["allReporters"]["pageInfo"]["endCursor"] last_result = result.data["allReporters"]["pageInfo"]["endCursor"]
result2 = schema.execute(query, variable_values=dict(first=4, after=last_result)) result2 = schema.execute(query, variable_values={"first": 4, "after": last_result})
assert not result2.errors assert not result2.errors
assert len(result2.data["allReporters"]["edges"]) == 2 assert len(result2.data["allReporters"]["edges"]) == 2
assert not result2.data["allReporters"]["pageInfo"]["hasNextPage"] assert not result2.data["allReporters"]["pageInfo"]["hasNextPage"]
@ -1319,7 +1623,7 @@ class TestBackwardPagination:
after = base64.b64encode(b"arrayconnection:0").decode() after = base64.b64encode(b"arrayconnection:0").decode()
result = schema.execute( result = schema.execute(
query_first_last_and_after, query_first_last_and_after,
variable_values=dict(after=after), variable_values={"after": after},
) )
assert not result.errors assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 3 assert len(result.data["allReporters"]["edges"]) == 3
@ -1351,7 +1655,7 @@ class TestBackwardPagination:
before = base64.b64encode(b"arrayconnection:5").decode() before = base64.b64encode(b"arrayconnection:5").decode()
result = schema.execute( result = schema.execute(
query_first_last_and_after, query_first_last_and_after,
variable_values=dict(before=before), variable_values={"before": before},
) )
assert not result.errors assert not result.errors
assert len(result.data["allReporters"]["edges"]) == 1 assert len(result.data["allReporters"]["edges"]) == 1
@ -1407,7 +1711,7 @@ def test_should_preserve_prefetch_related(django_assert_num_queries):
""" """
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
with django_assert_num_queries(3) as captured: with django_assert_num_queries(3):
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
@ -1574,7 +1878,7 @@ def test_connection_should_forbid_offset_filtering_with_before():
} }
""" """
before = base64.b64encode(b"arrayconnection:2").decode() before = base64.b64encode(b"arrayconnection:2").decode()
result = schema.execute(query, variable_values=dict(before=before)) result = schema.execute(query, variable_values={"before": before})
expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection." expected_error = "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `allReporters` connection."
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.errors[0].message == expected_error assert result.errors[0].message == expected_error
@ -1610,7 +1914,7 @@ def test_connection_should_allow_offset_filtering_with_after():
""" """
after = base64.b64encode(b"arrayconnection:0").decode() after = base64.b64encode(b"arrayconnection:0").decode()
result = schema.execute(query, variable_values=dict(after=after)) result = schema.execute(query, variable_values={"after": after})
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": { "allReporters": {
@ -1646,7 +1950,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
} }
""" """
result = schema.execute(query, variable_values=dict(last=2)) result = schema.execute(query, variable_values={"last": 2})
assert not result.errors assert not result.errors
expected = {"allReporters": {"edges": []}} expected = {"allReporters": {"edges": []}}
assert result.data == expected assert result.data == expected
@ -1656,7 +1960,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
Reporter.objects.create(first_name="Jane", last_name="Roe") Reporter.objects.create(first_name="Jane", last_name="Roe")
Reporter.objects.create(first_name="Some", last_name="Lady") Reporter.objects.create(first_name="Some", last_name="Lady")
result = schema.execute(query, variable_values=dict(last=2)) result = schema.execute(query, variable_values={"last": 2})
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": { "allReporters": {
@ -1668,7 +1972,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
} }
assert result.data == expected assert result.data == expected
result = schema.execute(query, variable_values=dict(last=4)) result = schema.execute(query, variable_values={"last": 4})
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": { "allReporters": {
@ -1682,7 +1986,7 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
} }
assert result.data == expected assert result.data == expected
result = schema.execute(query, variable_values=dict(last=20)) result = schema.execute(query, variable_values={"last": 20})
assert not result.errors assert not result.errors
expected = { expected = {
"allReporters": { "allReporters": {
@ -1697,14 +2001,62 @@ def test_connection_should_succeed_if_last_higher_than_number_of_objects():
assert result.data == expected assert result.data == expected
def test_connection_should_call_resolver_function():
resolver_mock = Mock(
name="resolver",
return_value=[
Reporter(first_name="Some", last_name="One"),
Reporter(first_name="John", last_name="Doe"),
],
)
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
fields = "__all__"
interfaces = [Node]
class Query(graphene.ObjectType):
reporters = DjangoConnectionField(ReporterType, resolver=resolver_mock)
schema = graphene.Schema(query=Query)
result = schema.execute(
"""
query {
reporters {
edges {
node {
firstName
lastName
}
}
}
}
"""
)
resolver_mock.assert_called_once_with(None, ANY)
assert not result.errors
assert result.data == {
"reporters": {
"edges": [
{"node": {"firstName": "Some", "lastName": "One"}},
{"node": {"firstName": "John", "lastName": "Doe"}},
],
},
}
def test_should_query_nullable_foreign_key(): def test_should_query_nullable_foreign_key():
class PetType(DjangoObjectType): class PetType(DjangoObjectType):
class Meta: class Meta:
model = Pet model = Pet
fields = "__all__"
class PersonType(DjangoObjectType): class PersonType(DjangoObjectType):
class Meta: class Meta:
model = Person model = Person
fields = "__all__"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
pet = graphene.Field(PetType, name=graphene.String(required=True)) pet = graphene.Field(PetType, name=graphene.String(required=True))
@ -1719,10 +2071,8 @@ def test_should_query_nullable_foreign_key():
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
person = Person.objects.create(name="Jane") person = Person.objects.create(name="Jane")
pets = [ Pet.objects.create(name="Stray dog", age=1)
Pet.objects.create(name="Stray dog", age=1), Pet.objects.create(name="Jane's dog", owner=person, age=1)
Pet.objects.create(name="Jane's dog", owner=person, age=1),
]
query_pet = """ query_pet = """
query getPet($name: String!) { query getPet($name: String!) {
@ -1759,3 +2109,76 @@ def test_should_query_nullable_foreign_key():
assert result.data["person"] == { assert result.data["person"] == {
"pets": [{"name": "Jane's dog"}], "pets": [{"name": "Jane's dog"}],
} }
def test_should_query_nullable_one_to_one_relation_with_custom_resolver():
class FilmType(DjangoObjectType):
class Meta:
model = Film
fields = "__all__"
@classmethod
def get_queryset(cls, queryset, info):
return queryset
class FilmDetailsType(DjangoObjectType):
class Meta:
model = FilmDetails
fields = "__all__"
@classmethod
def get_queryset(cls, queryset, info):
return queryset
class Query(graphene.ObjectType):
film = graphene.Field(FilmType, genre=graphene.String(required=True))
film_details = graphene.Field(
FilmDetailsType, location=graphene.String(required=True)
)
def resolve_film(self, info, genre):
return Film.objects.filter(genre=genre).first()
def resolve_film_details(self, info, location):
return FilmDetails.objects.filter(location=location).first()
schema = graphene.Schema(query=Query)
Film.objects.create(genre="do")
FilmDetails.objects.create(location="London")
query_film = """
query getFilm($genre: String!) {
film(genre: $genre) {
genre
details {
location
}
}
}
"""
query_film_details = """
query getFilmDetails($location: String!) {
filmDetails(location: $location) {
location
film {
genre
}
}
}
"""
result = schema.execute(query_film, variables={"genre": "do"})
assert not result.errors
assert result.data["film"] == {
"genre": "DO",
"details": None,
}
result = schema.execute(query_film_details, variables={"location": "London"})
assert not result.errors
assert result.data["filmDetails"] == {
"location": "London",
"film": None,
}

View File

@ -33,17 +33,21 @@ def test_should_map_fields_correctly():
fields = "__all__" fields = "__all__"
fields = list(ReporterType2._meta.fields.keys()) fields = list(ReporterType2._meta.fields.keys())
assert fields[:-2] == [ assert fields[:-3] == [
"id", "id",
"first_name", "first_name",
"last_name", "last_name",
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"typed_choice",
"class_choice",
"callable_choice",
"fans",
"reporter_type", "reporter_type",
] ]
assert sorted(fields[-2:]) == ["articles", "films"] assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():

View File

@ -1,9 +1,10 @@
import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from textwrap import dedent from textwrap import dedent
from unittest.mock import patch
import pytest import pytest
from django.db import models from django.db import models
from unittest.mock import patch
from graphene import Connection, Field, Interface, ObjectType, Schema, String from graphene import Connection, Field, Interface, ObjectType, Schema, String
from graphene.relay import Node from graphene.relay import Node
@ -11,8 +12,10 @@ from graphene.relay import Node
from .. import registry from .. import registry
from ..filter import DjangoFilterConnectionField from ..filter import DjangoFilterConnectionField
from ..types import DjangoObjectType, DjangoObjectTypeOptions from ..types import DjangoObjectType, DjangoObjectTypeOptions
from .models import Article as ArticleModel from .models import (
from .models import Reporter as ReporterModel Article as ArticleModel,
Reporter as ReporterModel,
)
class Reporter(DjangoObjectType): class Reporter(DjangoObjectType):
@ -67,16 +70,20 @@ def test_django_get_node(get):
def test_django_objecttype_map_correct_fields(): def test_django_objecttype_map_correct_fields():
fields = Reporter._meta.fields fields = Reporter._meta.fields
fields = list(fields.keys()) fields = list(fields.keys())
assert fields[:-2] == [ assert fields[:-3] == [
"id", "id",
"first_name", "first_name",
"last_name", "last_name",
"email", "email",
"pets", "pets",
"a_choice", "a_choice",
"typed_choice",
"class_choice",
"callable_choice",
"fans",
"reporter_type", "reporter_type",
] ]
assert sorted(fields[-2:]) == ["articles", "films"] assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"]
def test_django_objecttype_with_node_have_correct_fields(): def test_django_objecttype_with_node_have_correct_fields():
@ -182,6 +189,9 @@ def test_schema_representation():
email: String! email: String!
pets: [Reporter!]! pets: [Reporter!]!
aChoice: TestsReporterAChoiceChoices aChoice: TestsReporterAChoiceChoices
typedChoice: TestsReporterTypedChoiceChoices
classChoice: TestsReporterClassChoiceChoices
callableChoice: TestsReporterCallableChoiceChoices
reporterType: TestsReporterReporterTypeChoices reporterType: TestsReporterReporterTypeChoices
articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection! articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection!
} }
@ -195,6 +205,33 @@ def test_schema_representation():
A_2 A_2
} }
\"""An enumeration.\"""
enum TestsReporterTypedChoiceChoices {
\"""Choice This\"""
A_1
\"""Choice That\"""
A_2
}
\"""An enumeration.\"""
enum TestsReporterClassChoiceChoices {
\"""Choice This\"""
A_1
\"""Choice That\"""
A_2
}
\"""An enumeration.\"""
enum TestsReporterCallableChoiceChoices {
\"""Choice This\"""
THIS
\"""Choice That\"""
THAT
}
\"""An enumeration.\""" \"""An enumeration.\"""
enum TestsReporterReporterTypeChoices { enum TestsReporterReporterTypeChoices {
\"""Regular\""" \"""Regular\"""
@ -396,7 +433,7 @@ def test_django_objecttype_fields_exist_on_model():
with pytest.warns( with pytest.warns(
UserWarning, UserWarning,
match=r"Field name .* matches an attribute on Django model .* but it's not a model field", match=r"Field name .* matches an attribute on Django model .* but it's not a model field",
) as record: ):
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
@ -404,7 +441,8 @@ def test_django_objecttype_fields_exist_on_model():
fields = ["first_name", "some_method", "email"] fields = ["first_name", "some_method", "email"]
# Don't warn if selecting a custom field # Don't warn if selecting a custom field
with pytest.warns(None) as record: with warnings.catch_warnings():
warnings.simplefilter("error")
class Reporter3(DjangoObjectType): class Reporter3(DjangoObjectType):
custom_field = String() custom_field = String()
@ -413,8 +451,6 @@ def test_django_objecttype_fields_exist_on_model():
model = ReporterModel model = ReporterModel
fields = ["first_name", "custom_field", "email"] fields = ["first_name", "custom_field", "email"]
assert len(record) == 0
@with_local_registry @with_local_registry
def test_django_objecttype_exclude_fields_exist_on_model(): def test_django_objecttype_exclude_fields_exist_on_model():
@ -442,15 +478,14 @@ def test_django_objecttype_exclude_fields_exist_on_model():
exclude = ["custom_field"] exclude = ["custom_field"]
# Don't warn on exclude fields # Don't warn on exclude fields
with pytest.warns(None) as record: with warnings.catch_warnings():
warnings.simplefilter("error")
class Reporter4(DjangoObjectType): class Reporter4(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude = ["email", "first_name"] exclude = ["email", "first_name"]
assert len(record) == 0
@with_local_registry @with_local_registry
def test_django_objecttype_neither_fields_nor_exclude(): def test_django_objecttype_neither_fields_nor_exclude():
@ -464,24 +499,22 @@ def test_django_objecttype_neither_fields_nor_exclude():
class Meta: class Meta:
model = ReporterModel model = ReporterModel
with pytest.warns(None) as record: with warnings.catch_warnings():
warnings.simplefilter("error")
class Reporter2(DjangoObjectType): class Reporter2(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
fields = ["email"] fields = ["email"]
assert len(record) == 0 with warnings.catch_warnings():
warnings.simplefilter("error")
with pytest.warns(None) as record:
class Reporter3(DjangoObjectType): class Reporter3(DjangoObjectType):
class Meta: class Meta:
model = ReporterModel model = ReporterModel
exclude = ["email"] exclude = ["email"]
assert len(record) == 0
def custom_enum_name(field): def custom_enum_name(field):
return f"CustomEnum{field.name.title()}" return f"CustomEnum{field.name.title()}"
@ -658,6 +691,122 @@ class TestDjangoObjectType:
}""" }"""
) )
def test_django_objecttype_convert_choices_global_false(
self, graphene_settings, PetModel
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT = False
class Pet(DjangoObjectType):
class Meta:
model = PetModel
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: String!
cuteness: Int!
}"""
)
def test_django_objecttype_convert_choices_true_global_false(
self, graphene_settings, PetModel
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT = False
class Pet(DjangoObjectType):
class Meta:
model = PetModel
fields = "__all__"
convert_choices_to_enum = True
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: TestsPetModelKindChoices!
cuteness: TestsPetModelCutenessChoices!
}
\"""An enumeration.\"""
enum TestsPetModelKindChoices {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}
\"""An enumeration.\"""
enum TestsPetModelCutenessChoices {
\"""Kind of cute\"""
A_1
\"""Pretty cute\"""
A_2
\"""OMG SO CUTE!!!\"""
A_3
}"""
)
def test_django_objecttype_convert_choices_enum_list_global_false(
self, graphene_settings, PetModel
):
graphene_settings.DJANGO_CHOICE_FIELD_ENUM_CONVERT = False
class Pet(DjangoObjectType):
class Meta:
model = PetModel
convert_choices_to_enum = ["kind"]
fields = "__all__"
class Query(ObjectType):
pet = Field(Pet)
schema = Schema(query=Query)
assert str(schema) == dedent(
"""\
type Query {
pet: Pet
}
type Pet {
id: ID!
kind: TestsPetModelKindChoices!
cuteness: Int!
}
\"""An enumeration.\"""
enum TestsPetModelKindChoices {
\"""Cat\"""
CAT
\"""Dog\"""
DOG
}"""
)
@with_local_registry @with_local_registry
def test_django_objecttype_name_connection_propagation(): def test_django_objecttype_name_connection_propagation():

View File

@ -1,12 +1,12 @@
import json import json
from unittest.mock import patch
import pytest import pytest
from django.utils.translation import gettext_lazy from django.utils.translation import gettext_lazy
from unittest.mock import patch
from ..utils import camelize, get_model_fields, GraphQLTestCase from ..utils import GraphQLTestCase, camelize, get_model_fields, get_reverse_fields
from .models import Film, Reporter
from ..utils.testing import graphql_query from ..utils.testing import graphql_query
from .models import APNewsReporter, CNNReporter, Film, Reporter
def test_get_model_fields_no_duplication(): def test_get_model_fields_no_duplication():
@ -19,6 +19,18 @@ def test_get_model_fields_no_duplication():
assert len(film_fields) == len(film_name_set) assert len(film_fields) == len(film_name_set)
def test_get_reverse_fields_includes_proxied_models():
reporter_fields = get_reverse_fields(Reporter, [])
cnn_reporter_fields = get_reverse_fields(CNNReporter, [])
ap_news_reporter_fields = get_reverse_fields(APNewsReporter, [])
assert (
len(list(reporter_fields))
== len(list(cnn_reporter_fields))
== len(list(ap_news_reporter_fields))
)
def test_camelize(): def test_camelize():
assert camelize({}) == {} assert camelize({}) == {}
assert camelize("value_a") == "value_a" assert camelize("value_a") == "value_a"

View File

@ -1,13 +1,10 @@
import json import json
from http import HTTPStatus
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
from django.db import connection from django.db import connection
from graphene_django.settings import graphene_settings
from .models import Pet from .models import Pet
try: try:
@ -31,13 +28,17 @@ def response_json(response):
return json.loads(response.content.decode()) return json.loads(response.content.decode())
j = lambda **kwargs: json.dumps(kwargs) def j(**kwargs):
jl = lambda **kwargs: json.dumps([kwargs]) return json.dumps(kwargs)
def jl(**kwargs):
return json.dumps([kwargs])
def test_graphiql_is_enabled(client): def test_graphiql_is_enabled(client):
response = client.get(url_string(), HTTP_ACCEPT="text/html") response = client.get(url_string(), HTTP_ACCEPT="text/html")
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response["Content-Type"].split(";")[0] == "text/html" assert response["Content-Type"].split(";")[0] == "text/html"
@ -46,7 +47,7 @@ def test_qfactor_graphiql(client):
url_string(query="{test}"), url_string(query="{test}"),
HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9", HTTP_ACCEPT="application/json;q=0.8, text/html;q=0.9",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response["Content-Type"].split(";")[0] == "text/html" assert response["Content-Type"].split(";")[0] == "text/html"
@ -55,7 +56,7 @@ def test_qfactor_json(client):
url_string(query="{test}"), url_string(query="{test}"),
HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9", HTTP_ACCEPT="text/html;q=0.8, application/json;q=0.9",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response["Content-Type"].split(";")[0] == "application/json" assert response["Content-Type"].split(";")[0] == "application/json"
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -63,7 +64,7 @@ def test_qfactor_json(client):
def test_allows_get_with_query_param(client): def test_allows_get_with_query_param(client):
response = client.get(url_string(query="{test}")) response = client.get(url_string(query="{test}"))
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -75,7 +76,7 @@ def test_allows_get_with_variable_values(client):
) )
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -94,7 +95,7 @@ def test_allows_get_with_operation_name(client):
) )
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == { assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
} }
@ -103,7 +104,7 @@ def test_allows_get_with_operation_name(client):
def test_reports_validation_errors(client): def test_reports_validation_errors(client):
response = client.get(url_string(query="{ test, unknownOne, unknownTwo }")) response = client.get(url_string(query="{ test, unknownOne, unknownTwo }"))
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{ {
@ -128,7 +129,7 @@ def test_errors_when_missing_operation_name(client):
) )
) )
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{ {
@ -146,7 +147,7 @@ def test_errors_when_sending_a_mutation_via_get(client):
""" """
) )
) )
assert response.status_code == 405 assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{"message": "Can only perform a mutation operation from a POST request."} {"message": "Can only perform a mutation operation from a POST request."}
@ -165,7 +166,7 @@ def test_errors_when_selecting_a_mutation_within_a_get(client):
) )
) )
assert response.status_code == 405 assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{"message": "Can only perform a mutation operation from a POST request."} {"message": "Can only perform a mutation operation from a POST request."}
@ -184,14 +185,14 @@ def test_allows_mutation_to_exist_within_a_get(client):
) )
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
def test_allows_post_with_json_encoding(client): def test_allows_post_with_json_encoding(client):
response = client.post(url_string(), j(query="{test}"), "application/json") response = client.post(url_string(), j(query="{test}"), "application/json")
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -200,7 +201,7 @@ def test_batch_allows_post_with_json_encoding(client):
batch_url_string(), jl(id=1, query="{test}"), "application/json" batch_url_string(), jl(id=1, query="{test}"), "application/json"
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == [ assert response_json(response) == [
{"id": 1, "data": {"test": "Hello World"}, "status": 200} {"id": 1, "data": {"test": "Hello World"}, "status": 200}
] ]
@ -209,7 +210,7 @@ def test_batch_allows_post_with_json_encoding(client):
def test_batch_fails_if_is_empty(client): def test_batch_fails_if_is_empty(client):
response = client.post(batch_url_string(), "[]", "application/json") response = client.post(batch_url_string(), "[]", "application/json")
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Received an empty list in the batch request."}] "errors": [{"message": "Received an empty list in the batch request."}]
} }
@ -222,18 +223,18 @@ def test_allows_sending_a_mutation_via_post(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}} assert response_json(response) == {"data": {"writeTest": {"test": "Hello World"}}}
def test_allows_post_with_url_encoding(client): def test_allows_post_with_url_encoding(client):
response = client.post( response = client.post(
url_string(), url_string(),
urlencode(dict(query="{test}")), urlencode({"query": "{test}"}),
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello World"}} assert response_json(response) == {"data": {"test": "Hello World"}}
@ -247,7 +248,7 @@ def test_supports_post_json_query_with_string_variables(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -262,7 +263,7 @@ def test_batch_supports_post_json_query_with_string_variables(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == [ assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
] ]
@ -278,7 +279,7 @@ def test_supports_post_json_query_with_json_variables(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -293,7 +294,7 @@ def test_batch_supports_post_json_query_with_json_variables(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == [ assert response_json(response) == [
{"id": 1, "data": {"test": "Hello Dolly"}, "status": 200} {"id": 1, "data": {"test": "Hello Dolly"}, "status": 200}
] ]
@ -303,15 +304,15 @@ def test_supports_post_url_encoded_query_with_string_variables(client):
response = client.post( response = client.post(
url_string(), url_string(),
urlencode( urlencode(
dict( {
query="query helloWho($who: String){ test(who: $who) }", "query": "query helloWho($who: String){ test(who: $who) }",
variables=json.dumps({"who": "Dolly"}), "variables": json.dumps({"who": "Dolly"}),
) }
), ),
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -322,18 +323,18 @@ def test_supports_post_json_quey_with_get_variable_values(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
def test_post_url_encoded_query_with_get_variable_values(client): def test_post_url_encoded_query_with_get_variable_values(client):
response = client.post( response = client.post(
url_string(variables=json.dumps({"who": "Dolly"})), url_string(variables=json.dumps({"who": "Dolly"})),
urlencode(dict(query="query helloWho($who: String){ test(who: $who) }")), urlencode({"query": "query helloWho($who: String){ test(who: $who) }"}),
"application/x-www-form-urlencoded", "application/x-www-form-urlencoded",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -344,7 +345,7 @@ def test_supports_post_raw_text_query_with_get_variable_values(client):
"application/graphql", "application/graphql",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"test": "Hello Dolly"}} assert response_json(response) == {"data": {"test": "Hello Dolly"}}
@ -365,7 +366,7 @@ def test_allows_post_with_operation_name(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == { assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
} }
@ -389,7 +390,7 @@ def test_batch_allows_post_with_operation_name(client):
"application/json", "application/json",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == [ assert response_json(response) == [
{ {
"id": 1, "id": 1,
@ -413,7 +414,7 @@ def test_allows_post_with_get_operation_name(client):
"application/graphql", "application/graphql",
) )
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == { assert response_json(response) == {
"data": {"test": "Hello World", "shared": "Hello Everyone"} "data": {"test": "Hello World", "shared": "Hello Everyone"}
} }
@ -430,7 +431,7 @@ def test_inherited_class_with_attributes_works(client):
# Check graphiql works # Check graphiql works
response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html") response = client.get(url_string(inherited_url), HTTP_ACCEPT="text/html")
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
@pytest.mark.urls("graphene_django.tests.urls_pretty") @pytest.mark.urls("graphene_django.tests.urls_pretty")
@ -452,7 +453,7 @@ def test_supports_pretty_printing_by_request(client):
def test_handles_field_errors_caught_by_graphql(client): def test_handles_field_errors_caught_by_graphql(client):
response = client.get(url_string(query="{thrower}")) response = client.get(url_string(query="{thrower}"))
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == { assert response_json(response) == {
"data": None, "data": None,
"errors": [ "errors": [
@ -467,7 +468,7 @@ def test_handles_field_errors_caught_by_graphql(client):
def test_handles_syntax_errors_caught_by_graphql(client): def test_handles_syntax_errors_caught_by_graphql(client):
response = client.get(url_string(query="syntaxerror")) response = client.get(url_string(query="syntaxerror"))
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [ "errors": [
{ {
@ -481,7 +482,7 @@ def test_handles_syntax_errors_caught_by_graphql(client):
def test_handles_errors_caused_by_a_lack_of_query(client): def test_handles_errors_caused_by_a_lack_of_query(client):
response = client.get(url_string()) response = client.get(url_string())
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Must provide query string."}] "errors": [{"message": "Must provide query string."}]
} }
@ -490,7 +491,7 @@ def test_handles_errors_caused_by_a_lack_of_query(client):
def test_handles_not_expected_json_bodies(client): def test_handles_not_expected_json_bodies(client):
response = client.post(url_string(), "[]", "application/json") response = client.post(url_string(), "[]", "application/json")
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "The received data is not a valid JSON query."}] "errors": [{"message": "The received data is not a valid JSON query."}]
} }
@ -499,7 +500,7 @@ def test_handles_not_expected_json_bodies(client):
def test_handles_invalid_json_bodies(client): def test_handles_invalid_json_bodies(client):
response = client.post(url_string(), "[oh}", "application/json") response = client.post(url_string(), "[oh}", "application/json")
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "POST body sent invalid JSON."}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
@ -511,17 +512,17 @@ def test_handles_django_request_error(client, monkeypatch):
monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read) monkeypatch.setattr("django.http.request.HttpRequest.read", mocked_read)
valid_json = json.dumps(dict(foo="bar")) valid_json = json.dumps({"foo": "bar"})
response = client.post(url_string(), valid_json, "application/json") response = client.post(url_string(), valid_json, "application/json")
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == {"errors": [{"message": "foo-bar"}]} assert response_json(response) == {"errors": [{"message": "foo-bar"}]}
def test_handles_incomplete_json_bodies(client): def test_handles_incomplete_json_bodies(client):
response = client.post(url_string(), '{"query":', "application/json") response = client.post(url_string(), '{"query":', "application/json")
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "POST body sent invalid JSON."}] "errors": [{"message": "POST body sent invalid JSON."}]
} }
@ -533,7 +534,7 @@ def test_handles_plain_post_text(client):
"query helloWho($who: String){ test(who: $who) }", "query helloWho($who: String){ test(who: $who) }",
"text/plain", "text/plain",
) )
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Must provide query string."}] "errors": [{"message": "Must provide query string."}]
} }
@ -545,7 +546,7 @@ def test_handles_poorly_formed_variables(client):
query="query helloWho($who: String){ test(who: $who) }", variables="who:You" query="query helloWho($who: String){ test(who: $who) }", variables="who:You"
) )
) )
assert response.status_code == 400 assert response.status_code == HTTPStatus.BAD_REQUEST
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "Variables are invalid JSON."}] "errors": [{"message": "Variables are invalid JSON."}]
} }
@ -553,7 +554,7 @@ def test_handles_poorly_formed_variables(client):
def test_handles_unsupported_http_methods(client): def test_handles_unsupported_http_methods(client):
response = client.put(url_string(query="{test}")) response = client.put(url_string(query="{test}"))
assert response.status_code == 405 assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
assert response["Allow"] == "GET, POST" assert response["Allow"] == "GET, POST"
assert response_json(response) == { assert response_json(response) == {
"errors": [{"message": "GraphQL only supports GET and POST requests."}] "errors": [{"message": "GraphQL only supports GET and POST requests."}]
@ -563,7 +564,7 @@ def test_handles_unsupported_http_methods(client):
def test_passes_request_into_context_request(client): def test_passes_request_into_context_request(client):
response = client.get(url_string(query="{request}", q="testing")) response = client.get(url_string(query="{request}", q="testing"))
assert response.status_code == 200 assert response.status_code == HTTPStatus.OK
assert response_json(response) == {"data": {"request": "testing"}} assert response_json(response) == {"data": {"request": "testing"}}
@ -827,3 +828,97 @@ def test_query_errors_atomic_request(set_rollback_mock, client):
def test_query_errors_non_atomic(set_rollback_mock, client): def test_query_errors_non_atomic(set_rollback_mock, client):
client.get(url_string(query="force error")) client.get(url_string(query="force error"))
set_rollback_mock.assert_not_called() set_rollback_mock.assert_not_called()
VALIDATION_URLS = [
"/graphql/validation/",
"/graphql/validation/alternative/",
"/graphql/validation/inherited/",
]
QUERY_WITH_TWO_INTROSPECTIONS = """
query Instrospection {
queryType: __schema {
queryType {name}
}
mutationType: __schema {
mutationType {name}
}
}
"""
N_INTROSPECTIONS = 2
INTROSPECTION_DISALLOWED_ERROR_MESSAGE = "introspection is disabled"
MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE = "too many validation errors"
@pytest.mark.urls("graphene_django.tests.urls_validation")
def test_allow_introspection(client):
response = client.post(
url_string("/graphql/", query="{__schema {queryType {name}}}")
)
assert response.status_code == HTTPStatus.OK
assert response_json(response) == {
"data": {"__schema": {"queryType": {"name": "QueryRoot"}}}
}
@pytest.mark.parametrize("url", VALIDATION_URLS)
@pytest.mark.urls("graphene_django.tests.urls_validation")
def test_validation_disallow_introspection(client, url):
response = client.post(url_string(url, query="{__schema {queryType {name}}}"))
assert response.status_code == HTTPStatus.BAD_REQUEST
json_response = response_json(response)
assert "data" not in json_response
assert "errors" in json_response
assert len(json_response["errors"]) == 1
error_message = json_response["errors"][0]["message"]
assert INTROSPECTION_DISALLOWED_ERROR_MESSAGE in error_message
@pytest.mark.parametrize("url", VALIDATION_URLS)
@pytest.mark.urls("graphene_django.tests.urls_validation")
@patch(
"graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", N_INTROSPECTIONS
)
def test_within_max_validation_errors(client, url):
response = client.post(url_string(url, query=QUERY_WITH_TWO_INTROSPECTIONS))
assert response.status_code == HTTPStatus.BAD_REQUEST
json_response = response_json(response)
assert "data" not in json_response
assert "errors" in json_response
assert len(json_response["errors"]) == N_INTROSPECTIONS
error_messages = [error["message"].lower() for error in json_response["errors"]]
n_introspection_error_messages = sum(
INTROSPECTION_DISALLOWED_ERROR_MESSAGE in msg for msg in error_messages
)
assert n_introspection_error_messages == N_INTROSPECTIONS
assert all(
MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE not in msg for msg in error_messages
)
@pytest.mark.parametrize("url", VALIDATION_URLS)
@pytest.mark.urls("graphene_django.tests.urls_validation")
@patch("graphene_django.settings.graphene_settings.MAX_VALIDATION_ERRORS", 1)
def test_exceeds_max_validation_errors(client, url):
response = client.post(url_string(url, query=QUERY_WITH_TWO_INTROSPECTIONS))
assert response.status_code == HTTPStatus.BAD_REQUEST
json_response = response_json(response)
assert "data" not in json_response
assert "errors" in json_response
error_messages = (error["message"].lower() for error in json_response["errors"])
assert any(MAX_VALIDATION_ERRORS_EXCEEDED_MESSAGE in msg for msg in error_messages)

View File

@ -0,0 +1,26 @@
from django.urls import path
from graphene.validation import DisableIntrospection
from ..views import GraphQLView
from .schema_view import schema
class View(GraphQLView):
schema = schema
class NoIntrospectionView(View):
validation_rules = (DisableIntrospection,)
class NoIntrospectionViewInherited(NoIntrospectionView):
pass
urlpatterns = [
path("graphql/", View.as_view()),
path("graphql/validation/", View.as_view(validation_rules=(DisableIntrospection,))),
path("graphql/validation/alternative/", NoIntrospectionView.as_view()),
path("graphql/validation/inherited/", NoIntrospectionViewInherited.as_view()),
]

View File

@ -1,9 +1,10 @@
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Type from typing import Type # noqa: F401
from django.db.models import Model # noqa: F401
import graphene import graphene
from django.db.models import Model
from graphene.relay import Connection, Node from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs from graphene.types.utils import yank_fields_from_attrs
@ -22,7 +23,7 @@ ALL_FIELDS = "__all__"
def construct_fields( def construct_fields(
model, registry, only_fields, exclude_fields, convert_choices_to_enum model, registry, only_fields, exclude_fields, convert_choices_to_enum=None
): ):
_model_fields = get_model_fields(model) _model_fields = get_model_fields(model)
@ -46,7 +47,7 @@ def construct_fields(
continue continue
_convert_choices_to_enum = convert_choices_to_enum _convert_choices_to_enum = convert_choices_to_enum
if not isinstance(_convert_choices_to_enum, bool): if isinstance(_convert_choices_to_enum, list):
# then `convert_choices_to_enum` is a list of field names to convert # then `convert_choices_to_enum` is a list of field names to convert
if name in _convert_choices_to_enum: if name in _convert_choices_to_enum:
_convert_choices_to_enum = True _convert_choices_to_enum = True
@ -101,10 +102,8 @@ def validate_fields(type_, model, fields, only_fields, exclude_fields):
if name in all_field_names: if name in all_field_names:
# Field is a custom field # Field is a custom field
warnings.warn( warnings.warn(
( f'Excluding the custom field "{name}" on DjangoObjectType "{type_}" has no effect. '
'Excluding the custom field "{field_name}" on DjangoObjectType "{type_}" has no effect. '
'Either remove the custom field or remove the field from the "exclude" list.' 'Either remove the custom field or remove the field from the "exclude" list.'
).format(field_name=name, type_=type_)
) )
else: else:
if not hasattr(model, name): if not hasattr(model, name):
@ -147,9 +146,9 @@ class DjangoObjectType(ObjectType):
connection_class=None, connection_class=None,
use_connection=None, use_connection=None,
interfaces=(), interfaces=(),
convert_choices_to_enum=True, convert_choices_to_enum=None,
_meta=None, _meta=None,
**options **options,
): ):
assert is_valid_django_model(model), ( assert is_valid_django_model(model), (
'You need to pass a valid Django Model in {}.Meta, received "{}".' 'You need to pass a valid Django Model in {}.Meta, received "{}".'
@ -159,9 +158,9 @@ class DjangoObjectType(ObjectType):
registry = get_global_registry() registry = get_global_registry()
assert isinstance(registry, Registry), ( assert isinstance(registry, Registry), (
"The attribute registry in {} needs to be an instance of " f"The attribute registry in {cls.__name__} needs to be an instance of "
'Registry, received "{}".' f'Registry, received "{registry}".'
).format(cls.__name__, registry) )
if filter_fields and filterset_class: if filter_fields and filterset_class:
raise Exception("Can't set both filter_fields and filterset_class") raise Exception("Can't set both filter_fields and filterset_class")
@ -174,7 +173,7 @@ class DjangoObjectType(ObjectType):
assert not (fields and exclude), ( assert not (fields and exclude), (
"Cannot set both 'fields' and 'exclude' options on " "Cannot set both 'fields' and 'exclude' options on "
"DjangoObjectType {class_name}.".format(class_name=cls.__name__) f"DjangoObjectType {cls.__name__}."
) )
# Alias only_fields -> fields # Alias only_fields -> fields
@ -213,8 +212,8 @@ class DjangoObjectType(ObjectType):
warnings.warn( warnings.warn(
"Creating a DjangoObjectType without either the `fields` " "Creating a DjangoObjectType without either the `fields` "
"or the `exclude` option is deprecated. Add an explicit `fields " "or the `exclude` option is deprecated. Add an explicit `fields "
"= '__all__'` option on DjangoObjectType {class_name} to use all " f"= '__all__'` option on DjangoObjectType {cls.__name__} to use all "
"fields".format(class_name=cls.__name__), "fields",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
@ -239,9 +238,9 @@ class DjangoObjectType(ObjectType):
) )
if connection is not None: if connection is not None:
assert issubclass(connection, Connection), ( assert issubclass(
"The connection must be a Connection. Received {}" connection, Connection
).format(connection.__name__) ), f"The connection must be a Connection. Received {connection.__name__}"
if not _meta: if not _meta:
_meta = DjangoObjectTypeOptions(cls) _meta = DjangoObjectTypeOptions(cls)
@ -272,7 +271,7 @@ class DjangoObjectType(ObjectType):
if isinstance(root, cls): if isinstance(root, cls):
return True return True
if not is_valid_django_model(root.__class__): if not is_valid_django_model(root.__class__):
raise Exception(('Received incompatible instance "{}".').format(root)) raise Exception(f'Received incompatible instance "{root}".')
if cls._meta.model._meta.proxy: if cls._meta.model._meta.proxy:
model = root._meta.model model = root._meta.model

View File

@ -1,6 +1,7 @@
from .testing import GraphQLTestCase from .testing import GraphQLTestCase
from .utils import ( from .utils import (
DJANGO_FILTER_INSTALLED, DJANGO_FILTER_INSTALLED,
bypass_get_queryset,
camelize, camelize,
get_model_fields, get_model_fields,
get_reverse_fields, get_reverse_fields,
@ -16,4 +17,5 @@ __all__ = [
"camelize", "camelize",
"is_valid_django_model", "is_valid_django_model",
"GraphQLTestCase", "GraphQLTestCase",
"bypass_get_queryset",
] ]

View File

@ -1,4 +1,5 @@
import re import re
from text_unidecode import unidecode from text_unidecode import unidecode

View File

@ -4,6 +4,7 @@ import warnings
from django.test import Client, TestCase, TransactionTestCase from django.test import Client, TestCase, TransactionTestCase
from graphene_django.settings import graphene_settings from graphene_django.settings import graphene_settings
from graphene_django.utils.utils import _DJANGO_VERSION_AT_LEAST_4_2
DEFAULT_GRAPHQL_URL = "/graphql" DEFAULT_GRAPHQL_URL = "/graphql"
@ -55,8 +56,14 @@ def graphql_query(
else: else:
body["variables"] = {"input": input_data} body["variables"] = {"input": input_data}
if headers: if headers:
header_params = (
{"headers": headers} if _DJANGO_VERSION_AT_LEAST_4_2 else headers
)
resp = client.post( resp = client.post(
graphql_url, json.dumps(body), content_type="application/json", **headers graphql_url,
json.dumps(body),
content_type="application/json",
**header_params,
) )
else: else:
resp = client.post( resp = client.post(

View File

@ -1,10 +1,10 @@
import pytest import pytest
from .. import GraphQLTestCase
from ...tests.test_types import with_local_registry
from ...settings import graphene_settings
from django.test import Client from django.test import Client
from ...settings import graphene_settings
from ...tests.test_types import with_local_registry
from .. import GraphQLTestCase
@with_local_registry @with_local_registry
def test_graphql_test_case_deprecated_client_getter(): def test_graphql_test_case_deprecated_client_getter():
@ -23,7 +23,7 @@ def test_graphql_test_case_deprecated_client_getter():
tc.setUpClass() tc.setUpClass()
with pytest.warns(PendingDeprecationWarning): with pytest.warns(PendingDeprecationWarning):
tc._client tc._client # noqa: B018
@with_local_registry @with_local_registry

View File

@ -1,5 +1,6 @@
import inspect import inspect
import django
from django.db import connection, models, transaction from django.db import connection, models, transaction
from django.db.models.manager import Manager from django.db.models.manager import Manager
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -37,8 +38,24 @@ def camelize(data):
return data return data
def _get_model_ancestry(model):
model_ancestry = [model]
for base in model.__bases__:
if is_valid_django_model(base) and getattr(base, "_meta", False):
model_ancestry.append(base)
return model_ancestry
def get_reverse_fields(model, local_field_names): def get_reverse_fields(model, local_field_names):
for name, attr in model.__dict__.items(): """
Searches through the model's ancestry and gets reverse relationships the models
Yields a tuple of (field.name, field)
"""
model_ancestry = _get_model_ancestry(model)
for _model in model_ancestry:
for name, attr in _model.__dict__.items():
# Don't duplicate any local fields # Don't duplicate any local fields
if name in local_field_names: if name in local_field_names:
continue continue
@ -51,6 +68,24 @@ def get_reverse_fields(model, local_field_names):
yield (name, related) yield (name, related)
def get_local_fields(model):
"""
Searches through the model's ancestry and gets the fields on the models
Returns a dict of {field.name: field}
"""
model_ancestry = _get_model_ancestry(model)
local_fields_dict = {}
for _model in model_ancestry:
for field in sorted(
list(_model._meta.fields) + list(_model._meta.local_many_to_many)
):
if field.name not in local_fields_dict:
local_fields_dict[field.name] = field
return list(local_fields_dict.items())
def maybe_queryset(value): def maybe_queryset(value):
if isinstance(value, Manager): if isinstance(value, Manager):
value = value.get_queryset() value = value.get_queryset()
@ -58,17 +93,14 @@ def maybe_queryset(value):
def get_model_fields(model): def get_model_fields(model):
local_fields = [ """
(field.name, field) Gets all the fields and relationships on the Django model and its ancestry.
for field in sorted( Prioritizes local fields and relationships over the reverse relationships of the same name
list(model._meta.fields) + list(model._meta.local_many_to_many) Returns a tuple of (field.name, field)
) """
] local_fields = get_local_fields(model)
local_field_names = {field[0] for field in local_fields}
# Make sure we don't duplicate local fields with "reverse" version
local_field_names = [field[0] for field in local_fields]
reverse_fields = get_reverse_fields(model, local_field_names) reverse_fields = get_reverse_fields(model, local_field_names)
all_fields = local_fields + list(reverse_fields) all_fields = local_fields + list(reverse_fields)
return all_fields return all_fields
@ -79,24 +111,7 @@ def is_valid_django_model(model):
def import_single_dispatch(): def import_single_dispatch():
try:
from functools import singledispatch from functools import singledispatch
except ImportError:
singledispatch = None
if not singledispatch:
try:
from singledispatch import singledispatch
except ImportError:
pass
if not singledispatch:
raise Exception(
"It seems your python version does not include "
"functools.singledispatch. Please install the 'singledispatch' "
"package. More information here: "
"https://pypi.python.org/pypi/singledispatch"
)
return singledispatch return singledispatch
@ -105,3 +120,17 @@ def set_rollback():
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False) atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block: if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True) transaction.set_rollback(True)
def bypass_get_queryset(resolver):
"""
Adds a bypass_get_queryset attribute to the resolver, which is used to
bypass any custom get_queryset method of the DjangoObjectType.
"""
resolver._bypass_get_queryset = True
return resolver
_DJANGO_VERSION_AT_LEAST_4_2 = django.VERSION[0] > 4 or (
django.VERSION[0] >= 4 and django.VERSION[1] >= 2
)

View File

@ -9,13 +9,19 @@ from django.shortcuts import render
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import View from django.views.generic import View
from graphql import OperationType, get_operation_ast, parse from graphql import (
ExecutionResult,
OperationType,
execute,
get_operation_ast,
parse,
validate_schema,
)
from graphql.error import GraphQLError from graphql.error import GraphQLError
from graphql.execution import ExecutionResult from graphql.execution.middleware import MiddlewareManager
from graphql.validation import validate
from graphene import Schema from graphene import Schema
from graphql.execution.middleware import MiddlewareManager
from graphene_django.constants import MUTATION_ERRORS_FLAG from graphene_django.constants import MUTATION_ERRORS_FLAG
from graphene_django.utils.utils import set_rollback from graphene_django.utils.utils import set_rollback
@ -40,9 +46,9 @@ def get_accepted_content_types(request):
raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",") raw_content_types = request.META.get("HTTP_ACCEPT", "*/*").split(",")
qualified_content_types = map(qualify, raw_content_types) qualified_content_types = map(qualify, raw_content_types)
return list( return [
x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True) x[0] for x in sorted(qualified_content_types, key=lambda x: x[1], reverse=True)
) ]
def instantiate_middleware(middlewares): def instantiate_middleware(middlewares):
@ -66,18 +72,21 @@ class GraphQLView(View):
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=" react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
# The GraphiQL React app. # The GraphiQL React app.
graphiql_version = "2.4.1" # "1.0.3" graphiql_version = "2.4.7"
graphiql_sri = "sha256-s+f7CFAPSUIygFnRC2nfoiEKd3liCUy+snSdYFAoLUc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk=" graphiql_sri = "sha256-n/LKaELupC1H/PU6joz+ybeRJHT2xCdekEt6OYMOOZU="
graphiql_css_sri = "sha256-88yn8FJMyGboGs4Bj+Pbb3kWOWXo7jmb+XCRHE+282k=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E=" graphiql_css_sri = "sha256-OsbM+LQHcnFHi0iH7AUKueZvDcEBoy/z4hJ7jx1cpsM="
# The websocket transport library for subscriptions. # The websocket transport library for subscriptions.
subscriptions_transport_ws_version = "5.12.1" subscriptions_transport_ws_version = "5.13.1"
subscriptions_transport_ws_sri = ( subscriptions_transport_ws_sri = (
"sha256-EZhvg6ANJrBsgLvLAa0uuHNLepLJVCFYS+xlb5U/bqw=" "sha256-EZhvg6ANJrBsgLvLAa0uuHNLepLJVCFYS+xlb5U/bqw="
) )
graphiql_plugin_explorer_version = "0.1.15" graphiql_plugin_explorer_version = "0.1.15"
graphiql_plugin_explorer_sri = "sha256-3hUuhBXdXlfCj6RTeEkJFtEh/kUG+TCDASFpFPLrzvE=" graphiql_plugin_explorer_sri = "sha256-3hUuhBXdXlfCj6RTeEkJFtEh/kUG+TCDASFpFPLrzvE="
graphiql_plugin_explorer_css_sri = (
"sha256-fA0LPUlukMNR6L4SPSeFqDTYav8QdWjQ2nr559Zln1U="
)
schema = None schema = None
graphiql = False graphiql = False
@ -87,6 +96,7 @@ class GraphQLView(View):
batch = False batch = False
subscription_path = None subscription_path = None
execution_context_class = None execution_context_class = None
validation_rules = None
def __init__( def __init__(
self, self,
@ -98,6 +108,7 @@ class GraphQLView(View):
batch=False, batch=False,
subscription_path=None, subscription_path=None,
execution_context_class=None, execution_context_class=None,
validation_rules=None,
): ):
if not schema: if not schema:
schema = graphene_settings.SCHEMA schema = graphene_settings.SCHEMA
@ -126,6 +137,8 @@ class GraphQLView(View):
), "A Schema is required to be provided to GraphQLView." ), "A Schema is required to be provided to GraphQLView."
assert not all((graphiql, batch)), "Use either graphiql or batch processing" assert not all((graphiql, batch)), "Use either graphiql or batch processing"
self.validation_rules = validation_rules or self.validation_rules
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def get_root_value(self, request): def get_root_value(self, request):
return self.root_value return self.root_value
@ -165,11 +178,13 @@ class GraphQLView(View):
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version, graphiql_plugin_explorer_version=self.graphiql_plugin_explorer_version,
graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri, graphiql_plugin_explorer_sri=self.graphiql_plugin_explorer_sri,
graphiql_plugin_explorer_css_sri=self.graphiql_plugin_explorer_css_sri,
# The SUBSCRIPTION_PATH setting. # The SUBSCRIPTION_PATH setting.
subscription_path=self.subscription_path, subscription_path=self.subscription_path,
# GraphiQL headers tab, # GraphiQL headers tab,
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS, graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
graphiql_input_value_deprecation=graphene_settings.GRAPHIQL_INPUT_VALUE_DEPRECATION,
) )
if self.batch: if self.batch:
@ -291,14 +306,24 @@ class GraphQLView(View):
return None return None
raise HttpError(HttpResponseBadRequest("Must provide query string.")) raise HttpError(HttpResponseBadRequest("Must provide query string."))
schema = self.schema.graphql_schema
schema_validation_errors = validate_schema(schema)
if schema_validation_errors:
return ExecutionResult(data=None, errors=schema_validation_errors)
try: try:
document = parse(query) document = parse(query)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e]) return ExecutionResult(errors=[e])
if request.method.lower() == "get":
operation_ast = get_operation_ast(document, operation_name) operation_ast = get_operation_ast(document, operation_name)
if operation_ast and operation_ast.operation != OperationType.QUERY:
if (
request.method.lower() == "get"
and operation_ast is not None
and operation_ast.operation != OperationType.QUERY
):
if show_graphiql: if show_graphiql:
return None return None
@ -310,24 +335,32 @@ class GraphQLView(View):
), ),
) )
) )
try:
extra_options = {}
if self.execution_context_class:
extra_options["execution_context_class"] = self.execution_context_class
options = { validation_errors = validate(
"source": query, schema,
document,
self.validation_rules,
graphene_settings.MAX_VALIDATION_ERRORS,
)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
try:
execute_options = {
"root_value": self.get_root_value(request), "root_value": self.get_root_value(request),
"context_value": self.get_context(request),
"variable_values": variables, "variable_values": variables,
"operation_name": operation_name, "operation_name": operation_name,
"context_value": self.get_context(request),
"middleware": self.get_middleware(request), "middleware": self.get_middleware(request),
} }
options.update(extra_options) if self.execution_context_class:
execute_options[
"execution_context_class"
] = self.execution_context_class
operation_ast = get_operation_ast(document, operation_name)
if ( if (
operation_ast operation_ast is not None
and operation_ast.operation == OperationType.MUTATION and operation_ast.operation == OperationType.MUTATION
and ( and (
graphene_settings.ATOMIC_MUTATIONS is True graphene_settings.ATOMIC_MUTATIONS is True
@ -335,12 +368,12 @@ class GraphQLView(View):
) )
): ):
with transaction.atomic(): with transaction.atomic():
result = self.schema.execute(**options) result = execute(schema, document, **execute_options)
if getattr(request, MUTATION_ERRORS_FLAG, False) is True: if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
transaction.set_rollback(True) transaction.set_rollback(True)
return result return result
return self.schema.execute(**options) return execute(schema, document, **execute_options)
except Exception as e: except Exception as e:
return ExecutionResult(errors=[e]) return ExecutionResult(errors=[e])

View File

@ -4,46 +4,13 @@ test=pytest
[bdist_wheel] [bdist_wheel]
universal=1 universal=1
[flake8]
exclude = docs,graphene_django/debug/sql/*
max-line-length = 120
select =
# Dictionary key repeated
F601,
# Ensure use of ==/!= to compare with str, bytes and int literals
F632,
# Redefinition of unused name
F811,
# Using an undefined variable
F821,
# Defining an undefined variable in __all__
F822,
# Using a variable before it is assigned
F823,
# Duplicate argument in function declaration
F831,
# Black would format this line
BLK,
# Do not use bare except
B001,
# Don't allow ++n. You probably meant n += 1
B002,
# Do not use mutable structures for argument defaults
B006,
# Do not perform calls in argument defaults
B008
[coverage:run] [coverage:run]
omit = */tests/* omit = */tests/*
[isort]
known_first_party=graphene,graphene_django
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88
[tool:pytest] [tool:pytest]
DJANGO_SETTINGS_MODULE = examples.django_test_settings DJANGO_SETTINGS_MODULE = examples.django_test_settings
addopts = --random-order addopts = --random-order
filterwarnings =
error
# we can't do anything about the DeprecationWarning about typing.ByteString in graphql
default:'typing\.ByteString' is deprecated:DeprecationWarning:graphql\.pyutils\.is_iterable

View File

@ -26,10 +26,7 @@ tests_require = [
dev_requires = [ dev_requires = [
"black==23.3.0", "ruff==0.1.2",
"flake8==6.0.0",
"flake8-black==0.3.6",
"flake8-bugbear==23.3.23",
"pre-commit", "pre-commit",
] + tests_require ] + tests_require
@ -38,7 +35,7 @@ setup(
version=version, version=version,
description="Graphene Django integration", description="Graphene Django integration",
long_description=open("README.md").read(), long_description=open("README.md").read(),
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
url="https://github.com/graphql-python/graphene-django", url="https://github.com/graphql-python/graphene-django",
author="Syrus Akbary", author="Syrus Akbary",
author_email="me@syrusakbary.com", author_email="me@syrusakbary.com",
@ -48,16 +45,18 @@ setup(
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: Implementation :: PyPy",
"Framework :: Django", "Framework :: Django",
"Framework :: Django :: 3.2", "Framework :: Django :: 3.2",
"Framework :: Django :: 4.0",
"Framework :: Django :: 4.1", "Framework :: Django :: 4.1",
"Framework :: Django :: 4.2",
"Framework :: Django :: 5.1",
"Framework :: Django :: 5.2",
], ],
keywords="api graphql protocol rest relay graphene", keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests", "examples", "examples.*"]), packages=find_packages(exclude=["tests", "examples", "examples.*"]),

24
tox.ini
View File

@ -1,23 +1,25 @@
[tox] [tox]
envlist = envlist =
py{37,38,39,310}-django32, py{38,39,310}-django32
py{38,39,310}-django{40,41,main}, py{38,39}-django42
py311-django{41,main} py{310,311,312}-django{42,50,51,main}
pre-commit pre-commit
[gh-actions] [gh-actions]
python = python =
3.7: py37
3.8: py38 3.8: py38
3.9: py39 3.9: py39
3.10: py310 3.10: py310
3.11: py311 3.11: py311
3.12: py312
[gh-actions:env] [gh-actions:env]
DJANGO = DJANGO =
3.2: django32 3.2: django32
4.0: django40 4.2: django42
4.1: django41 5.0: django50
5.1: django51
5.2: django52
main: djangomain main: djangomain
[testenv] [testenv]
@ -30,13 +32,15 @@ deps =
-e.[test] -e.[test]
psycopg2-binary psycopg2-binary
django32: Django>=3.2,<4.0 django32: Django>=3.2,<4.0
django40: Django>=4.0,<4.1 django42: Django>=4.2,<4.3
django41: Django>=4.1,<4.2 django50: Django>=5.0,<5.1
django51: Django>=5.1,<5.2
django52: Django>=5.2,<6.0
djangomain: https://github.com/django/django/archive/main.zip djangomain: https://github.com/django/django/archive/main.zip
commands = {posargs:py.test --cov=graphene_django graphene_django examples} commands = {posargs:pytest --cov=graphene_django graphene_django examples}
[testenv:pre-commit] [testenv:pre-commit]
skip_install = true skip_install = true
deps = pre-commit deps = pre-commit
commands = commands =
pre-commit run --all-files --show-diff-on-failure pre-commit run {posargs:--all-files --show-diff-on-failure}