Merge branch 'master' into update-dataloaderdocs

This commit is contained in:
Erik Wrede 2022-08-13 15:11:12 +02:00 committed by GitHub
commit 8e1c3d3102
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
103 changed files with 2366 additions and 652 deletions

34
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@ -0,0 +1,34 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: "\U0001F41B bug"
assignees: ''
---
**Note: for support questions, please use stackoverflow**. This repository's issues are reserved for feature requests and bug reports.
* **What is the current behavior?**
* **If the current behavior is a bug, please provide the steps to reproduce and if possible a minimal demo of the problem** via
a github repo, https://repl.it or similar.
* **What is the expected behavior?**
* **What is the motivation / use case for changing the behavior?**
* **Please tell us about your environment:**
- Version:
- Platform:
* **Other information** (e.g. detailed explanation, stacktraces, related issues, suggestions how to fix, links for us to have context, eg. stackoverflow)

1
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1 @@
blank_issues_enabled: false

View File

@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: "✨ enhancement"
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

13
.github/stale.yml vendored
View File

@ -1,7 +1,7 @@
# Number of days of inactivity before an issue becomes stale # Number of days of inactivity before an issue becomes stale
daysUntilStale: 90 daysUntilStale: false
# Number of days of inactivity before a stale issue is closed # Number of days of inactivity before a stale issue is closed
daysUntilClose: 14 daysUntilClose: false
# Issues with these labels will never be considered stale # Issues with these labels will never be considered stale
exemptLabels: exemptLabels:
- pinned - pinned
@ -15,9 +15,10 @@ exemptLabels:
# Label to use when marking an issue as stale # Label to use when marking an issue as stale
staleLabel: wontfix staleLabel: wontfix
# Comment to post when marking an issue as stale. Set to `false` to disable # Comment to post when marking an issue as stale. Set to `false` to disable
markComment: > markComment: false
This issue has been automatically marked as stale because it has not had # markComment: >
recent activity. It will be closed if no further activity occurs. Thank you # This issue has been automatically marked as stale because it has not had
for your contributions. # recent activity. It will be closed if no further activity occurs. Thank you
# for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable # Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false closeComment: false

25
.github/workflows/coveralls.yml vendored Normal file
View File

@ -0,0 +1,25 @@
name: 📊 Check Coverage
on:
push:
branches:
- master
- '*.x'
paths-ignore:
- 'docs/**'
- '*.md'
- '*.rst'
pull_request:
branches:
- master
- '*.x'
paths-ignore:
- 'docs/**'
- '*.md'
- '*.rst'
jobs:
coveralls_finish:
# check coverage increase/decrease
runs-on: ubuntu-latest
steps:
- name: Coveralls Finished
uses: AndreMiras/coveralls-python-action@develop

26
.github/workflows/deploy.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: 🚀 Deploy to PyPI
on:
push:
tags:
- 'v*'
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Build wheel and source tarball
run: |
pip install wheel
python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@v1.1.0
with:
user: __token__
password: ${{ secrets.pypi_password }}

26
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: 💅 Lint
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox
- name: Run lint
run: tox
env:
TOXENV: pre-commit
- name: Run mypy
run: tox
env:
TOXENV: mypy

66
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,66 @@
name: 📄 Tests
on:
push:
branches:
- master
- '*.x'
paths-ignore:
- 'docs/**'
- '*.md'
- '*.rst'
pull_request:
branches:
- master
- '*.x'
paths-ignore:
- 'docs/**'
- '*.md'
- '*.rst'
jobs:
tests:
# runs the test suite
name: ${{ matrix.name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- {name: '3.10', python: '3.10', os: ubuntu-latest, tox: py310}
- {name: '3.9', python: '3.9', os: ubuntu-latest, tox: py39}
- { name: '3.8', python: '3.8', os: ubuntu-latest, tox: py38 }
- { name: '3.7', python: '3.7', os: ubuntu-latest, tox: py37 }
- { name: '3.6', python: '3.6', os: ubuntu-latest, tox: py36 }
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: update pip
run: |
pip install -U wheel
pip install -U setuptools
python -m pip install -U pip
- name: get pip cache dir
id: pip-cache
run: echo "::set-output name=dir::$(pip cache dir)"
- name: cache pip dependencies
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: pip|${{ runner.os }}|${{ matrix.python }}|${{ hashFiles('setup.py') }}
- run: pip install tox
- run: tox -e ${{ matrix.tox }}
- name: Upload coverage.xml
if: ${{ matrix.python == '3.10' }}
uses: actions/upload-artifact@v3
with:
name: graphene-sqlalchemy-coverage
path: coverage.xml
if-no-files-found: error
- name: Upload coverage.xml to codecov
if: ${{ matrix.python == '3.10' }}
uses: codecov/codecov-action@v3

View File

@ -1,6 +1,9 @@
default_language_version:
python: python3.9
repos: repos:
- repo: git://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.1.0 rev: v4.2.0
hooks: hooks:
- id: check-merge-conflict - id: check-merge-conflict
- id: check-json - id: check-json
@ -14,15 +17,14 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
exclude: README.md exclude: README.md
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v1.12.0 rev: v2.32.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 19.10b0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
language_version: python3
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 3.7.8 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8

View File

@ -1,42 +0,0 @@
language: python
dist: xenial
python:
- "3.6"
- "3.7"
- "3.8"
install:
- pip install tox tox-travis
script: tox
after_success:
- pip install coveralls
- coveralls
cache:
directories:
- $HOME/.cache/pip
- $HOME/.cache/pre-commit
stages:
- test
- name: deploy
if: tag IS present
jobs:
fast_finish: true
include:
- env: TOXENV=pre-commit
python: 3.7
- env: TOXENV=mypy
python: 3.7
- stage: deploy
python: 3.7
after_success: true
deploy:
provider: pypi
user: syrusakbary
on:
tags: true
password:
secure: LHOp9DvYR+70vj4YVY8+JRNCKUOfYZREEUY3+4lMUpY7Zy5QwDfgEMXG64ybREH9dFldpUqVXRj53eeU3spfudSfh8NHkgqW7qihez2AhSnRc4dK6ooNfB+kLcSoJ4nUFGxdYImABc4V1hJvflGaUkTwDNYVxJF938bPaO797IvSbuI86llwqkvuK2Vegv9q/fy9sVGaF9VZIs4JgXwR5AyDR7FBArl+S84vWww4vTFD33hoE88VR4QvFY3/71BwRtQrnCMm7AOm31P9u29yi3bpzQpiOR2rHsgrsYdm597QzFKVxYwsmf9uAx2bpbSPy2WibunLePIvOFwm8xcfwnz4/J4ONBc5PSFmUytTWpzEnxb0bfUNLuYloIS24V6OZ8BfAhiYZ1AwySeJCQDM4Vk1V8IF6trTtyx5EW/uV9jsHCZ3LFsAD7UnFRTosIgN3SAK3ZWCEk5oF2IvjecsolEfkRXB3q9EjMkkuXRUeFDH2lWJLgNE27BzY6myvZVzPmfwZUsPBlPD/6w+WLSp97Rjgr9zS3T1d4ddqFM4ZYu04f2i7a/UUQqG+itzzuX5DWLPvzuNt37JB45mB9IsvxPyXZ6SkAcLl48NGyKok1f3vQnvphkfkl4lni29woKhaau8xlsuEDrcwOoeAsVcZXiItg+l+z2SlIwM0A06EvQ=
distributions: "sdist bdist_wheel"

View File

@ -1,3 +0,0 @@
* @ekampf @dan98765 @projectcheshire @jkimbo
/docs/ @dvndrsn @phalt @changeling
/examples/ @dvndrsn @phalt @changeling

View File

@ -8,7 +8,7 @@ install-dev:
pip install -e ".[dev]" pip install -e ".[dev]"
test: test:
py.test graphene examples tests_asyncio py.test graphene examples
.PHONY: docs ## Generate docs .PHONY: docs ## Generate docs
docs: install-dev docs: install-dev
@ -20,8 +20,8 @@ docs-live: install-dev
.PHONY: format .PHONY: format
format: format:
black graphene examples setup.py tests_asyncio black graphene examples setup.py
.PHONY: lint .PHONY: lint
lint: lint:
flake8 graphene examples setup.py tests_asyncio flake8 graphene examples setup.py

View File

@ -4,12 +4,6 @@
**We are looking for contributors**! Please check the [ROADMAP](https://github.com/graphql-python/graphene/blob/master/ROADMAP.md) to see how you can help ❤️ **We are looking for contributors**! Please check the [ROADMAP](https://github.com/graphql-python/graphene/blob/master/ROADMAP.md) to see how you can help ❤️
---
**The below readme is the documentation for the `dev` (prerelease) version of Graphene. To view the documentation for the latest stable Graphene version go to the [v2 docs](https://docs.graphene-python.org/en/stable/)**
---
## Introduction ## Introduction
[Graphene](http://graphene-python.org) is an opinionated Python library for building GraphQL schemas/types fast and easily. [Graphene](http://graphene-python.org) is an opinionated Python library for building GraphQL schemas/types fast and easily.
@ -37,7 +31,7 @@ Also, Graphene is fully compatible with the GraphQL spec, working seamlessly wit
For instaling graphene, just run this command in your shell For instaling graphene, just run this command in your shell
```bash ```bash
pip install "graphene>=2.0" pip install "graphene>=3.0"
``` ```
## Examples ## Examples

View File

@ -1,18 +1,18 @@
|Graphene Logo| `Graphene <http://graphene-python.org>`__ |Build Status| |PyPI version| |Coverage Status|
=========================================================================================================
`💬 Join the community on
Slack <https://join.slack.com/t/graphenetools/shared_invite/enQtOTE2MDQ1NTg4MDM1LTA4Nzk0MGU0NGEwNzUxZGNjNDQ4ZjAwNDJjMjY0OGE1ZDgxZTg4YjM2ZTc4MjE2ZTAzZjE2ZThhZTQzZTkyMmM>`__
**We are looking for contributors**! Please check the **We are looking for contributors**! Please check the
`ROADMAP <https://github.com/graphql-python/graphene/blob/master/ROADMAP.md>`__ `ROADMAP <https://github.com/graphql-python/graphene/blob/master/ROADMAP.md>`__
to see how you can help ❤️ to see how you can help ❤️
--------------
|Graphene Logo| `Graphene <http://graphene-python.org>`__ |Build Status| |PyPI version| |Coverage Status|
=========================================================================================================
Introduction Introduction
------------ ------------
`Graphene <http://graphene-python.org>`__ is a Python library for `Graphene <http://graphene-python.org>`__ is an opinionated Python
building GraphQL schemas/types fast and easily. library for building GraphQL schemas/types fast and easily.
- **Easy to use:** Graphene helps you use GraphQL in Python without - **Easy to use:** Graphene helps you use GraphQL in Python without
effort. effort.
@ -27,17 +27,18 @@ Integrations
Graphene has multiple integrations with different frameworks: Graphene has multiple integrations with different frameworks:
+---------------------+----------------------------------------------------------------------------------------------+ +-------------------+-------------------------------------------------+
| integration | Package | | integration | Package |
+=====================+==============================================================================================+ +===================+=================================================+
| Django | `graphene-django <https://github.com/graphql-python/graphene-django/>`__ | | Django | `graphene-django <https:/ |
+---------------------+----------------------------------------------------------------------------------------------+ | | /github.com/graphql-python/graphene-django/>`__ |
| SQLAlchemy | `graphene-sqlalchemy <https://github.com/graphql-python/graphene-sqlalchemy/>`__ | +-------------------+-------------------------------------------------+
+---------------------+----------------------------------------------------------------------------------------------+ | SQLAlchemy | `graphene-sqlalchemy <https://git |
| Google App Engine | `graphene-gae <https://github.com/graphql-python/graphene-gae/>`__ | | | hub.com/graphql-python/graphene-sqlalchemy/>`__ |
+---------------------+----------------------------------------------------------------------------------------------+ +-------------------+-------------------------------------------------+
| Peewee | *In progress* (`Tracking Issue <https://github.com/graphql-python/graphene/issues/289>`__) | | Google App Engine | `graphene-gae <http |
+---------------------+----------------------------------------------------------------------------------------------+ | | s://github.com/graphql-python/graphene-gae/>`__ |
+-------------------+-------------------------------------------------+
Also, Graphene is fully compatible with the GraphQL spec, working Also, Graphene is fully compatible with the GraphQL spec, working
seamlessly with all GraphQL clients, such as seamlessly with all GraphQL clients, such as
@ -52,13 +53,7 @@ For instaling graphene, just run this command in your shell
.. code:: bash .. code:: bash
pip install "graphene>=2.0" pip install "graphene>=3.0"
2.0 Upgrade Guide
-----------------
Please read `UPGRADE-v2.0.md </UPGRADE-v2.0.md>`__ to learn how to
upgrade.
Examples Examples
-------- --------
@ -123,7 +118,7 @@ this project. While developing, run new and existing tests with:
py.test graphene/relay # All tests in directory py.test graphene/relay # All tests in directory
Add the ``-s`` flag if you have introduced breakpoints into the code for Add the ``-s`` flag if you have introduced breakpoints into the code for
debugging. Add the ``-v`` ("verbose") flag to get more detailed test debugging. Add the ``-v`` (“verbose”) flag to get more detailed test
output. For even more detailed output, use ``-vv``. Check out the output. For even more detailed output, use ``-vv``. Check out the
`pytest documentation <https://docs.pytest.org/en/latest/>`__ for more `pytest documentation <https://docs.pytest.org/en/latest/>`__ for more
options and test running controls. options and test running controls.

View File

@ -153,7 +153,7 @@ class Query(ObjectType):
``` ```
Also, if you wanted to create an `ObjectType` that implements `Node`, you have to do it Also, if you wanted to create an `ObjectType` that implements `Node`, you have to do it
explicity. explicitly.
## Django ## Django

View File

@ -123,7 +123,7 @@ def resolve_my_field(root, info, my_arg):
return ... return ...
``` ```
**PS.: Take care with receiving args like `my_arg` as above. This doesn't work for optional (non-required) arguments as stantard `Connection`'s arguments (first, before, after, before).** **PS.: Take care with receiving args like `my_arg` as above. This doesn't work for optional (non-required) arguments as standard `Connection`'s arguments (first, last, after, before).**
You may need something like this: You may need something like this:
```python ```python

View File

@ -64,6 +64,8 @@ Graphene Scalars
.. autoclass:: graphene.JSONString() .. autoclass:: graphene.JSONString()
.. autoclass:: graphene.Base64()
Enum Enum
---- ----

View File

@ -64,18 +64,18 @@ source_suffix = ".rst"
master_doc = "index" master_doc = "index"
# General information about the project. # General information about the project.
project = u"Graphene" project = "Graphene"
copyright = u"Graphene 2016" copyright = "Graphene 2016"
author = u"Syrus Akbary" author = "Syrus Akbary"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = u"1.0" version = "1.0"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = u"1.0" release = "1.0"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
@ -278,7 +278,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, "Graphene.tex", u"Graphene Documentation", u"Syrus Akbary", "manual") (master_doc, "Graphene.tex", "Graphene Documentation", "Syrus Akbary", "manual")
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
@ -318,7 +318,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "graphene", u"Graphene Documentation", [author], 1)] man_pages = [(master_doc, "graphene", "Graphene Documentation", [author], 1)]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
# #
@ -334,7 +334,7 @@ texinfo_documents = [
( (
master_doc, master_doc,
"Graphene", "Graphene",
u"Graphene Documentation", "Graphene Documentation",
author, author,
"Graphene", "Graphene",
"One line description of project.", "One line description of project.",

View File

@ -25,10 +25,10 @@ Create loaders by providing a batch loading function.
A batch loading async function accepts a list of keys, and returns a list of ``values``. A batch loading async function accepts a list of keys, and returns a list of ``values``.
Then load individual values from the loader. ``DataLoader`` will coalesce all
individual loads which occur within a single frame of execution (executed once ``DataLoader`` will coalesce all individual loads which occur within a
the wrapping event loop is resolved) and then call your batch function with all single frame of execution (executed once the wrapping event loop is resolved)
requested keys. and then call your batch function with all requested keys.
.. code:: python .. code:: python
@ -95,7 +95,7 @@ Consider the following GraphQL request:
} }
Naively, if ``me``, ``bestFriend`` and ``friends`` each need to request the backend, If ``me``, ``bestFriend`` and ``friends`` each need to send a request to the backend,
there could be at most 13 database requests! there could be at most 13 database requests!

View File

@ -3,7 +3,6 @@
Executing a query Executing a query
================= =================
For executing a query against a schema, you can directly call the ``execute`` method on it. For executing a query against a schema, you can directly call the ``execute`` method on it.
@ -17,43 +16,6 @@ For executing a query against a schema, you can directly call the ``execute`` me
``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred. ``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred.
For executing a subscription, you can directly call the ``subscribe`` method on it.
This method is async and must be awaited.
.. code:: python
import asyncio
from datetime import datetime
from graphene import ObjectType, String, Schema, Field
# All schema require a query.
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return 'Hello, world!'
class Subscription(ObjectType):
time_of_day = Field(String)
async def subscribe_time_of_day(root, info):
while True:
yield { 'time_of_day': datetime.now().isoformat()}
await asyncio.sleep(1)
SCHEMA = Schema(query=Query, subscription=Subscription)
async def main(schema):
subscription = 'subscription { timeOfDay }'
result = await schema.subscribe(subscription)
async for item in result:
print(item.data['timeOfDay'])
asyncio.run(main(SCHEMA))
The ``result`` is an async iterator which yields items in the same manner as a query.
.. _SchemaExecuteContext: .. _SchemaExecuteContext:
Context Context
@ -123,7 +85,7 @@ Value used for :ref:`ResolverParamParent` in root queries and mutations can be o
return {'id': root.id, 'firstName': root.name} return {'id': root.id, 'firstName': root.name}
schema = Schema(Query) schema = Schema(Query)
user_root = User(id=12, name='bob'} user_root = User(id=12, name='bob')
result = schema.execute( result = schema.execute(
''' '''
query getUser { query getUser {
@ -148,7 +110,7 @@ If there are multiple operations defined in a query string, ``operation_name`` s
from graphene import ObjectType, Field, Schema from graphene import ObjectType, Field, Schema
class Query(ObjectType): class Query(ObjectType):
me = Field(User) user = Field(User)
def resolve_user(root, info): def resolve_user(root, info):
return get_user_by_id(12) return get_user_by_id(12)

View File

@ -4,5 +4,5 @@ File uploading
File uploading is not part of the official GraphQL spec yet and is not natively File uploading is not part of the official GraphQL spec yet and is not natively
implemented in Graphene. implemented in Graphene.
If your server needs to support file uploading then you can use the libary: `graphene-file-upload <https://github.com/lmcgartland/graphene-file-upload>`_ which enhances Graphene to add file If your server needs to support file uploading then you can use the library: `graphene-file-upload <https://github.com/lmcgartland/graphene-file-upload>`_ which enhances Graphene to add file
uploads and conforms to the unoffical GraphQL `multipart request spec <https://github.com/jaydenseric/graphql-multipart-request-spec>`_. uploads and conforms to the unoffical GraphQL `multipart request spec <https://github.com/jaydenseric/graphql-multipart-request-spec>`_.

View File

@ -9,3 +9,5 @@ Execution
middleware middleware
dataloader dataloader
fileuploading fileuploading
subscriptions
queryvalidation

View File

@ -46,7 +46,7 @@ Functional example
------------------ ------------------
Middleware can also be defined as a function. Here we define a middleware that Middleware can also be defined as a function. Here we define a middleware that
logs the time it takes to resolve each field logs the time it takes to resolve each field:
.. code:: python .. code:: python

View File

@ -0,0 +1,123 @@
Query Validation
==========
GraphQL uses query validators to check if Query AST is valid and can be executed. Every GraphQL server implements
standard query validators. For example, there is an validator that tests if queried field exists on queried type, that
makes query fail with "Cannot query field on type" error if it doesn't.
To help with common use cases, graphene provides a few validation rules out of the box.
Depth limit Validator
-----------------
The depth limit validator helps to prevent execution of malicious
queries. It takes in the following arguments.
- ``max_depth`` is the maximum allowed depth for any operation in a GraphQL document.
- ``ignore`` Stops recursive depth checking based on a field name. Either a string or regexp to match the name, or a function that returns a boolean
- ``callback`` Called each time validation runs. Receives an Object which is a map of the depths for each operation.
Usage
-------
Here is how you would implement depth-limiting on your schema.
.. code:: python
from graphql import validate, parse
from graphene import ObjectType, Schema, String
from graphene.validation import depth_limit_validator
class MyQuery(ObjectType):
name = String(required=True)
schema = Schema(query=MyQuery)
# queries which have a depth more than 20
# will not be executed.
validation_errors = validate(
schema=schema.graphql_schema,
document_ast=parse('THE QUERY'),
rules=(
depth_limit_validator(
max_depth=20
),
)
)
Disable Introspection
---------------------
the disable introspection validation rule ensures that your schema cannot be introspected.
This is a useful security measure in production environments.
Usage
-------
Here is how you would disable introspection for your schema.
.. code:: python
from graphql import validate, parse
from graphene import ObjectType, Schema, String
from graphene.validation import DisableIntrospection
class MyQuery(ObjectType):
name = String(required=True)
schema = Schema(query=MyQuery)
# introspection queries will not be executed.
validation_errors = validate(
schema=schema.graphql_schema,
document_ast=parse('THE QUERY'),
rules=(
DisableIntrospection,
)
)
Implementing custom validators
------------------------------
All custom query validators should extend the `ValidationRule <https://github.com/graphql-python/graphql-core/blob/v3.0.5/src/graphql/validation/rules/__init__.py#L37>`_
base class importable from the graphql.validation.rules module. Query validators are visitor classes. They are
instantiated at the time of query validation with one required argument (context: ASTValidationContext). In order to
perform validation, your validator class should define one or more of enter_* and leave_* methods. For possible
enter/leave items as well as details on function documentation, please see contents of the visitor module. To make
validation fail, you should call validator's report_error method with the instance of GraphQLError describing failure
reason. Here is an example query validator that visits field definitions in GraphQL query and fails query validation
if any of those fields are blacklisted:
.. code:: python
from graphql import GraphQLError
from graphql.language import FieldNode
from graphql.validation import ValidationRule
my_blacklist = (
"disallowed_field",
)
def is_blacklisted_field(field_name: str):
return field_name.lower() in my_blacklist
class BlackListRule(ValidationRule):
def enter_field(self, node: FieldNode, *_args):
field_name = node.name.value
if not is_blacklisted_field(field_name):
return
self.report_error(
GraphQLError(
f"Cannot query '{field_name}': field is blacklisted.", node,
)
)

View File

@ -0,0 +1,40 @@
.. _SchemaSubscription:
Subscriptions
=============
To create a subscription, you can directly call the ``subscribe`` method on the
schema. This method is async and must be awaited.
.. code:: python
import asyncio
from datetime import datetime
from graphene import ObjectType, String, Schema, Field
# Every schema requires a query.
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return "Hello, world!"
class Subscription(ObjectType):
time_of_day = String()
async def subscribe_time_of_day(root, info):
while True:
yield datetime.now().isoformat()
await asyncio.sleep(1)
schema = Schema(query=Query, subscription=Subscription)
async def main(schema):
subscription = 'subscription { timeOfDay }'
result = await schema.subscribe(subscription)
async for item in result:
print(item.data['timeOfDay'])
asyncio.run(main(schema))
The ``result`` is an async iterator which yields items in the same manner as a query.

View File

@ -60,14 +60,14 @@ Requirements
~~~~~~~~~~~~ ~~~~~~~~~~~~
- Python (2.7, 3.4, 3.5, 3.6, pypy) - Python (2.7, 3.4, 3.5, 3.6, pypy)
- Graphene (2.0) - Graphene (3.0)
Project setup Project setup
~~~~~~~~~~~~~ ~~~~~~~~~~~~~
.. code:: bash .. code:: bash
pip install "graphene>=2.0" pip install "graphene>=3.0"
Creating a basic Schema Creating a basic Schema
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
@ -103,7 +103,7 @@ For each **Field** in our **Schema**, we write a **Resolver** method to fetch da
Schema Definition Language (SDL) Schema Definition Language (SDL)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In the `GraphQL Schema Definition Language`_, we could describe the fields defined by our example code as show below. In the `GraphQL Schema Definition Language`_, we could describe the fields defined by our example code as shown below.
.. _GraphQL Schema Definition Language: https://graphql.org/learn/schema/ .. _GraphQL Schema Definition Language: https://graphql.org/learn/schema/

View File

@ -19,11 +19,8 @@ Useful links
- `Getting started with Relay`_ - `Getting started with Relay`_
- `Relay Global Identification Specification`_ - `Relay Global Identification Specification`_
- `Relay Cursor Connection Specification`_ - `Relay Cursor Connection Specification`_
- `Relay input Object Mutation`_
.. _Relay: https://facebook.github.io/relay/docs/en/graphql-server-specification.html .. _Relay: https://relay.dev/docs/guides/graphql-server-specification/
.. _Relay specification: https://facebook.github.io/relay/graphql/objectidentification.htm#sec-Node-root-field .. _Getting started with Relay: https://relay.dev/docs/getting-started/step-by-step-guide/
.. _Getting started with Relay: https://facebook.github.io/relay/docs/en/quick-start-guide.html .. _Relay Global Identification Specification: https://relay.dev/graphql/objectidentification.htm
.. _Relay Global Identification Specification: https://facebook.github.io/relay/graphql/objectidentification.htm .. _Relay Cursor Connection Specification: https://relay.dev/graphql/connections.htm
.. _Relay Cursor Connection Specification: https://facebook.github.io/relay/graphql/connections.htm
.. _Relay input Object Mutation: https://facebook.github.io/relay/graphql/mutations.htm

View File

@ -51,20 +51,20 @@ Example of a custom node:
name = 'Node' name = 'Node'
@staticmethod @staticmethod
def to_global_id(type, id): def to_global_id(type_, id):
return f"{type}:{id}" return f"{type_}:{id}"
@staticmethod @staticmethod
def get_node_from_global_id(info, global_id, only_type=None): def get_node_from_global_id(info, global_id, only_type=None):
type, id = global_id.split(':') type_, id = global_id.split(':')
if only_type: if only_type:
# We assure that the node type that we want to retrieve # We assure that the node type that we want to retrieve
# is the same that was indicated in the field type # is the same that was indicated in the field type
assert type == only_type._meta.name, 'Received not compatible node.' assert type_ == only_type._meta.name, 'Received not compatible node.'
if type == 'User': if type_ == 'User':
return get_user(id) return get_user(id)
elif type == 'Photo': elif type_ == 'Photo':
return get_photo(id) return get_photo(id)

View File

@ -77,13 +77,13 @@ Snapshot testing
As our APIs evolve, we need to know when our changes introduce any breaking changes that might break As our APIs evolve, we need to know when our changes introduce any breaking changes that might break
some of the clients of our GraphQL app. some of the clients of our GraphQL app.
However, writing tests and replicate the same response we expect from our GraphQL application can be However, writing tests and replicating the same response we expect from our GraphQL application can be a
tedious and repetitive task, and sometimes it's easier to skip this process. tedious and repetitive task, and sometimes it's easier to skip this process.
Because of that, we recommend the usage of `SnapshotTest <https://github.com/syrusakbary/snapshottest/>`_. Because of that, we recommend the usage of `SnapshotTest <https://github.com/syrusakbary/snapshottest/>`_.
SnapshotTest let us write all this tests in a breeze, as creates automatically the ``snapshots`` for us SnapshotTest lets us write all these tests in a breeze, as it automatically creates the ``snapshots`` for us
the first time the test is executed. the first time the test are executed.
Here is a simple example on how our tests will look if we use ``pytest``: Here is a simple example on how our tests will look if we use ``pytest``:

View File

@ -61,7 +61,8 @@ you can add description etc. to your enum without changing the original:
graphene.Enum.from_enum( graphene.Enum.from_enum(
AlreadyExistingPyEnum, AlreadyExistingPyEnum,
description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar') description=lambda v: return 'foo' if v == AlreadyExistingPyEnum.Foo else 'bar'
)
Notes Notes
@ -76,6 +77,7 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu
.. code:: python .. code:: python
from enum import Enum from enum import Enum
class Color(Enum): class Color(Enum):
RED = 1 RED = 1
GREEN = 2 GREEN = 2
@ -84,11 +86,12 @@ In the Python ``Enum`` implementation you can access a member by initing the Enu
assert Color(1) == Color.RED assert Color(1) == Color.RED
However, in Graphene ``Enum`` you need to call get to have the same effect: However, in Graphene ``Enum`` you need to call `.get` to have the same effect:
.. code:: python .. code:: python
from graphene import Enum from graphene import Enum
class Color(Enum): class Color(Enum):
RED = 1 RED = 1
GREEN = 2 GREEN = 2

View File

@ -44,7 +44,7 @@ Both of these types have all of the fields from the ``Character`` interface,
but also bring in extra fields, ``home_planet``, ``starships`` and but also bring in extra fields, ``home_planet``, ``starships`` and
``primary_function``, that are specific to that particular type of character. ``primary_function``, that are specific to that particular type of character.
The full GraphQL schema defition will look like this: The full GraphQL schema definition will look like this:
.. code:: .. code::

View File

@ -85,9 +85,9 @@ We should receive:
InputFields and InputObjectTypes InputFields and InputObjectTypes
---------------------------------- ----------------------------------
InputFields are used in mutations to allow nested input data for mutations InputFields are used in mutations to allow nested input data for mutations.
To use an InputField you define an InputObjectType that specifies the structure of your input data To use an InputField you define an InputObjectType that specifies the structure of your input data:
.. code:: python .. code:: python
@ -104,7 +104,6 @@ To use an InputField you define an InputObjectType that specifies the structure
person = graphene.Field(Person) person = graphene.Field(Person)
@staticmethod
def mutate(root, info, person_data=None): def mutate(root, info, person_data=None):
person = Person( person = Person(
name=person_data.name, name=person_data.name,
@ -113,7 +112,7 @@ To use an InputField you define an InputObjectType that specifies the structure
return CreatePerson(person=person) return CreatePerson(person=person)
Note that **name** and **age** are part of **person_data** now Note that **name** and **age** are part of **person_data** now.
Using the above mutation your new query would look like this: Using the above mutation your new query would look like this:
@ -129,7 +128,7 @@ Using the above mutation your new query would look like this:
} }
InputObjectTypes can also be fields of InputObjectTypes allowing you to have InputObjectTypes can also be fields of InputObjectTypes allowing you to have
as complex of input data as you need as complex of input data as you need:
.. code:: python .. code:: python
@ -161,7 +160,7 @@ To return an existing ObjectType instead of a mutation-specific type, set the **
def mutate(root, info, name): def mutate(root, info, name):
return Person(name=name) return Person(name=name)
Then, if we query (``schema.execute(query_str)``) the following: Then, if we query (``schema.execute(query_str)``) with the following:
.. code:: .. code::

View File

@ -102,7 +102,7 @@ When we execute a query against that schema.
query_string = "{ me { fullName } }" query_string = "{ me { fullName } }"
result = schema.execute(query_string) result = schema.execute(query_string)
assert result.data["me"] == {"fullName": "Luke Skywalker") assert result.data["me"] == {"fullName": "Luke Skywalker"}
Then we go through the following steps to resolve this query: Then we go through the following steps to resolve this query:

View File

@ -3,6 +3,11 @@
Scalars Scalars
======= =======
Scalar types represent concrete values at the leaves of a query. There are
several built in types that Graphene provides out of the box which represent common
values in Python. You can also create your own Scalar types to better express
values that you might have in your data model.
All Scalar types accept the following arguments. All are optional: All Scalar types accept the following arguments. All are optional:
``name``: *string* ``name``: *string*
@ -27,34 +32,39 @@ All Scalar types accept the following arguments. All are optional:
Base scalars Built in scalars
------------ ----------------
Graphene defines the following base Scalar Types: Graphene defines the following base Scalar Types that match the default `GraphQL types <https://graphql.org/learn/schema/#scalar-types>`_:
``graphene.String`` ``graphene.String``
^^^^^^^^^^^^^^^^^^^
Represents textual data, represented as UTF-8 Represents textual data, represented as UTF-8
character sequences. The String type is most often used by GraphQL to character sequences. The String type is most often used by GraphQL to
represent free-form human-readable text. represent free-form human-readable text.
``graphene.Int`` ``graphene.Int``
^^^^^^^^^^^^^^^^
Represents non-fractional signed whole numeric Represents non-fractional signed whole numeric
values. Int is a signed 32bit integer per the values. Int is a signed 32bit integer per the
`GraphQL spec <https://facebook.github.io/graphql/June2018/#sec-Int>`_ `GraphQL spec <https://facebook.github.io/graphql/June2018/#sec-Int>`_
``graphene.Float`` ``graphene.Float``
^^^^^^^^^^^^^^^^^^
Represents signed double-precision fractional Represents signed double-precision fractional
values as specified by values as specified by
`IEEE 754 <http://en.wikipedia.org/wiki/IEEE_floating_point>`_. `IEEE 754 <http://en.wikipedia.org/wiki/IEEE_floating_point>`_.
``graphene.Boolean`` ``graphene.Boolean``
^^^^^^^^^^^^^^^^^^^^
Represents `true` or `false`. Represents `true` or `false`.
``graphene.ID`` ``graphene.ID``
^^^^^^^^^^^^^^^
Represents a unique identifier, often used to Represents a unique identifier, often used to
refetch an object or as key for a cache. The ID type appears in a JSON refetch an object or as key for a cache. The ID type appears in a JSON
@ -62,24 +72,183 @@ Graphene defines the following base Scalar Types:
When expected as an input type, any string (such as `"4"`) or integer When expected as an input type, any string (such as `"4"`) or integer
(such as `4`) input value will be accepted as an ID. (such as `4`) input value will be accepted as an ID.
Graphene also provides custom scalars for Dates, Times, and JSON: ----
``graphene.types.datetime.Date`` Graphene also provides custom scalars for common values:
``graphene.Date``
^^^^^^^^^^^^^^^^^
Represents a Date value as specified by `iso8601 <https://en.wikipedia.org/wiki/ISO_8601>`_. Represents a Date value as specified by `iso8601 <https://en.wikipedia.org/wiki/ISO_8601>`_.
``graphene.types.datetime.DateTime`` .. code:: python
import datetime
from graphene import Schema, ObjectType, Date
class Query(ObjectType):
one_week_from = Date(required=True, date_input=Date(required=True))
def resolve_one_week_from(root, info, date_input):
assert date_input == datetime.date(2006, 1, 2)
return date_input + datetime.timedelta(weeks=1)
schema = Schema(query=Query)
results = schema.execute("""
query {
oneWeekFrom(dateInput: "2006-01-02")
}
""")
assert results.data == {"oneWeekFrom": "2006-01-09"}
``graphene.DateTime``
^^^^^^^^^^^^^^^^^^^^^
Represents a DateTime value as specified by `iso8601 <https://en.wikipedia.org/wiki/ISO_8601>`_. Represents a DateTime value as specified by `iso8601 <https://en.wikipedia.org/wiki/ISO_8601>`_.
``graphene.types.datetime.Time`` .. code:: python
import datetime
from graphene import Schema, ObjectType, DateTime
class Query(ObjectType):
one_hour_from = DateTime(required=True, datetime_input=DateTime(required=True))
def resolve_one_hour_from(root, info, datetime_input):
assert datetime_input == datetime.datetime(2006, 1, 2, 15, 4, 5)
return datetime_input + datetime.timedelta(hours=1)
schema = Schema(query=Query)
results = schema.execute("""
query {
oneHourFrom(datetimeInput: "2006-01-02T15:04:05")
}
""")
assert results.data == {"oneHourFrom": "2006-01-02T16:04:05"}
``graphene.Time``
^^^^^^^^^^^^^^^^^
Represents a Time value as specified by `iso8601 <https://en.wikipedia.org/wiki/ISO_8601>`_. Represents a Time value as specified by `iso8601 <https://en.wikipedia.org/wiki/ISO_8601>`_.
``graphene.types.json.JSONString`` .. code:: python
import datetime
from graphene import Schema, ObjectType, Time
class Query(ObjectType):
one_hour_from = Time(required=True, time_input=Time(required=True))
def resolve_one_hour_from(root, info, time_input):
assert time_input == datetime.time(15, 4, 5)
tmp_time_input = datetime.datetime.combine(datetime.date(1, 1, 1), time_input)
return (tmp_time_input + datetime.timedelta(hours=1)).time()
schema = Schema(query=Query)
results = schema.execute("""
query {
oneHourFrom(timeInput: "15:04:05")
}
""")
assert results.data == {"oneHourFrom": "16:04:05"}
``graphene.Decimal``
^^^^^^^^^^^^^^^^^^^^
Represents a Python Decimal value.
.. code:: python
import decimal
from graphene import Schema, ObjectType, Decimal
class Query(ObjectType):
add_one_to = Decimal(required=True, decimal_input=Decimal(required=True))
def resolve_add_one_to(root, info, decimal_input):
assert decimal_input == decimal.Decimal("10.50")
return decimal_input + decimal.Decimal("1")
schema = Schema(query=Query)
results = schema.execute("""
query {
addOneTo(decimalInput: "10.50")
}
""")
assert results.data == {"addOneTo": "11.50"}
``graphene.JSONString``
^^^^^^^^^^^^^^^^^^^^^^^
Represents a JSON string. Represents a JSON string.
.. code:: python
from graphene import Schema, ObjectType, JSONString, String
class Query(ObjectType):
update_json_key = JSONString(
required=True,
json_input=JSONString(required=True),
key=String(required=True),
value=String(required=True)
)
def resolve_update_json_key(root, info, json_input, key, value):
assert json_input == {"name": "Jane"}
json_input[key] = value
return json_input
schema = Schema(query=Query)
results = schema.execute("""
query {
updateJsonKey(jsonInput: "{\\"name\\": \\"Jane\\"}", key: "name", value: "Beth")
}
""")
assert results.data == {"updateJsonKey": "{\"name\": \"Beth\"}"}
``graphene.Base64``
^^^^^^^^^^^^^^^^^^^
Represents a Base64 encoded string.
.. code:: python
from graphene import Schema, ObjectType, Base64
class Query(ObjectType):
increment_encoded_id = Base64(
required=True,
base64_input=Base64(required=True),
)
def resolve_increment_encoded_id(root, info, base64_input):
assert base64_input == "4"
return int(base64_input) + 1
schema = Schema(query=Query)
results = schema.execute("""
query {
incrementEncodedId(base64Input: "NA==")
}
""")
assert results.data == {"incrementEncodedId": "NQ=="}
Custom scalars Custom scalars
-------------- --------------

View File

@ -1,11 +1,11 @@
Schema Schema
====== ======
A GraphQL **Schema** defines the types and relationship between **Fields** in your API. A GraphQL **Schema** defines the types and relationships between **Fields** in your API.
A Schema is created by supplying the root :ref:`ObjectType` of each operation, query (mandatory), mutation and subscription. A Schema is created by supplying the root :ref:`ObjectType` of each operation, query (mandatory), mutation and subscription.
Schema will collect all type definitions related to the root operations and then supplied to the validator and executor. Schema will collect all type definitions related to the root operations and then supply them to the validator and executor.
.. code:: python .. code:: python
@ -15,11 +15,11 @@ Schema will collect all type definitions related to the root operations and then
subscription=MyRootSubscription subscription=MyRootSubscription
) )
A Root Query is just a special :ref:`ObjectType` that :ref:`defines the fields <Scalars>` that are the entrypoint for your API. Root Mutation and Root Subscription are similar to Root Query, but for different operation types: A Root Query is just a special :ref:`ObjectType` that defines the fields that are the entrypoint for your API. Root Mutation and Root Subscription are similar to Root Query, but for different operation types:
* Query fetches data * Query fetches data
* Mutation to changes data and retrieve the changes * Mutation changes data and retrieves the changes
* Subscription to sends changes to clients in real time * Subscription sends changes to clients in real-time
Review the `GraphQL documentation on Schema`_ for a brief overview of fields, schema and operations. Review the `GraphQL documentation on Schema`_ for a brief overview of fields, schema and operations.
@ -44,7 +44,7 @@ There are some cases where the schema cannot access all of the types that we pla
For example, when a field returns an ``Interface``, the schema doesn't know about any of the For example, when a field returns an ``Interface``, the schema doesn't know about any of the
implementations. implementations.
In this case, we need to use the ``types`` argument when creating the Schema. In this case, we need to use the ``types`` argument when creating the Schema:
.. code:: python .. code:: python
@ -56,14 +56,14 @@ In this case, we need to use the ``types`` argument when creating the Schema.
.. _SchemaAutoCamelCase: .. _SchemaAutoCamelCase:
Auto CamelCase field names Auto camelCase field names
-------------------------- --------------------------
By default all field and argument names (that are not By default all field and argument names (that are not
explicitly set with the ``name`` arg) will be converted from explicitly set with the ``name`` arg) will be converted from
``snake_case`` to ``camelCase`` (as the API is usually being consumed by a js/mobile client) ``snake_case`` to ``camelCase`` (as the API is usually being consumed by a js/mobile client)
For example with the ObjectType For example with the ObjectType the ``last_name`` field name is converted to ``lastName``:
.. code:: python .. code:: python
@ -71,12 +71,10 @@ For example with the ObjectType
last_name = graphene.String() last_name = graphene.String()
other_name = graphene.String(name='_other_Name') other_name = graphene.String(name='_other_Name')
the ``last_name`` field name is converted to ``lastName``.
In case you don't want to apply this transformation, provide a ``name`` argument to the field constructor. In case you don't want to apply this transformation, provide a ``name`` argument to the field constructor.
``other_name`` converts to ``_other_Name`` (without further transformations). ``other_name`` converts to ``_other_Name`` (without further transformations).
Your query should look like Your query should look like:
.. code:: .. code::
@ -86,7 +84,7 @@ Your query should look like
} }
To disable this behavior, set the ``auto_camelcase`` to ``False`` upon schema instantiation. To disable this behavior, set the ``auto_camelcase`` to ``False`` upon schema instantiation:
.. code:: python .. code:: python

View File

@ -7,7 +7,7 @@ to specify any common fields between the types.
The basics: The basics:
- Each Union is a Python class that inherits from ``graphene.Union``. - Each Union is a Python class that inherits from ``graphene.Union``.
- Unions don't have any fields on it, just links to the possible objecttypes. - Unions don't have any fields on it, just links to the possible ObjectTypes.
Quick example Quick example
------------- -------------

View File

@ -49,7 +49,7 @@ type Faction implements Node {
name: String name: String
"""The ships used by the faction.""" """The ships used by the faction."""
ships(before: String = null, after: String = null, first: Int = null, last: Int = null): ShipConnection ships(before: String, after: String, first: Int, last: Int): ShipConnection
} }
"""An object with an ID""" """An object with an ID"""
@ -115,5 +115,4 @@ input IntroduceShipInput {
shipName: String! shipName: String!
factionId: String! factionId: String!
clientMutationId: String clientMutationId: String
} }'''
'''

View File

@ -9,7 +9,7 @@ client = Client(schema)
def test_str_schema(snapshot): def test_str_schema(snapshot):
snapshot.assert_match(str(schema)) snapshot.assert_match(str(schema).strip())
def test_correctly_fetches_id_name_rebels(snapshot): def test_correctly_fetches_id_name_rebels(snapshot):

View File

@ -1,88 +1,88 @@
from .pyutils.version import get_version from .pyutils.version import get_version
from .types import (
ObjectType,
InputObjectType,
Interface,
Mutation,
Field,
InputField,
Schema,
Scalar,
String,
ID,
Int,
Float,
Boolean,
Date,
DateTime,
Time,
Decimal,
JSONString,
UUID,
List,
NonNull,
Enum,
Argument,
Dynamic,
Union,
Context,
ResolveInfo,
)
from .relay import ( from .relay import (
Node,
is_node,
GlobalID,
ClientIDMutation, ClientIDMutation,
Connection, Connection,
ConnectionField, ConnectionField,
GlobalID,
Node,
PageInfo, PageInfo,
is_node,
)
from .types import (
ID,
UUID,
Argument,
Base64,
Boolean,
Context,
Date,
DateTime,
Decimal,
Dynamic,
Enum,
Field,
Float,
InputField,
InputObjectType,
Int,
Interface,
JSONString,
List,
Mutation,
NonNull,
ObjectType,
ResolveInfo,
Scalar,
Schema,
String,
Time,
Union,
) )
from .utils.resolve_only_args import resolve_only_args
from .utils.module_loading import lazy_import from .utils.module_loading import lazy_import
from .utils.resolve_only_args import resolve_only_args
VERSION = (3, 1, 0, "final", 0)
VERSION = (3, 0, 0, "beta", 1)
__version__ = get_version(VERSION) __version__ = get_version(VERSION)
__all__ = [ __all__ = [
"__version__", "__version__",
"ObjectType",
"InputObjectType",
"Interface",
"Mutation",
"Field",
"InputField",
"Schema",
"Scalar",
"String",
"ID",
"Int",
"Float",
"Enum",
"Boolean",
"Date",
"DateTime",
"Time",
"Decimal",
"JSONString",
"UUID",
"List",
"NonNull",
"Argument", "Argument",
"Dynamic", "Base64",
"Union", "Boolean",
"resolve_only_args",
"Node",
"is_node",
"GlobalID",
"ClientIDMutation", "ClientIDMutation",
"Connection", "Connection",
"ConnectionField", "ConnectionField",
"PageInfo",
"lazy_import",
"Context", "Context",
"Date",
"DateTime",
"Decimal",
"Dynamic",
"Enum",
"Field",
"Float",
"GlobalID",
"ID",
"InputField",
"InputObjectType",
"Int",
"Interface",
"JSONString",
"List",
"Mutation",
"Node",
"NonNull",
"ObjectType",
"PageInfo",
"ResolveInfo", "ResolveInfo",
"Scalar",
"Schema",
"String",
"Time",
"UUID",
"Union",
"is_node",
"lazy_import",
"resolve_only_args",
] ]

View File

@ -291,14 +291,7 @@ class Field:
class _DataclassParams: class _DataclassParams:
__slots__ = ( __slots__ = ("init", "repr", "eq", "order", "unsafe_hash", "frozen")
"init",
"repr",
"eq",
"order",
"unsafe_hash",
"frozen",
)
def __init__(self, init, repr, eq, order, unsafe_hash, frozen): def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
self.init = init self.init = init
@ -442,13 +435,11 @@ def _field_init(f, frozen, globals, self_name):
# This field does not need initialization. Signify that # This field does not need initialization. Signify that
# to the caller by returning None. # to the caller by returning None.
return None return None
# Only test this now, so that we can create variables for the # Only test this now, so that we can create variables for the
# default. However, return None to signify that we're not going # default. However, return None to signify that we're not going
# to actually do the assignment statement for InitVars. # to actually do the assignment statement for InitVars.
if f._field_type == _FIELD_INITVAR: if f._field_type == _FIELD_INITVAR:
return None return None
# Now, actually generate the field assignment. # Now, actually generate the field assignment.
return _field_assign(frozen, f.name, value, self_name) return _field_assign(frozen, f.name, value, self_name)
@ -490,7 +481,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
raise TypeError( raise TypeError(
f"non-default argument {f.name!r} " "follows default argument" f"non-default argument {f.name!r} " "follows default argument"
) )
globals = {"MISSING": MISSING, "_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY} globals = {"MISSING": MISSING, "_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY}
body_lines = [] body_lines = []
@ -500,16 +490,13 @@ def _init_fn(fields, frozen, has_post_init, self_name):
# initialization (it's a pseudo-field). Just skip it. # initialization (it's a pseudo-field). Just skip it.
if line: if line:
body_lines.append(line) body_lines.append(line)
# Does this class have a post-init function? # Does this class have a post-init function?
if has_post_init: if has_post_init:
params_str = ",".join(f.name for f in fields if f._field_type is _FIELD_INITVAR) params_str = ",".join(f.name for f in fields if f._field_type is _FIELD_INITVAR)
body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})") body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})")
# If no body lines, use 'pass'. # If no body lines, use 'pass'.
if not body_lines: if not body_lines:
body_lines = ["pass"] body_lines = ["pass"]
locals = {f"_type_{f.name}": f.type for f in fields} locals = {f"_type_{f.name}": f.type for f in fields}
return _create_fn( return _create_fn(
"__init__", "__init__",
@ -674,7 +661,6 @@ def _get_field(cls, a_name, a_type):
# This is a field in __slots__, so it has no default value. # This is a field in __slots__, so it has no default value.
default = MISSING default = MISSING
f = field(default=default) f = field(default=default)
# Only at this point do we know the name and the type. Set them. # Only at this point do we know the name and the type. Set them.
f.name = a_name f.name = a_name
f.type = a_type f.type = a_type
@ -705,7 +691,6 @@ def _get_field(cls, a_name, a_type):
and _is_type(f.type, cls, typing, typing.ClassVar, _is_classvar) and _is_type(f.type, cls, typing, typing.ClassVar, _is_classvar)
): ):
f._field_type = _FIELD_CLASSVAR f._field_type = _FIELD_CLASSVAR
# If the type is InitVar, or if it's a matching string annotation, # If the type is InitVar, or if it's a matching string annotation,
# then it's an InitVar. # then it's an InitVar.
if f._field_type is _FIELD: if f._field_type is _FIELD:
@ -717,7 +702,6 @@ def _get_field(cls, a_name, a_type):
and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, _is_initvar) and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, _is_initvar)
): ):
f._field_type = _FIELD_INITVAR f._field_type = _FIELD_INITVAR
# Validations for individual fields. This is delayed until now, # Validations for individual fields. This is delayed until now,
# instead of in the Field() constructor, since only here do we # instead of in the Field() constructor, since only here do we
# know the field name, which allows for better error reporting. # know the field name, which allows for better error reporting.
@ -731,14 +715,12 @@ def _get_field(cls, a_name, a_type):
# example, how about init=False (or really, # example, how about init=False (or really,
# init=<not-the-default-init-value>)? It makes no sense for # init=<not-the-default-init-value>)? It makes no sense for
# ClassVar and InitVar to specify init=<anything>. # ClassVar and InitVar to specify init=<anything>.
# For real fields, disallow mutable defaults for known types. # For real fields, disallow mutable defaults for known types.
if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
raise ValueError( raise ValueError(
f"mutable default {type(f.default)} for field " f"mutable default {type(f.default)} for field "
f"{f.name} is not allowed: use default_factory" f"{f.name} is not allowed: use default_factory"
) )
return f return f
@ -827,7 +809,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
fields[f.name] = f fields[f.name] = f
if getattr(b, _PARAMS).frozen: if getattr(b, _PARAMS).frozen:
any_frozen_base = True any_frozen_base = True
# Annotations that are defined in this class (not in base # Annotations that are defined in this class (not in base
# classes). If __annotations__ isn't present, then this class # classes). If __annotations__ isn't present, then this class
# adds no new annotations. We use this to compute fields that are # adds no new annotations. We use this to compute fields that are
@ -845,7 +826,9 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# Now find fields in our class. While doing so, validate some # Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes) where # things, and set the default values (as class attributes) where
# we can. # we can.
cls_fields = [_get_field(cls, name, type) for name, type in cls_annotations.items()] cls_fields = [
_get_field(cls, name, type_) for name, type_ in cls_annotations.items()
]
for f in cls_fields: for f in cls_fields:
fields[f.name] = f fields[f.name] = f
@ -864,22 +847,18 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
delattr(cls, f.name) delattr(cls, f.name)
else: else:
setattr(cls, f.name, f.default) setattr(cls, f.name, f.default)
# Do we have any Field members that don't also have annotations? # Do we have any Field members that don't also have annotations?
for name, value in cls.__dict__.items(): for name, value in cls.__dict__.items():
if isinstance(value, Field) and not name in cls_annotations: if isinstance(value, Field) and not name in cls_annotations:
raise TypeError(f"{name!r} is a field but has no type annotation") raise TypeError(f"{name!r} is a field but has no type annotation")
# Check rules that apply if we are derived from any dataclasses. # Check rules that apply if we are derived from any dataclasses.
if has_dataclass_bases: if has_dataclass_bases:
# Raise an exception if any of our bases are frozen, but we're not. # Raise an exception if any of our bases are frozen, but we're not.
if any_frozen_base and not frozen: if any_frozen_base and not frozen:
raise TypeError("cannot inherit non-frozen dataclass from a " "frozen one") raise TypeError("cannot inherit non-frozen dataclass from a " "frozen one")
# Raise an exception if we're frozen, but none of our bases are. # Raise an exception if we're frozen, but none of our bases are.
if not any_frozen_base and frozen: if not any_frozen_base and frozen:
raise TypeError("cannot inherit frozen dataclass from a " "non-frozen one") raise TypeError("cannot inherit frozen dataclass from a " "non-frozen one")
# Remember all of the fields on our class (including bases). This # Remember all of the fields on our class (including bases). This
# also marks this class as being a dataclass. # also marks this class as being a dataclass.
setattr(cls, _FIELDS, fields) setattr(cls, _FIELDS, fields)
@ -898,7 +877,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# eq methods. # eq methods.
if order and not eq: if order and not eq:
raise ValueError("eq must be true if order is true") raise ValueError("eq must be true if order is true")
if init: if init:
# Does this class have a post-init function? # Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME) has_post_init = hasattr(cls, _POST_INIT_NAME)
@ -918,7 +896,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
"__dataclass_self__" if "self" in fields else "self", "__dataclass_self__" if "self" in fields else "self",
), ),
) )
# Get the fields as a list, and include only real fields. This is # Get the fields as a list, and include only real fields. This is
# used in all of the following methods. # used in all of the following methods.
field_list = [f for f in fields.values() if f._field_type is _FIELD] field_list = [f for f in fields.values() if f._field_type is _FIELD]
@ -926,7 +903,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
if repr: if repr:
flds = [f for f in field_list if f.repr] flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, "__repr__", _repr_fn(flds)) _set_new_attribute(cls, "__repr__", _repr_fn(flds))
if eq: if eq:
# Create _eq__ method. There's no need for a __ne__ method, # Create _eq__ method. There's no need for a __ne__ method,
# since python will call __eq__ and negate it. # since python will call __eq__ and negate it.
@ -936,7 +912,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
_set_new_attribute( _set_new_attribute(
cls, "__eq__", _cmp_fn("__eq__", "==", self_tuple, other_tuple) cls, "__eq__", _cmp_fn("__eq__", "==", self_tuple, other_tuple)
) )
if order: if order:
# Create and set the ordering methods. # Create and set the ordering methods.
flds = [f for f in field_list if f.compare] flds = [f for f in field_list if f.compare]
@ -956,7 +931,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
f"in class {cls.__name__}. Consider using " f"in class {cls.__name__}. Consider using "
"functools.total_ordering" "functools.total_ordering"
) )
if frozen: if frozen:
for fn in _frozen_get_del_attr(cls, field_list): for fn in _frozen_get_del_attr(cls, field_list):
if _set_new_attribute(cls, fn.__name__, fn): if _set_new_attribute(cls, fn.__name__, fn):
@ -964,7 +938,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
f"Cannot overwrite attribute {fn.__name__} " f"Cannot overwrite attribute {fn.__name__} "
f"in class {cls.__name__}" f"in class {cls.__name__}"
) )
# Decide if/how we're going to create a hash function. # Decide if/how we're going to create a hash function.
hash_action = _hash_action[ hash_action = _hash_action[
bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash
@ -973,11 +946,9 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# No need to call _set_new_attribute here, since by the time # No need to call _set_new_attribute here, since by the time
# we're here the overwriting is unconditional. # we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list) cls.__hash__ = hash_action(cls, field_list)
if not getattr(cls, "__doc__"): if not getattr(cls, "__doc__"):
# Create a class doc-string. # Create a class doc-string.
cls.__doc__ = cls.__name__ + str(inspect.signature(cls)).replace(" -> None", "") cls.__doc__ = cls.__name__ + str(inspect.signature(cls)).replace(" -> None", "")
return cls return cls
@ -1013,7 +984,6 @@ def dataclass(
if _cls is None: if _cls is None:
# We're called with parens. # We're called with parens.
return wrap return wrap
# We're called as @dataclass without parens. # We're called as @dataclass without parens.
return wrap(_cls) return wrap(_cls)
@ -1030,7 +1000,6 @@ def fields(class_or_instance):
fields = getattr(class_or_instance, _FIELDS) fields = getattr(class_or_instance, _FIELDS)
except AttributeError: except AttributeError:
raise TypeError("must be called with a dataclass type or instance") raise TypeError("must be called with a dataclass type or instance")
# Exclude pseudo-fields. Note that fields is sorted by insertion # Exclude pseudo-fields. Note that fields is sorted by insertion
# order, so the order of the tuple is as the fields were defined. # order, so the order of the tuple is as the fields were defined.
return tuple(f for f in fields.values() if f._field_type is _FIELD) return tuple(f for f in fields.values() if f._field_type is _FIELD)
@ -1172,7 +1141,6 @@ def make_dataclass(
else: else:
# Copy namespace since we're going to mutate it. # Copy namespace since we're going to mutate it.
namespace = namespace.copy() namespace = namespace.copy()
# While we're looking through the field names, validate that they # While we're looking through the field names, validate that they
# are identifiers, are not keywords, and not duplicates. # are identifiers, are not keywords, and not duplicates.
seen = set() seen = set()
@ -1182,23 +1150,20 @@ def make_dataclass(
name = item name = item
tp = "typing.Any" tp = "typing.Any"
elif len(item) == 2: elif len(item) == 2:
name, tp, = item (name, tp) = item
elif len(item) == 3: elif len(item) == 3:
name, tp, spec = item name, tp, spec = item
namespace[name] = spec namespace[name] = spec
else: else:
raise TypeError(f"Invalid field: {item!r}") raise TypeError(f"Invalid field: {item!r}")
if not isinstance(name, str) or not name.isidentifier(): if not isinstance(name, str) or not name.isidentifier():
raise TypeError(f"Field names must be valid identifers: {name!r}") raise TypeError(f"Field names must be valid identifers: {name!r}")
if keyword.iskeyword(name): if keyword.iskeyword(name):
raise TypeError(f"Field names must not be keywords: {name!r}") raise TypeError(f"Field names must not be keywords: {name!r}")
if name in seen: if name in seen:
raise TypeError(f"Field name duplicated: {name!r}") raise TypeError(f"Field name duplicated: {name!r}")
seen.add(name) seen.add(name)
anns[name] = tp anns[name] = tp
namespace["__annotations__"] = anns namespace["__annotations__"] = anns
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
# of generic dataclassses. # of generic dataclassses.
@ -1234,7 +1199,6 @@ def replace(obj, **changes):
if not _is_dataclass_instance(obj): if not _is_dataclass_instance(obj):
raise TypeError("replace() should be called on dataclass instances") raise TypeError("replace() should be called on dataclass instances")
# It's an error to have init=False fields in 'changes'. # It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj. # If a field is not in 'changes', read its value from the provided obj.
@ -1248,10 +1212,8 @@ def replace(obj, **changes):
"replace()" "replace()"
) )
continue continue
if f.name not in changes: if f.name not in changes:
changes[f.name] = getattr(obj, f.name) changes[f.name] = getattr(obj, f.name)
# Create the new object, which calls __init__() and # Create the new object, which calls __init__() and
# __post_init__() (if defined), using all of the init fields we've # __post_init__() (if defined), using all of the init fields we've
# added and/or left in 'changes'. If there are values supplied in # added and/or left in 'changes'. If there are values supplied in

View File

@ -19,10 +19,7 @@ def get_version(version=None):
sub = "" sub = ""
if version[3] == "alpha" and version[4] == 0: if version[3] == "alpha" and version[4] == 0:
git_changeset = get_git_changeset() git_changeset = get_git_changeset()
if git_changeset: sub = ".dev%s" % git_changeset if git_changeset else ".dev"
sub = ".dev%s" % git_changeset
else:
sub = ".dev"
elif version[3] != "final": elif version[3] != "final":
mapping = {"alpha": "a", "beta": "b", "rc": "rc"} mapping = {"alpha": "a", "beta": "b", "rc": "rc"}
sub = mapping[version[3]] + str(version[4]) sub = mapping[version[3]] + str(version[4])

View File

@ -117,19 +117,19 @@ def connection_adapter(cls, edges, pageInfo):
class IterableConnectionField(Field): class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs): def __init__(self, type_, *args, **kwargs):
kwargs.setdefault("before", String()) kwargs.setdefault("before", String())
kwargs.setdefault("after", String()) kwargs.setdefault("after", String())
kwargs.setdefault("first", Int()) kwargs.setdefault("first", Int())
kwargs.setdefault("last", Int()) kwargs.setdefault("last", Int())
super(IterableConnectionField, self).__init__(type, *args, **kwargs) super(IterableConnectionField, self).__init__(type_, *args, **kwargs)
@property @property
def type(self): def type(self):
type = super(IterableConnectionField, self).type type_ = super(IterableConnectionField, self).type
connection_type = type connection_type = type_
if isinstance(type, NonNull): if isinstance(type_, NonNull):
connection_type = type.of_type connection_type = type_.of_type
if is_node(connection_type): if is_node(connection_type):
raise Exception( raise Exception(
@ -140,7 +140,7 @@ class IterableConnectionField(Field):
assert issubclass( assert issubclass(
connection_type, Connection connection_type, Connection
), f'{self.__class__.__name__} type has to be a subclass of Connection. Received "{connection_type}".' ), f'{self.__class__.__name__} type has to be a subclass of Connection. Received "{connection_type}".'
return type return type_
@classmethod @classmethod
def resolve_connection(cls, connection_type, args, resolved): def resolve_connection(cls, connection_type, args, resolved):
@ -171,8 +171,8 @@ class IterableConnectionField(Field):
on_resolve = partial(cls.resolve_connection, connection_type, args) on_resolve = partial(cls.resolve_connection, connection_type, args)
return maybe_thenable(resolved, on_resolve) return maybe_thenable(resolved, on_resolve)
def get_resolver(self, parent_resolver): def wrap_resolve(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
return partial(self.connection_resolver, resolver, self.type) return partial(self.connection_resolver, resolver, self.type)

View File

@ -18,11 +18,7 @@ def is_node(objecttype):
if not issubclass(objecttype, ObjectType): if not issubclass(objecttype, ObjectType):
return False return False
for i in objecttype._meta.interfaces: return any(issubclass(i, Node) for i in objecttype._meta.interfaces)
if issubclass(i, Node):
return True
return False
class GlobalID(Field): class GlobalID(Field):
@ -37,7 +33,7 @@ class GlobalID(Field):
parent_type_name = parent_type_name or info.parent_type.name parent_type_name = parent_type_name or info.parent_type.name
return node.to_global_id(parent_type_name, type_id) # root._meta.name return node.to_global_id(parent_type_name, type_id) # root._meta.name
def get_resolver(self, parent_resolver): def wrap_resolve(self, parent_resolver):
return partial( return partial(
self.id_resolver, self.id_resolver,
parent_resolver, parent_resolver,
@ -47,20 +43,20 @@ class GlobalID(Field):
class NodeField(Field): class NodeField(Field):
def __init__(self, node, type=False, **kwargs): def __init__(self, node, type_=False, **kwargs):
assert issubclass(node, Node), "NodeField can only operate in Nodes" assert issubclass(node, Node), "NodeField can only operate in Nodes"
self.node_type = node self.node_type = node
self.field_type = type self.field_type = type_
super(NodeField, self).__init__( super(NodeField, self).__init__(
# If we don's specify a type, the field type will be the node # If we don's specify a type, the field type will be the node
# interface # interface
type or node, type_ or node,
id=ID(required=True, description="The ID of the object"), id=ID(required=True, description="The ID of the object"),
**kwargs, **kwargs,
) )
def get_resolver(self, parent_resolver): def wrap_resolve(self, parent_resolver):
return partial(self.node_type.node_resolver, get_type(self.field_type)) return partial(self.node_type.node_resolver, get_type(self.field_type))
@ -90,13 +86,13 @@ class Node(AbstractNode):
def get_node_from_global_id(cls, info, global_id, only_type=None): def get_node_from_global_id(cls, info, global_id, only_type=None):
try: try:
_type, _id = cls.from_global_id(global_id) _type, _id = cls.from_global_id(global_id)
if not _type:
raise ValueError("Invalid Global ID")
except Exception as e: except Exception as e:
raise Exception( raise Exception(
(
f'Unable to parse global ID "{global_id}". ' f'Unable to parse global ID "{global_id}". '
'Make sure it is a base64 encoded string in the format: "TypeName:id". ' 'Make sure it is a base64 encoded string in the format: "TypeName:id". '
f"Exception message: {str(e)}" f"Exception message: {e}"
)
) )
graphene_type = info.schema.get_type(_type) graphene_type = info.schema.get_type(_type)
@ -125,5 +121,5 @@ class Node(AbstractNode):
return from_global_id(global_id) return from_global_id(global_id)
@classmethod @classmethod
def to_global_id(cls, type, id): def to_global_id(cls, type_, id):
return to_global_id(type, id) return to_global_id(type_, id)

View File

@ -51,10 +51,10 @@ letters = {letter: Letter(id=i, letter=letter) for i, letter in enumerate(letter
def edges(selected_letters): def edges(selected_letters):
return [ return [
{ {
"node": {"id": base64("Letter:%s" % l.id), "letter": l.letter}, "node": {"id": base64("Letter:%s" % letter.id), "letter": letter.letter},
"cursor": base64("arrayconnection:%s" % l.id), "cursor": base64("arrayconnection:%s" % letter.id),
} }
for l in [letters[i] for i in selected_letters] for letter in [letters[i] for i in selected_letters]
] ]

View File

@ -51,10 +51,10 @@ letters = {letter: Letter(id=i, letter=letter) for i, letter in enumerate(letter
def edges(selected_letters): def edges(selected_letters):
return [ return [
{ {
"node": {"id": base64("Letter:%s" % l.id), "letter": l.letter}, "node": {"id": base64("Letter:%s" % letter.id), "letter": letter.letter},
"cursor": base64("arrayconnection:%s" % l.id), "cursor": base64("arrayconnection:%s" % letter.id),
} }
for l in [letters[i] for i in selected_letters] for letter in [letters[i] for i in selected_letters]
] ]
@ -66,7 +66,6 @@ def cursor_for(ltr):
async def execute(args=""): async def execute(args=""):
if args: if args:
args = "(" + args + ")" args = "(" + args + ")"
return await schema.execute_async( return await schema.execute_async(
""" """
{ {
@ -164,14 +163,14 @@ async def test_respects_first_and_after_and_before_too_few():
@mark.asyncio @mark.asyncio
async def test_respects_first_and_after_and_before_too_many(): async def test_respects_first_and_after_and_before_too_many():
await check( await check(
f'first: 4, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD", f'first: 4, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD"
) )
@mark.asyncio @mark.asyncio
async def test_respects_first_and_after_and_before_exactly_right(): async def test_respects_first_and_after_and_before_exactly_right():
await check( await check(
f'first: 3, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD", f'first: 3, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD"
) )
@ -187,14 +186,14 @@ async def test_respects_last_and_after_and_before_too_few():
@mark.asyncio @mark.asyncio
async def test_respects_last_and_after_and_before_too_many(): async def test_respects_last_and_after_and_before_too_many():
await check( await check(
f'last: 4, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD", f'last: 4, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD"
) )
@mark.asyncio @mark.asyncio
async def test_respects_last_and_after_and_before_exactly_right(): async def test_respects_last_and_after_and_before_exactly_right():
await check( await check(
f'last: 3, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD", f'last: 3, after: "{cursor_for("A")}", before: "{cursor_for("E")}"', "BCD"
) )

View File

@ -45,7 +45,7 @@ def test_global_id_allows_overriding_of_node_and_required():
def test_global_id_defaults_to_info_parent_type(): def test_global_id_defaults_to_info_parent_type():
my_id = "1" my_id = "1"
gid = GlobalID() gid = GlobalID()
id_resolver = gid.get_resolver(lambda *_: my_id) id_resolver = gid.wrap_resolve(lambda *_: my_id)
my_global_id = id_resolver(None, Info(User)) my_global_id = id_resolver(None, Info(User))
assert my_global_id == to_global_id(User._meta.name, my_id) assert my_global_id == to_global_id(User._meta.name, my_id)
@ -53,6 +53,6 @@ def test_global_id_defaults_to_info_parent_type():
def test_global_id_allows_setting_customer_parent_type(): def test_global_id_allows_setting_customer_parent_type():
my_id = "1" my_id = "1"
gid = GlobalID(parent_type=User) gid = GlobalID(parent_type=User)
id_resolver = gid.get_resolver(lambda *_: my_id) id_resolver = gid.wrap_resolve(lambda *_: my_id)
my_global_id = id_resolver(None, None) my_global_id = id_resolver(None, None)
assert my_global_id == to_global_id(User._meta.name, my_id) assert my_global_id == to_global_id(User._meta.name, my_id)

View File

@ -1,7 +1,7 @@
import re import re
from graphql_relay import to_global_id from textwrap import dedent
from graphql.pyutils import dedent from graphql_relay import to_global_id
from ...types import ObjectType, Schema, String from ...types import ObjectType, Schema, String
from ..node import Node, is_node from ..node import Node, is_node
@ -171,7 +171,9 @@ def test_node_field_only_lazy_type_wrong():
def test_str_schema(): def test_str_schema():
assert str(schema) == dedent( assert (
str(schema).strip()
== dedent(
''' '''
schema { schema {
query: RootQuery query: RootQuery
@ -213,4 +215,5 @@ def test_str_schema():
): MyNode ): MyNode
} }
''' '''
).strip()
) )

View File

@ -1,5 +1,6 @@
from textwrap import dedent
from graphql import graphql_sync from graphql import graphql_sync
from graphql.pyutils import dedent
from ...types import Interface, ObjectType, Schema from ...types import Interface, ObjectType, Schema
from ...types.scalars import Int, String from ...types.scalars import Int, String
@ -11,7 +12,7 @@ class CustomNode(Node):
name = "Node" name = "Node"
@staticmethod @staticmethod
def to_global_id(type, id): def to_global_id(type_, id):
return id return id
@staticmethod @staticmethod
@ -53,7 +54,9 @@ graphql_schema = schema.graphql_schema
def test_str_schema_correct(): def test_str_schema_correct():
assert str(schema) == dedent( assert (
str(schema).strip()
== dedent(
''' '''
schema { schema {
query: RootQuery query: RootQuery
@ -92,6 +95,7 @@ def test_str_schema_correct():
): Node ): Node
} }
''' '''
).strip()
) )

View File

@ -1,5 +1,4 @@
from promise import Promise, is_thenable from promise import Promise, is_thenable
from graphql.error import format_error as format_graphql_error
from graphql.error import GraphQLError from graphql.error import GraphQLError
from graphene.types.schema import Schema from graphene.types.schema import Schema
@ -7,7 +6,7 @@ from graphene.types.schema import Schema
def default_format_error(error): def default_format_error(error):
if isinstance(error, GraphQLError): if isinstance(error, GraphQLError):
return format_graphql_error(error) return error.formatted
return {"message": str(error)} return {"message": str(error)}

View File

@ -0,0 +1,36 @@
from ...types import ObjectType, Schema, String, NonNull
class Query(ObjectType):
hello = String(input=NonNull(String))
def resolve_hello(self, info, input):
if input == "nothing":
return None
return f"Hello {input}!"
schema = Schema(query=Query)
def test_required_input_provided():
"""
Test that a required argument works when provided.
"""
input_value = "Potato"
result = schema.execute('{ hello(input: "%s") }' % input_value)
assert not result.errors
assert result.data == {"hello": "Hello Potato!"}
def test_required_input_missing():
"""
Test that a required argument raised an error if not provided.
"""
result = schema.execute("{ hello }")
assert result.errors
assert len(result.errors) == 1
assert (
result.errors[0].message
== "Field 'hello' argument 'input' of type 'String!' is required, but it was not provided."
)

View File

@ -0,0 +1,53 @@
import pytest
from ...types.base64 import Base64
from ...types.datetime import Date, DateTime
from ...types.decimal import Decimal
from ...types.generic import GenericScalar
from ...types.json import JSONString
from ...types.objecttype import ObjectType
from ...types.scalars import ID, BigInt, Boolean, Float, Int, String
from ...types.schema import Schema
from ...types.uuid import UUID
@pytest.mark.parametrize(
"input_type,input_value",
[
(Date, '"2022-02-02"'),
(GenericScalar, '"foo"'),
(Int, "1"),
(BigInt, "12345678901234567890"),
(Float, "1.1"),
(String, '"foo"'),
(Boolean, "true"),
(ID, "1"),
(DateTime, '"2022-02-02T11:11:11"'),
(UUID, '"cbebbc62-758e-4f75-a890-bc73b5017d81"'),
(Decimal, "1.1"),
(JSONString, '{key:"foo",value:"bar"}'),
(Base64, '"Q2hlbG8gd29ycmxkCg=="'),
],
)
def test_parse_literal_with_variables(input_type, input_value):
# input_b needs to be evaluated as literal while the variable dict for
# input_a is passed along.
class Query(ObjectType):
generic = GenericScalar(input_a=GenericScalar(), input_b=input_type())
def resolve_generic(self, info, input_a=None, input_b=None):
return input
schema = Schema(query=Query)
query = f"""
query Test($a: GenericScalar){{
generic(inputA: $a, inputB: {input_value})
}}
"""
result = schema.execute(
query,
variables={"a": "bar"},
)
assert not result.errors

View File

@ -1,52 +1,53 @@
# flake8: noqa # flake8: noqa
from graphql import GraphQLResolveInfo as ResolveInfo from graphql import GraphQLResolveInfo as ResolveInfo
from .objecttype import ObjectType from .argument import Argument
from .interface import Interface from .base64 import Base64
from .mutation import Mutation from .context import Context
from .scalars import Scalar, String, ID, Int, Float, Boolean
from .datetime import Date, DateTime, Time from .datetime import Date, DateTime, Time
from .decimal import Decimal from .decimal import Decimal
from .json import JSONString from .dynamic import Dynamic
from .uuid import UUID
from .schema import Schema
from .structures import List, NonNull
from .enum import Enum from .enum import Enum
from .field import Field from .field import Field
from .inputfield import InputField from .inputfield import InputField
from .argument import Argument
from .inputobjecttype import InputObjectType from .inputobjecttype import InputObjectType
from .dynamic import Dynamic from .interface import Interface
from .json import JSONString
from .mutation import Mutation
from .objecttype import ObjectType
from .scalars import ID, Boolean, Float, Int, Scalar, String
from .schema import Schema
from .structures import List, NonNull
from .union import Union from .union import Union
from .context import Context from .uuid import UUID
__all__ = [ __all__ = [
"ObjectType", "Argument",
"InputObjectType", "Base64",
"Interface", "Boolean",
"Mutation", "Context",
"Enum",
"Field",
"InputField",
"Schema",
"Scalar",
"String",
"ID",
"Int",
"Float",
"Date", "Date",
"DateTime", "DateTime",
"Time",
"Decimal", "Decimal",
"JSONString",
"UUID",
"Boolean",
"List",
"NonNull",
"Argument",
"Dynamic", "Dynamic",
"Union", "Enum",
"Context", "Field",
"Float",
"ID",
"InputField",
"InputObjectType",
"Int",
"Interface",
"JSONString",
"List",
"Mutation",
"NonNull",
"ObjectType",
"ResolveInfo", "ResolveInfo",
"Scalar",
"Schema",
"String",
"Time",
"UUID",
"Union",
] ]

View File

@ -1,4 +1,5 @@
from itertools import chain from itertools import chain
from graphql import Undefined
from .dynamic import Dynamic from .dynamic import Dynamic
from .mountedtype import MountedType from .mountedtype import MountedType
@ -40,8 +41,8 @@ class Argument(MountedType):
def __init__( def __init__(
self, self,
type, type_,
default_value=None, default_value=Undefined,
description=None, description=None,
name=None, name=None,
required=False, required=False,
@ -50,10 +51,10 @@ class Argument(MountedType):
super(Argument, self).__init__(_creation_counter=_creation_counter) super(Argument, self).__init__(_creation_counter=_creation_counter)
if required: if required:
type = NonNull(type) type_ = NonNull(type_)
self.name = name self.name = name
self._type = type self._type = type_
self.default_value = default_value self.default_value = default_value
self.description = description self.description = description

View File

@ -38,7 +38,7 @@ class BaseType(SubclassWithMeta):
def __init_subclass_with_meta__( def __init_subclass_with_meta__(
cls, name=None, description=None, _meta=None, **_kwargs cls, name=None, description=None, _meta=None, **_kwargs
): ):
assert "_meta" not in cls.__dict__, "Can't assign directly meta" assert "_meta" not in cls.__dict__, "Can't assign meta directly"
if not _meta: if not _meta:
return return
_meta.name = name or cls.__name__ _meta.name = name or cls.__name__

43
graphene/types/base64.py Normal file
View File

@ -0,0 +1,43 @@
from binascii import Error as _Error
from base64 import b64decode, b64encode
from graphql.error import GraphQLError
from graphql.language import StringValueNode, print_ast
from .scalars import Scalar
class Base64(Scalar):
"""
The `Base64` scalar type represents a base64-encoded String.
"""
@staticmethod
def serialize(value):
if not isinstance(value, bytes):
if isinstance(value, str):
value = value.encode("utf-8")
else:
value = str(value).encode("utf-8")
return b64encode(value).decode("utf-8")
@classmethod
def parse_literal(cls, node, _variables=None):
if not isinstance(node, StringValueNode):
raise GraphQLError(
f"Base64 cannot represent non-string value: {print_ast(node)}"
)
return cls.parse_value(node.value)
@staticmethod
def parse_value(value):
if not isinstance(value, bytes):
if not isinstance(value, str):
raise GraphQLError(
f"Base64 cannot represent non-string value: {repr(value)}"
)
value = value.encode("utf-8")
try:
return b64decode(value, validate=True).decode("utf-8")
except _Error:
raise GraphQLError(f"Base64 cannot decode value: {repr(value)}")

View File

@ -25,7 +25,7 @@ class Date(Scalar):
return date.isoformat() return date.isoformat()
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node, _variables=None):
if not isinstance(node, StringValueNode): if not isinstance(node, StringValueNode):
raise GraphQLError( raise GraphQLError(
f"Date cannot represent non-string value: {print_ast(node)}" f"Date cannot represent non-string value: {print_ast(node)}"
@ -58,7 +58,7 @@ class DateTime(Scalar):
return dt.isoformat() return dt.isoformat()
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node, _variables=None):
if not isinstance(node, StringValueNode): if not isinstance(node, StringValueNode):
raise GraphQLError( raise GraphQLError(
f"DateTime cannot represent non-string value: {print_ast(node)}" f"DateTime cannot represent non-string value: {print_ast(node)}"
@ -93,7 +93,7 @@ class Time(Scalar):
return time.isoformat() return time.isoformat()
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node, _variables=None):
if not isinstance(node, StringValueNode): if not isinstance(node, StringValueNode):
raise GraphQLError( raise GraphQLError(
f"Time cannot represent non-string value: {print_ast(node)}" f"Time cannot represent non-string value: {print_ast(node)}"

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
from decimal import Decimal as _Decimal from decimal import Decimal as _Decimal
from graphql.language.ast import StringValueNode from graphql.language.ast import StringValueNode, IntValueNode
from .scalars import Scalar from .scalars import Scalar
@ -22,8 +22,8 @@ class Decimal(Scalar):
return str(dec) return str(dec)
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node, _variables=None):
if isinstance(node, StringValueNode): if isinstance(node, (StringValueNode, IntValueNode)):
return cls.parse_value(node.value) return cls.parse_value(node.value)
@staticmethod @staticmethod

View File

@ -1,3 +1,5 @@
from enum import Enum as PyEnum
from graphql import ( from graphql import (
GraphQLEnumType, GraphQLEnumType,
GraphQLInputObjectType, GraphQLInputObjectType,
@ -36,7 +38,19 @@ class GrapheneScalarType(GrapheneGraphQLType, GraphQLScalarType):
class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType): class GrapheneEnumType(GrapheneGraphQLType, GraphQLEnumType):
def serialize(self, value):
if not isinstance(value, PyEnum):
enum = self.graphene_type._meta.enum
try:
# Try and get enum by value
value = enum(value)
except ValueError:
# Try and get enum by name
try:
value = enum[value]
except KeyError:
pass pass
return super(GrapheneEnumType, self).serialize(value)
class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType): class GrapheneInputObjectType(GrapheneGraphQLType, GraphQLInputObjectType):

View File

@ -10,10 +10,10 @@ class Dynamic(MountedType):
the schema. So we can have lazy fields. the schema. So we can have lazy fields.
""" """
def __init__(self, type, with_schema=False, _creation_counter=None): def __init__(self, type_, with_schema=False, _creation_counter=None):
super(Dynamic, self).__init__(_creation_counter=_creation_counter) super(Dynamic, self).__init__(_creation_counter=_creation_counter)
assert inspect.isfunction(type) or isinstance(type, partial) assert inspect.isfunction(type_) or isinstance(type_, partial)
self.type = type self.type = type_
self.with_schema = with_schema self.with_schema = with_schema
def get_type(self, schema=None): def get_type(self, schema=None):

View File

@ -21,14 +21,14 @@ class EnumOptions(BaseOptions):
class EnumMeta(SubclassWithMeta_Meta): class EnumMeta(SubclassWithMeta_Meta):
def __new__(cls, name, bases, classdict, **options): def __new__(cls, name_, bases, classdict, **options):
enum_members = dict(classdict, __eq__=eq_enum) enum_members = dict(classdict, __eq__=eq_enum)
# We remove the Meta attribute from the class to not collide # We remove the Meta attribute from the class to not collide
# with the enum values. # with the enum values.
enum_members.pop("Meta", None) enum_members.pop("Meta", None)
enum = PyEnum(cls.__name__, enum_members) enum = PyEnum(cls.__name__, enum_members)
return SubclassWithMeta_Meta.__new__( return SubclassWithMeta_Meta.__new__(
cls, name, bases, dict(classdict, __enum__=enum), **options cls, name_, bases, dict(classdict, __enum__=enum), **options
) )
def get(cls, value): def get(cls, value):
@ -52,7 +52,10 @@ class EnumMeta(SubclassWithMeta_Meta):
return super(EnumMeta, cls).__call__(*args, **kwargs) return super(EnumMeta, cls).__call__(*args, **kwargs)
# return cls._meta.enum(*args, **kwargs) # return cls._meta.enum(*args, **kwargs)
def from_enum(cls, enum, description=None, deprecation_reason=None): # noqa: N805 def from_enum(
cls, enum, name=None, description=None, deprecation_reason=None
): # noqa: N805
name = name or enum.__name__
description = description or enum.__doc__ description = description or enum.__doc__
meta_dict = { meta_dict = {
"enum": enum, "enum": enum,
@ -60,7 +63,7 @@ class EnumMeta(SubclassWithMeta_Meta):
"deprecation_reason": deprecation_reason, "deprecation_reason": deprecation_reason,
} }
meta_class = type("Meta", (object,), meta_dict) meta_class = type("Meta", (object,), meta_dict)
return type(meta_class.enum.__name__, (Enum,), {"Meta": meta_class}) return type(name, (Enum,), {"Meta": meta_class})
class Enum(UnmountedType, BaseType, metaclass=EnumMeta): class Enum(UnmountedType, BaseType, metaclass=EnumMeta):

View File

@ -8,6 +8,7 @@ from .resolver import default_resolver
from .structures import NonNull from .structures import NonNull
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
from .utils import get_type from .utils import get_type
from ..utils.deprecated import warn_deprecation
base_type = type base_type = type
@ -64,7 +65,7 @@ class Field(MountedType):
def __init__( def __init__(
self, self,
type, type_,
args=None, args=None,
resolver=None, resolver=None,
source=None, source=None,
@ -88,7 +89,7 @@ class Field(MountedType):
), f'The default value can not be a function but received "{base_type(default_value)}".' ), f'The default value can not be a function but received "{base_type(default_value)}".'
if required: if required:
type = NonNull(type) type_ = NonNull(type_)
# Check if name is actually an argument of the field # Check if name is actually an argument of the field
if isinstance(name, (Argument, UnmountedType)): if isinstance(name, (Argument, UnmountedType)):
@ -101,7 +102,7 @@ class Field(MountedType):
source = None source = None
self.name = name self.name = name
self._type = type self._type = type_
self.args = to_arguments(args or {}, extra_args) self.args = to_arguments(args or {}, extra_args)
if source: if source:
resolver = partial(source_resolver, source) resolver = partial(source_resolver, source)
@ -114,5 +115,24 @@ class Field(MountedType):
def type(self): def type(self):
return get_type(self._type) return get_type(self._type)
def get_resolver(self, parent_resolver): get_resolver = None
def wrap_resolve(self, parent_resolver):
"""
Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
(parent_resolver) if the Field definition has no resolver.
"""
if self.get_resolver is not None:
warn_deprecation(
"The get_resolver method is being deprecated, please rename it to wrap_resolve."
)
return self.get_resolver(parent_resolver)
return self.resolver or parent_resolver return self.resolver or parent_resolver
def wrap_subscribe(self, parent_subscribe):
"""
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
(parent_subscribe) if the Field definition has no subscribe.
"""
return parent_subscribe

View File

@ -29,7 +29,7 @@ class GenericScalar(Scalar):
parse_value = identity parse_value = identity
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast, _variables=None):
if isinstance(ast, (StringValueNode, BooleanValueNode)): if isinstance(ast, (StringValueNode, BooleanValueNode)):
return ast.value return ast.value
elif isinstance(ast, IntValueNode): elif isinstance(ast, IntValueNode):

View File

@ -48,7 +48,7 @@ class InputField(MountedType):
def __init__( def __init__(
self, self,
type, type_,
name=None, name=None,
default_value=Undefined, default_value=Undefined,
deprecation_reason=None, deprecation_reason=None,
@ -60,8 +60,8 @@ class InputField(MountedType):
super(InputField, self).__init__(_creation_counter=_creation_counter) super(InputField, self).__init__(_creation_counter=_creation_counter)
self.name = name self.name = name
if required: if required:
type = NonNull(type) type_ = NonNull(type_)
self._type = type self._type = type_
self.deprecation_reason = deprecation_reason self.deprecation_reason = deprecation_reason
self.default_value = default_value self.default_value = default_value
self.description = description self.description = description

View File

@ -5,11 +5,12 @@ from .utils import yank_fields_from_attrs
# For static type checking with Mypy # For static type checking with Mypy
MYPY = False MYPY = False
if MYPY: if MYPY:
from typing import Dict # NOQA from typing import Dict, Iterable, Type # NOQA
class InterfaceOptions(BaseOptions): class InterfaceOptions(BaseOptions):
fields = None # type: Dict[str, Field] fields = None # type: Dict[str, Field]
interfaces = () # type: Iterable[Type[Interface]]
class Interface(BaseType): class Interface(BaseType):
@ -45,7 +46,7 @@ class Interface(BaseType):
""" """
@classmethod @classmethod
def __init_subclass_with_meta__(cls, _meta=None, **options): def __init_subclass_with_meta__(cls, _meta=None, interfaces=(), **options):
if not _meta: if not _meta:
_meta = InterfaceOptions(cls) _meta = InterfaceOptions(cls)
@ -58,6 +59,9 @@ class Interface(BaseType):
else: else:
_meta.fields = fields _meta.fields = fields
if not _meta.interfaces:
_meta.interfaces = interfaces
super(Interface, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(Interface, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod @classmethod

View File

@ -20,7 +20,7 @@ class JSONString(Scalar):
return json.dumps(dt) return json.dumps(dt)
@staticmethod @staticmethod
def parse_literal(node): def parse_literal(node, _variables=None):
if isinstance(node, StringValueNode): if isinstance(node, StringValueNode):
return json.loads(node.value) return json.loads(node.value)

View File

@ -76,7 +76,6 @@ class Mutation(ObjectType):
): ):
if not _meta: if not _meta:
_meta = MutationOptions(cls) _meta = MutationOptions(cls)
output = output or getattr(cls, "Output", None) output = output or getattr(cls, "Output", None)
fields = {} fields = {}
@ -85,43 +84,32 @@ class Mutation(ObjectType):
interface, Interface interface, Interface
), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".' ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".'
fields.update(interface._meta.fields) fields.update(interface._meta.fields)
if not output: if not output:
# If output is defined, we don't need to get the fields # If output is defined, we don't need to get the fields
fields = {} fields = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field)) fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
output = cls output = cls
if not arguments: if not arguments:
input_class = getattr(cls, "Arguments", None) input_class = getattr(cls, "Arguments", None)
if not input_class: if not input_class:
input_class = getattr(cls, "Input", None) input_class = getattr(cls, "Input", None)
if input_class: if input_class:
warn_deprecation( warn_deprecation(
(
f"Please use {cls.__name__}.Arguments instead of {cls.__name__}.Input." f"Please use {cls.__name__}.Arguments instead of {cls.__name__}.Input."
" Input is now only used in ClientMutationID.\n" " Input is now only used in ClientMutationID.\n"
"Read more:" "Read more:"
" https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input" " https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input"
) )
) arguments = props(input_class) if input_class else {}
if input_class:
arguments = props(input_class)
else:
arguments = {}
if not resolver: if not resolver:
mutate = getattr(cls, "mutate", None) mutate = getattr(cls, "mutate", None)
assert mutate, "All mutations must define a mutate method in it" assert mutate, "All mutations must define a mutate method in it"
resolver = get_unbound_function(mutate) resolver = get_unbound_function(mutate)
if _meta.fields: if _meta.fields:
_meta.fields.update(fields) _meta.fields.update(fields)
else: else:
_meta.fields = fields _meta.fields = fields
_meta.interfaces = interfaces _meta.interfaces = interfaces
_meta.output = output _meta.output = output
_meta.resolver = resolver _meta.resolver = resolver
@ -133,7 +121,7 @@ class Mutation(ObjectType):
def Field( def Field(
cls, name=None, description=None, deprecation_reason=None, required=False cls, name=None, description=None, deprecation_reason=None, required=False
): ):
""" Mount instance of mutation Field. """ """Mount instance of mutation Field."""
return Field( return Field(
cls._meta.output, cls._meta.output,
args=cls._meta.arguments, args=cls._meta.arguments,

View File

@ -7,7 +7,6 @@ try:
from dataclasses import make_dataclass, field from dataclasses import make_dataclass, field
except ImportError: except ImportError:
from ..pyutils.dataclasses import make_dataclass, field # type: ignore from ..pyutils.dataclasses import make_dataclass, field # type: ignore
# For static type checking with Mypy # For static type checking with Mypy
MYPY = False MYPY = False
if MYPY: if MYPY:
@ -20,12 +19,16 @@ class ObjectTypeOptions(BaseOptions):
class ObjectTypeMeta(BaseTypeMeta): class ObjectTypeMeta(BaseTypeMeta):
def __new__(cls, name, bases, namespace): def __new__(cls, name_, bases, namespace, **options):
# Note: it's safe to pass options as keyword arguments as they are still type-checked by ObjectTypeOptions.
# We create this type, to then overload it with the dataclass attrs # We create this type, to then overload it with the dataclass attrs
class InterObjectType: class InterObjectType:
pass pass
base_cls = super().__new__(cls, name, (InterObjectType,) + bases, namespace) base_cls = super().__new__(
cls, name_, (InterObjectType,) + bases, namespace, **options
)
if base_cls._meta: if base_cls._meta:
fields = [ fields = [
( (
@ -39,7 +42,7 @@ class ObjectTypeMeta(BaseTypeMeta):
) )
for key, field_value in base_cls._meta.fields.items() for key, field_value in base_cls._meta.fields.items()
] ]
dataclass = make_dataclass(name, fields, bases=()) dataclass = make_dataclass(name_, fields, bases=())
InterObjectType.__init__ = dataclass.__init__ InterObjectType.__init__ = dataclass.__init__
InterObjectType.__eq__ = dataclass.__eq__ InterObjectType.__eq__ = dataclass.__eq__
InterObjectType.__repr__ = dataclass.__repr__ InterObjectType.__repr__ = dataclass.__repr__
@ -62,7 +65,7 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta):
Methods starting with ``resolve_<field_name>`` are bound as resolvers of the matching Field Methods starting with ``resolve_<field_name>`` are bound as resolvers of the matching Field
name. If no resolver is provided, the default resolver is used. name. If no resolver is provided, the default resolver is used.
Ambiguous types with Interface and Union can be determined through``is_type_of`` method and Ambiguous types with Interface and Union can be determined through ``is_type_of`` method and
``Meta.possible_types`` attribute. ``Meta.possible_types`` attribute.
.. code:: python .. code:: python
@ -129,7 +132,6 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta):
): ):
if not _meta: if not _meta:
_meta = ObjectTypeOptions(cls) _meta = ObjectTypeOptions(cls)
fields = {} fields = {}
for interface in interfaces: for interface in interfaces:
@ -137,10 +139,8 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta):
interface, Interface interface, Interface
), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".' ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".'
fields.update(interface._meta.fields) fields.update(interface._meta.fields)
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field)) fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
assert not (possible_types and cls.is_type_of), ( assert not (possible_types and cls.is_type_of), (
f"{cls.__name__}.Meta.possible_types will cause type collision with {cls.__name__}.is_type_of. " f"{cls.__name__}.Meta.possible_types will cause type collision with {cls.__name__}.is_type_of. "
"Please use one or other." "Please use one or other."
@ -150,7 +150,6 @@ class ObjectType(BaseType, metaclass=ObjectTypeMeta):
_meta.fields.update(fields) _meta.fields.update(fields)
else: else:
_meta.fields = fields _meta.fields = fields
if not _meta.interfaces: if not _meta.interfaces:
_meta.interfaces = interfaces _meta.interfaces = interfaces
_meta.possible_types = possible_types _meta.possible_types = possible_types

View File

@ -7,9 +7,7 @@ def dict_resolver(attname, default_value, root, info, **args):
def dict_or_attr_resolver(attname, default_value, root, info, **args): def dict_or_attr_resolver(attname, default_value, root, info, **args):
resolver = attr_resolver resolver = dict_resolver if isinstance(root, dict) else attr_resolver
if isinstance(root, dict):
resolver = dict_resolver
return resolver(attname, default_value, root, info, **args) return resolver(attname, default_value, root, info, **args)

View File

@ -75,13 +75,40 @@ class Int(Scalar):
parse_value = coerce_int parse_value = coerce_int
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast, _variables=None):
if isinstance(ast, IntValueNode): if isinstance(ast, IntValueNode):
num = int(ast.value) num = int(ast.value)
if MIN_INT <= num <= MAX_INT: if MIN_INT <= num <= MAX_INT:
return num return num
class BigInt(Scalar):
"""
The `BigInt` scalar type represents non-fractional whole numeric values.
`BigInt` is not constrained to 32-bit like the `Int` type and thus is a less
compatible type.
"""
@staticmethod
def coerce_int(value):
try:
num = int(value)
except ValueError:
try:
num = int(float(value))
except ValueError:
return None
return num
serialize = coerce_int
parse_value = coerce_int
@staticmethod
def parse_literal(ast, _variables=None):
if isinstance(ast, IntValueNode):
return int(ast.value)
class Float(Scalar): class Float(Scalar):
""" """
The `Float` scalar type represents signed double-precision fractional The `Float` scalar type represents signed double-precision fractional
@ -101,7 +128,7 @@ class Float(Scalar):
parse_value = coerce_float parse_value = coerce_float
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast, _variables=None):
if isinstance(ast, (FloatValueNode, IntValueNode)): if isinstance(ast, (FloatValueNode, IntValueNode)):
return float(ast.value) return float(ast.value)
@ -116,14 +143,14 @@ class String(Scalar):
@staticmethod @staticmethod
def coerce_string(value): def coerce_string(value):
if isinstance(value, bool): if isinstance(value, bool):
return u"true" if value else u"false" return "true" if value else "false"
return str(value) return str(value)
serialize = coerce_string serialize = coerce_string
parse_value = coerce_string parse_value = coerce_string
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast, _variables=None):
if isinstance(ast, StringValueNode): if isinstance(ast, StringValueNode):
return ast.value return ast.value
@ -137,7 +164,7 @@ class Boolean(Scalar):
parse_value = bool parse_value = bool
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast, _variables=None):
if isinstance(ast, BooleanValueNode): if isinstance(ast, BooleanValueNode):
return ast.value return ast.value
@ -155,6 +182,6 @@ class ID(Scalar):
parse_value = str parse_value = str
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast, _variables=None):
if isinstance(ast, (StringValueNode, IntValueNode)): if isinstance(ast, (StringValueNode, IntValueNode)):
return ast.value return ast.value

View File

@ -10,8 +10,11 @@ from graphql import (
parse, parse,
print_schema, print_schema,
subscribe, subscribe,
validate,
ExecutionResult,
GraphQLArgument, GraphQLArgument,
GraphQLBoolean, GraphQLBoolean,
GraphQLError,
GraphQLEnumValue, GraphQLEnumValue,
GraphQLField, GraphQLField,
GraphQLFloat, GraphQLFloat,
@ -23,7 +26,6 @@ from graphql import (
GraphQLObjectType, GraphQLObjectType,
GraphQLSchema, GraphQLSchema,
GraphQLString, GraphQLString,
Undefined,
) )
from ..utils.str_converters import to_camel_case from ..utils.str_converters import to_camel_case
@ -76,6 +78,11 @@ def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types) return isinstance(root, possible_types)
# We use this resolver for subscriptions
def identity_resolve(root, info, **arguments):
return root
class TypeMap(dict): class TypeMap(dict):
def __init__( def __init__(
self, self,
@ -172,7 +179,7 @@ class TypeMap(dict):
deprecation_reason = graphene_type._meta.deprecation_reason(value) deprecation_reason = graphene_type._meta.deprecation_reason(value)
values[name] = GraphQLEnumValue( values[name] = GraphQLEnumValue(
value=value.value, value=value,
description=description, description=description,
deprecation_reason=deprecation_reason, deprecation_reason=deprecation_reason,
) )
@ -226,11 +233,20 @@ class TypeMap(dict):
else None else None
) )
def interfaces():
interfaces = []
for graphene_interface in graphene_type._meta.interfaces:
interface = self.add_type(graphene_interface)
assert interface.graphene_type == graphene_interface
interfaces.append(interface)
return interfaces
return GrapheneInterfaceType( return GrapheneInterfaceType(
graphene_type=graphene_type, graphene_type=graphene_type,
name=graphene_type._meta.name, name=graphene_type._meta.name,
description=graphene_type._meta.description, description=graphene_type._meta.description,
fields=partial(self.create_fields_for_type, graphene_type), fields=partial(self.create_fields_for_type, graphene_type),
interfaces=interfaces,
resolve_type=resolve_type, resolve_type=resolve_type,
) )
@ -303,26 +319,41 @@ class TypeMap(dict):
arg_type, arg_type,
out_name=arg_name, out_name=arg_name,
description=arg.description, description=arg.description,
default_value=Undefined default_value=arg.default_value,
if isinstance(arg.type, NonNull)
else arg.default_value,
) )
subscribe = field.wrap_subscribe(
self.get_function_for_type(
graphene_type, f"subscribe_{name}", name, field.default_value
)
)
# If we are in a subscription, we use (by default) an
# identity-based resolver for the root, rather than the
# default resolver for objects/dicts.
if subscribe:
field_default_resolver = identity_resolve
elif issubclass(graphene_type, ObjectType):
default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
field_default_resolver = partial(
default_resolver, name, field.default_value
)
else:
field_default_resolver = None
resolve = field.wrap_resolve(
self.get_function_for_type(
graphene_type, f"resolve_{name}", name, field.default_value
)
or field_default_resolver
)
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolve=field.get_resolver( resolve=resolve,
self.get_resolver_for_type( subscribe=subscribe,
graphene_type, f"resolve_{name}", name, field.default_value
)
),
subscribe=field.get_resolver(
self.get_resolver_for_type(
graphene_type,
f"subscribe_{name}",
name,
field.default_value,
)
),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description, description=field.description,
) )
@ -330,7 +361,8 @@ class TypeMap(dict):
fields[field_name] = _field fields[field_name] = _field
return fields return fields
def get_resolver_for_type(self, graphene_type, func_name, name, default_value): def get_function_for_type(self, graphene_type, func_name, name, default_value):
"""Gets a resolve or subscribe function for a given ObjectType"""
if not issubclass(graphene_type, ObjectType): if not issubclass(graphene_type, ObjectType):
return return
resolver = getattr(graphene_type, func_name, None) resolver = getattr(graphene_type, func_name, None)
@ -350,36 +382,21 @@ class TypeMap(dict):
if resolver: if resolver:
return get_unbound_function(resolver) return get_unbound_function(resolver)
default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
return partial(default_resolver, name, default_value)
def resolve_type(self, resolve_type_func, type_name, root, info, _type): def resolve_type(self, resolve_type_func, type_name, root, info, _type):
type_ = resolve_type_func(root, info) type_ = resolve_type_func(root, info)
if not type_: if inspect.isclass(type_) and issubclass(type_, ObjectType):
return type_._meta.name
return_type = self[type_name] return_type = self[type_name]
return default_type_resolver(root, info, return_type) return default_type_resolver(root, info, return_type)
if inspect.isclass(type_) and issubclass(type_, ObjectType):
graphql_type = self.get(type_._meta.name)
assert graphql_type, f"Can't find type {type_._meta.name} in schema"
assert (
graphql_type.graphene_type == type_
), f"The type {type_} does not match with the associated graphene type {graphql_type.graphene_type}."
return graphql_type
return type_
class Schema: class Schema:
"""Schema Definition. """Schema Definition.
A Graphene Schema can execute operations (query, mutation, subscription) against the defined A Graphene Schema can execute operations (query, mutation, subscription) against the defined
types. For advanced purposes, the schema can be used to lookup type definitions and answer types. For advanced purposes, the schema can be used to lookup type definitions and answer
questions about the types through introspection. questions about the types through introspection.
Args: Args:
query (Type[ObjectType]): Root query *ObjectType*. Describes entry point for fields to *read* query (Type[ObjectType]): Root query *ObjectType*. Describes entry point for fields to *read*
data in your Schema. data in your Schema.
@ -426,7 +443,6 @@ class Schema:
""" """
This function let the developer select a type in a given schema This function let the developer select a type in a given schema
by accessing its attrs. by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema Example: using schema.Query for accessing the "Query" type in the Schema
""" """
_type = self.graphql_schema.get_type(type_name) _type = self.graphql_schema.get_type(type_name)
@ -441,11 +457,9 @@ class Schema:
def execute(self, *args, **kwargs): def execute(self, *args, **kwargs):
"""Execute a GraphQL query on the schema. """Execute a GraphQL query on the schema.
Use the `graphql_sync` function from `graphql-core` to provide the result Use the `graphql_sync` function from `graphql-core` to provide the result
for a query string. Most of the time this method will be called by one of the Graphene for a query string. Most of the time this method will be called by one of the Graphene
:ref:`Integrations` via a web request. :ref:`Integrations` via a web request.
Args: Args:
request_string (str or Document): GraphQL request (query, mutation or subscription) request_string (str or Document): GraphQL request (query, mutation or subscription)
as string or parsed AST form from `graphql-core`. as string or parsed AST form from `graphql-core`.
@ -460,7 +474,8 @@ class Schema:
request_string, an operation name must be provided for the result to be provided. request_string, an operation name must be provided for the result to be provided.
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
defined in `graphql-core`. defined in `graphql-core`.
execution_context_class (ExecutionContext, optional): The execution context class
to use when resolving queries and mutations.
Returns: Returns:
:obj:`ExecutionResult` containing any data and errors for the operation. :obj:`ExecutionResult` containing any data and errors for the operation.
""" """
@ -469,14 +484,25 @@ class Schema:
async def execute_async(self, *args, **kwargs): async def execute_async(self, *args, **kwargs):
"""Execute a GraphQL query on the schema asynchronously. """Execute a GraphQL query on the schema asynchronously.
Same as `execute`, but uses `graphql` instead of `graphql_sync`. Same as `execute`, but uses `graphql` instead of `graphql_sync`.
""" """
kwargs = normalize_execute_kwargs(kwargs) kwargs = normalize_execute_kwargs(kwargs)
return await graphql(self.graphql_schema, *args, **kwargs) return await graphql(self.graphql_schema, *args, **kwargs)
async def subscribe(self, query, *args, **kwargs): async def subscribe(self, query, *args, **kwargs):
"""Execute a GraphQL subscription on the schema asynchronously."""
# Do parsing
try:
document = parse(query) document = parse(query)
except GraphQLError as error:
return ExecutionResult(data=None, errors=[error])
# Do validation
validation_errors = validate(self.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
# Execute the query
kwargs = normalize_execute_kwargs(kwargs) kwargs = normalize_execute_kwargs(kwargs)
return await subscribe(self.graphql_schema, document, *args, **kwargs) return await subscribe(self.graphql_schema, document, *args, **kwargs)

View File

@ -0,0 +1,97 @@
import base64
from graphql import GraphQLError
from ..objecttype import ObjectType
from ..scalars import String
from ..schema import Schema
from ..base64 import Base64
class Query(ObjectType):
base64 = Base64(_in=Base64(name="input"), _match=String(name="match"))
bytes_as_base64 = Base64()
string_as_base64 = Base64()
number_as_base64 = Base64()
def resolve_base64(self, info, _in=None, _match=None):
if _match:
assert _in == _match
return _in
def resolve_bytes_as_base64(self, info):
return b"Hello world"
def resolve_string_as_base64(self, info):
return "Spam and eggs"
def resolve_number_as_base64(self, info):
return 42
schema = Schema(query=Query)
def test_base64_query():
base64_value = base64.b64encode(b"Random string").decode("utf-8")
result = schema.execute(
"""{{ base64(input: "{}", match: "Random string") }}""".format(base64_value)
)
assert not result.errors
assert result.data == {"base64": base64_value}
def test_base64_query_with_variable():
base64_value = base64.b64encode(b"Another string").decode("utf-8")
# test datetime variable in string representation
result = schema.execute(
"""
query GetBase64($base64: Base64) {
base64(input: $base64, match: "Another string")
}
""",
variables={"base64": base64_value},
)
assert not result.errors
assert result.data == {"base64": base64_value}
def test_base64_query_none():
result = schema.execute("""{ base64 }""")
assert not result.errors
assert result.data == {"base64": None}
def test_base64_query_invalid():
bad_inputs = [dict(), 123, "This is not valid base64"]
for input_ in bad_inputs:
result = schema.execute(
"""{ base64(input: $input) }""", variables={"input": input_}
)
assert isinstance(result.errors, list)
assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError)
assert result.data is None
def test_base64_from_bytes():
base64_value = base64.b64encode(b"Hello world").decode("utf-8")
result = schema.execute("""{ bytesAsBase64 }""")
assert not result.errors
assert result.data == {"bytesAsBase64": base64_value}
def test_base64_from_string():
base64_value = base64.b64encode(b"Spam and eggs").decode("utf-8")
result = schema.execute("""{ stringAsBase64 }""")
assert not result.errors
assert result.data == {"stringAsBase64": base64_value}
def test_base64_from_number():
base64_value = base64.b64encode(b"42").decode("utf-8")
result = schema.execute("""{ numberAsBase64 }""")
assert not result.errors
assert result.data == {"numberAsBase64": base64_value}

View File

@ -60,6 +60,23 @@ def test_datetime_query(sample_datetime):
assert result.data == {"datetime": isoformat} assert result.data == {"datetime": isoformat}
def test_datetime_query_with_variables(sample_datetime):
isoformat = sample_datetime.isoformat()
result = schema.execute(
"""
query GetDate($datetime: DateTime) {
literal: datetime(in: "%s")
value: datetime(in: $datetime)
}
"""
% isoformat,
variable_values={"datetime": isoformat},
)
assert not result.errors
assert result.data == {"literal": isoformat, "value": isoformat}
def test_date_query(sample_date): def test_date_query(sample_date):
isoformat = sample_date.isoformat() isoformat = sample_date.isoformat()
@ -68,6 +85,23 @@ def test_date_query(sample_date):
assert result.data == {"date": isoformat} assert result.data == {"date": isoformat}
def test_date_query_with_variables(sample_date):
isoformat = sample_date.isoformat()
result = schema.execute(
"""
query GetDate($date: Date) {
literal: date(in: "%s")
value: date(in: $date)
}
"""
% isoformat,
variable_values={"date": isoformat},
)
assert not result.errors
assert result.data == {"literal": isoformat, "value": isoformat}
def test_time_query(sample_time): def test_time_query(sample_time):
isoformat = sample_time.isoformat() isoformat = sample_time.isoformat()
@ -76,6 +110,23 @@ def test_time_query(sample_time):
assert result.data == {"time": isoformat} assert result.data == {"time": isoformat}
def test_time_query_with_variables(sample_time):
isoformat = sample_time.isoformat()
result = schema.execute(
"""
query GetTime($time: Time) {
literal: time(at: "%s")
value: time(at: $time)
}
"""
% isoformat,
variable_values={"time": isoformat},
)
assert not result.errors
assert result.data == {"literal": isoformat, "value": isoformat}
def test_bad_datetime_query(): def test_bad_datetime_query():
not_a_date = "Some string that's not a datetime" not_a_date = "Some string that's not a datetime"

View File

@ -41,3 +41,11 @@ def test_bad_decimal_query():
result = schema.execute("""{ decimal(input: "%s") }""" % not_a_decimal) result = schema.execute("""{ decimal(input: "%s") }""" % not_a_decimal)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert result.data is None assert result.data is None
def test_decimal_string_query_integer():
decimal_value = 1
result = schema.execute("""{ decimal(input: %s) }""" % decimal_value)
assert not result.errors
assert result.data == {"decimal": str(decimal_value)}
assert decimal.Decimal(result.data["decimal"]) == decimal_value

View File

@ -234,10 +234,10 @@ def test_stringifies_simple_types():
# (InputObjectType, True) # (InputObjectType, True)
# ) # )
# for type, answer in expected: # for type_, answer in expected:
# assert is_input_type(type) == answer # assert is_input_type(type_) == answer
# assert is_input_type(GraphQLList(type)) == answer # assert is_input_type(GraphQLList(type_)) == answer
# assert is_input_type(GraphQLNonNull(type)) == answer # assert is_input_type(GraphQLNonNull(type_)) == answer
# def test_identifies_output_types(): # def test_identifies_output_types():

View File

@ -1,7 +1,12 @@
from textwrap import dedent
from ..argument import Argument from ..argument import Argument
from ..enum import Enum, PyEnum from ..enum import Enum, PyEnum
from ..field import Field from ..field import Field
from ..inputfield import InputField from ..inputfield import InputField
from ..inputobjecttype import InputObjectType
from ..mutation import Mutation
from ..scalars import String
from ..schema import ObjectType, Schema from ..schema import ObjectType, Schema
@ -21,8 +26,8 @@ def test_enum_construction():
assert RGB._meta.description == "Description" assert RGB._meta.description == "Description"
values = RGB._meta.enum.__members__.values() values = RGB._meta.enum.__members__.values()
assert sorted([v.name for v in values]) == ["BLUE", "GREEN", "RED"] assert sorted(v.name for v in values) == ["BLUE", "GREEN", "RED"]
assert sorted([v.description for v in values]) == [ assert sorted(v.description for v in values) == [
"Description BLUE", "Description BLUE",
"Description GREEN", "Description GREEN",
"Description RED", "Description RED",
@ -47,7 +52,7 @@ def test_enum_instance_construction():
RGB = Enum("RGB", "RED,GREEN,BLUE") RGB = Enum("RGB", "RED,GREEN,BLUE")
values = RGB._meta.enum.__members__.values() values = RGB._meta.enum.__members__.values()
assert sorted([v.name for v in values]) == ["BLUE", "GREEN", "RED"] assert sorted(v.name for v in values) == ["BLUE", "GREEN", "RED"]
def test_enum_from_builtin_enum(): def test_enum_from_builtin_enum():
@ -224,3 +229,292 @@ def test_enum_skip_meta_from_members():
"GREEN": RGB1.GREEN, "GREEN": RGB1.GREEN,
"BLUE": RGB1.BLUE, "BLUE": RGB1.BLUE,
} }
def test_enum_types():
from enum import Enum as PyEnum
class Color(PyEnum):
"""Primary colors"""
RED = 1
YELLOW = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
def resolve_color(_, info):
return Color.RED
schema = Schema(query=Query)
assert (
str(schema).strip()
== dedent(
'''
type Query {
color: Color!
}
"""Primary colors"""
enum Color {
RED
YELLOW
BLUE
}
'''
).strip()
)
def test_enum_resolver():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
def resolve_color(_, info):
return Color.RED
schema = Schema(query=Query)
results = schema.execute("query { color }")
assert not results.errors
assert results.data["color"] == Color.RED.name
def test_enum_resolver_compat():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
color_by_name = GColor(required=True)
def resolve_color(_, info):
return Color.RED.value
def resolve_color_by_name(_, info):
return Color.RED.name
schema = Schema(query=Query)
results = schema.execute(
"""query {
color
colorByName
}"""
)
assert not results.errors
assert results.data["color"] == Color.RED.name
assert results.data["colorByName"] == Color.RED.name
def test_enum_with_name():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
YELLOW = 2
BLUE = 3
GColor = Enum.from_enum(Color, description="original colors")
UniqueGColor = Enum.from_enum(
Color, name="UniqueColor", description="unique colors"
)
class Query(ObjectType):
color = GColor(required=True)
unique_color = UniqueGColor(required=True)
schema = Schema(query=Query)
assert (
str(schema).strip()
== dedent(
'''
type Query {
color: Color!
uniqueColor: UniqueColor!
}
"""original colors"""
enum Color {
RED
YELLOW
BLUE
}
"""unique colors"""
enum UniqueColor {
RED
YELLOW
BLUE
}
'''
).strip()
)
def test_enum_resolver_invalid():
from enum import Enum as PyEnum
class Color(PyEnum):
RED = 1
GREEN = 2
BLUE = 3
GColor = Enum.from_enum(Color)
class Query(ObjectType):
color = GColor(required=True)
def resolve_color(_, info):
return "BLACK"
schema = Schema(query=Query)
results = schema.execute("query { color }")
assert results.errors
assert results.errors[0].message == "Enum 'Color' cannot represent value: 'BLACK'"
def test_field_enum_argument():
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
class Brick(ObjectType):
color = Color(required=True)
color_filter = None
class Query(ObjectType):
bricks_by_color = Field(Brick, color=Color(required=True))
def resolve_bricks_by_color(_, info, color):
nonlocal color_filter
color_filter = color
return Brick(color=color)
schema = Schema(query=Query)
results = schema.execute(
"""
query {
bricksByColor(color: RED) {
color
}
}
"""
)
assert not results.errors
assert results.data == {"bricksByColor": {"color": "RED"}}
assert color_filter == Color.RED
def test_mutation_enum_input():
class RGB(Enum):
"""Available colors"""
RED = 1
GREEN = 2
BLUE = 3
color_input = None
class CreatePaint(Mutation):
class Arguments:
color = RGB(required=True)
color = RGB(required=True)
def mutate(_, info, color):
nonlocal color_input
color_input = color
return CreatePaint(color=color)
class MyMutation(ObjectType):
create_paint = CreatePaint.Field()
class Query(ObjectType):
a = String()
schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(
""" mutation MyMutation {
createPaint(color: RED) {
color
}
}
"""
)
assert not result.errors
assert result.data == {"createPaint": {"color": "RED"}}
assert color_input == RGB.RED
def test_mutation_enum_input_type():
class RGB(Enum):
"""Available colors"""
RED = 1
GREEN = 2
BLUE = 3
class ColorInput(InputObjectType):
color = RGB(required=True)
color_input_value = None
class CreatePaint(Mutation):
class Arguments:
color_input = ColorInput(required=True)
color = RGB(required=True)
def mutate(_, info, color_input):
nonlocal color_input_value
color_input_value = color_input.color
return CreatePaint(color=color_input.color)
class MyMutation(ObjectType):
create_paint = CreatePaint.Field()
class Query(ObjectType):
a = String()
schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(
"""
mutation MyMutation {
createPaint(colorInput: { color: RED }) {
color
}
}
"""
)
assert not result.errors
assert result.data == {"createPaint": {"color": "RED"}}
assert color_input_value == RGB.RED

View File

@ -25,13 +25,18 @@ def test_generate_interface():
def test_generate_interface_with_meta(): def test_generate_interface_with_meta():
class MyFirstInterface(Interface):
pass
class MyInterface(Interface): class MyInterface(Interface):
class Meta: class Meta:
name = "MyOtherInterface" name = "MyOtherInterface"
description = "Documentation" description = "Documentation"
interfaces = [MyFirstInterface]
assert MyInterface._meta.name == "MyOtherInterface" assert MyInterface._meta.name == "MyOtherInterface"
assert MyInterface._meta.description == "Documentation" assert MyInterface._meta.description == "Documentation"
assert MyInterface._meta.interfaces == [MyFirstInterface]
def test_generate_interface_with_fields(): def test_generate_interface_with_fields():

View File

@ -191,21 +191,15 @@ def test_objecttype_as_container_all_kwargs():
def test_objecttype_as_container_extra_args(): def test_objecttype_as_container_extra_args():
with raises(TypeError) as excinfo: msg = r"__init__\(\) takes from 1 to 3 positional arguments but 4 were given"
Container("1", "2", "3") with raises(TypeError, match=msg):
Container("1", "2", "3") # type: ignore
assert "__init__() takes from 1 to 3 positional arguments but 4 were given" == str(
excinfo.value
)
def test_objecttype_as_container_invalid_kwargs(): def test_objecttype_as_container_invalid_kwargs():
with raises(TypeError) as excinfo: msg = r"__init__\(\) got an unexpected keyword argument 'unexisting_field'"
Container(unexisting_field="3") with raises(TypeError, match=msg):
Container(unexisting_field="3") # type: ignore
assert "__init__() got an unexpected keyword argument 'unexisting_field'" == str(
excinfo.value
)
def test_objecttype_container_benchmark(benchmark): def test_objecttype_container_benchmark(benchmark):
@ -295,3 +289,21 @@ def test_objecttype_meta_with_annotations():
schema = Schema(query=Query) schema = Schema(query=Query)
assert schema is not None assert schema is not None
def test_objecttype_meta_arguments():
class MyInterface(Interface):
foo = String()
class MyType(ObjectType, interfaces=[MyInterface]):
bar = String()
assert MyType._meta.interfaces == [MyInterface]
assert list(MyType._meta.fields.keys()) == ["foo", "bar"]
def test_objecttype_type_name():
class MyObjectType(ObjectType, name="FooType"):
pass
assert MyObjectType._meta.name == "FooType"

View File

@ -229,11 +229,11 @@ def test_query_arguments():
result = test_schema.execute("{ test }", None) result = test_schema.execute("{ test }", None)
assert not result.errors assert not result.errors
assert result.data == {"test": '[null,{"a_str":null,"a_int":null}]'} assert result.data == {"test": "[null,{}]"}
result = test_schema.execute('{ test(aStr: "String!") }', "Source!") result = test_schema.execute('{ test(aStr: "String!") }', "Source!")
assert not result.errors assert not result.errors
assert result.data == {"test": '["Source!",{"a_str":"String!","a_int":null}]'} assert result.data == {"test": '["Source!",{"a_str":"String!"}]'}
result = test_schema.execute('{ test(aInt: -123, aStr: "String!") }', "Source!") result = test_schema.execute('{ test(aInt: -123, aStr: "String!") }', "Source!")
assert not result.errors assert not result.errors
@ -258,7 +258,7 @@ def test_query_input_field():
result = test_schema.execute("{ test }", None) result = test_schema.execute("{ test }", None)
assert not result.errors assert not result.errors
assert result.data == {"test": '[null,{"a_input":null}]'} assert result.data == {"test": "[null,{}]"}
result = test_schema.execute('{ test(aInput: {aField: "String!"} ) }', "Source!") result = test_schema.execute('{ test(aInput: {aField: "String!"} ) }', "Source!")
assert not result.errors assert not result.errors

View File

@ -1,4 +1,5 @@
from ..scalars import Scalar from ..scalars import Scalar, Int, BigInt
from graphql.language.ast import IntValueNode
def test_scalar(): def test_scalar():
@ -7,3 +8,22 @@ def test_scalar():
assert JSONScalar._meta.name == "JSONScalar" assert JSONScalar._meta.name == "JSONScalar"
assert JSONScalar._meta.description == "Documentation" assert JSONScalar._meta.description == "Documentation"
def test_ints():
assert Int.parse_value(2**31 - 1) is not None
assert Int.parse_value("2.0") is not None
assert Int.parse_value(2**31) is None
assert Int.parse_literal(IntValueNode(value=str(2**31 - 1))) == 2**31 - 1
assert Int.parse_literal(IntValueNode(value=str(2**31))) is None
assert Int.parse_value(-(2**31)) is not None
assert Int.parse_value(-(2**31) - 1) is None
assert BigInt.parse_value(2**31) is not None
assert BigInt.parse_value("2.0") is not None
assert BigInt.parse_value(-(2**31) - 1) is not None
assert BigInt.parse_literal(IntValueNode(value=str(2**31 - 1))) == 2**31 - 1
assert BigInt.parse_literal(IntValueNode(value=str(2**31))) == 2**31

View File

@ -38,7 +38,7 @@ def test_serializes_output_string():
assert String.serialize(-1.1) == "-1.1" assert String.serialize(-1.1) == "-1.1"
assert String.serialize(True) == "true" assert String.serialize(True) == "true"
assert String.serialize(False) == "false" assert String.serialize(False) == "false"
assert String.serialize(u"\U0001F601") == u"\U0001F601" assert String.serialize("\U0001F601") == "\U0001F601"
def test_serializes_output_boolean(): def test_serializes_output_boolean():

View File

@ -1,7 +1,8 @@
from textwrap import dedent
from pytest import raises from pytest import raises
from graphql.type import GraphQLObjectType, GraphQLSchema from graphql.type import GraphQLObjectType, GraphQLSchema
from graphql.pyutils import dedent
from ..field import Field from ..field import Field
from ..objecttype import ObjectType from ..objecttype import ObjectType
@ -43,7 +44,9 @@ def test_schema_get_type_error():
def test_schema_str(): def test_schema_str():
schema = Schema(Query) schema = Schema(Query)
assert str(schema) == dedent( assert (
str(schema).strip()
== dedent(
""" """
type Query { type Query {
inner: MyOtherType inner: MyOtherType
@ -53,9 +56,19 @@ def test_schema_str():
field: String field: String
} }
""" """
).strip()
) )
def test_schema_introspect(): def test_schema_introspect():
schema = Schema(Query) schema = Schema(Query)
assert "__schema" in schema.introspect() assert "__schema" in schema.introspect()
def test_schema_requires_query_type():
schema = Schema()
result = schema.execute("query {}")
assert len(result.errors) == 1
error = result.errors[0]
assert error.message == "Query root type must be provided."

View File

@ -0,0 +1,78 @@
from pytest import mark
from graphene import ObjectType, Int, String, Schema, Field
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return "Hello, world!"
class Subscription(ObjectType):
count_to_ten = Field(Int)
async def subscribe_count_to_ten(root, info):
for count in range(1, 11):
yield count
schema = Schema(query=Query, subscription=Subscription)
@mark.asyncio
async def test_subscription():
subscription = "subscription { countToTen }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countToTen"]
assert count == 10
@mark.asyncio
async def test_subscription_fails_with_invalid_query():
# It fails if the provided query is invalid
subscription = "subscription { "
result = await schema.subscribe(subscription)
assert not result.data
assert result.errors
assert "Syntax Error: Expected Name, found <EOF>" in str(result.errors[0])
@mark.asyncio
async def test_subscription_fails_when_query_is_not_valid():
# It can't subscribe to two fields at the same time, triggering a
# validation error.
subscription = "subscription { countToTen, b: countToTen }"
result = await schema.subscribe(subscription)
assert not result.data
assert result.errors
assert "Anonymous Subscription must select only one top level field." in str(
result.errors[0]
)
@mark.asyncio
async def test_subscription_with_args():
class Query(ObjectType):
hello = String()
class Subscription(ObjectType):
count_upwards = Field(Int, limit=Int(required=True))
async def subscribe_count_upwards(root, info, limit):
count = 0
while count < limit:
count += 1
yield count
schema = Schema(query=Query, subscription=Subscription)
subscription = "subscription { countUpwards(limit: 5) }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countUpwards"]
assert count == 5

View File

@ -1,3 +1,4 @@
from graphql import Undefined
from graphql.type import ( from graphql.type import (
GraphQLArgument, GraphQLArgument,
GraphQLEnumType, GraphQLEnumType,
@ -6,6 +7,7 @@ from graphql.type import (
GraphQLInputField, GraphQLInputField,
GraphQLInputObjectType, GraphQLInputObjectType,
GraphQLInterfaceType, GraphQLInterfaceType,
GraphQLNonNull,
GraphQLObjectType, GraphQLObjectType,
GraphQLString, GraphQLString,
) )
@ -94,6 +96,21 @@ def test_objecttype():
} }
def test_required_argument_with_default_value():
class MyObjectType(ObjectType):
foo = String(bar=String(required=True, default_value="x"))
type_map = create_type_map([MyObjectType])
graphql_type = type_map["MyObjectType"]
foo_field = graphql_type.fields["foo"]
bar_argument = foo_field.args["bar"]
assert bar_argument.default_value == "x"
assert isinstance(bar_argument.type, GraphQLNonNull)
assert bar_argument.type.of_type == GraphQLString
def test_dynamic_objecttype(): def test_dynamic_objecttype():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
"""Description""" """Description"""
@ -228,7 +245,9 @@ def test_objecttype_camelcase():
foo_field = fields["fooBar"] foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.args == { assert foo_field.args == {
"barFoo": GraphQLArgument(GraphQLString, default_value=None, out_name="bar_foo") "barFoo": GraphQLArgument(
GraphQLString, default_value=Undefined, out_name="bar_foo"
)
} }
@ -251,7 +270,7 @@ def test_objecttype_camelcase_disabled():
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.args == { assert foo_field.args == {
"bar_foo": GraphQLArgument( "bar_foo": GraphQLArgument(
GraphQLString, default_value=None, out_name="bar_foo" GraphQLString, default_value=Undefined, out_name="bar_foo"
) )
} }
@ -270,3 +289,33 @@ def test_objecttype_with_possible_types():
assert graphql_type.is_type_of assert graphql_type.is_type_of
assert graphql_type.is_type_of({}, None) is True assert graphql_type.is_type_of({}, None) is True
assert graphql_type.is_type_of(MyObjectType(), None) is False assert graphql_type.is_type_of(MyObjectType(), None) is False
def test_interface_with_interfaces():
class FooInterface(Interface):
foo = String()
class BarInterface(Interface):
class Meta:
interfaces = [FooInterface]
foo = String()
bar = String()
type_map = create_type_map([FooInterface, BarInterface])
assert "FooInterface" in type_map
foo_graphql_type = type_map["FooInterface"]
assert isinstance(foo_graphql_type, GraphQLInterfaceType)
assert foo_graphql_type.name == "FooInterface"
assert "BarInterface" in type_map
bar_graphql_type = type_map["BarInterface"]
assert isinstance(bar_graphql_type, GraphQLInterfaceType)
assert bar_graphql_type.name == "BarInterface"
fields = bar_graphql_type.fields
assert list(fields) == ["foo", "bar"]
assert isinstance(fields["foo"], GraphQLField)
assert isinstance(fields["bar"], GraphQLField)
assert list(bar_graphql_type.interfaces) == list([foo_graphql_type])

View File

@ -21,7 +21,7 @@ class Union(UnmountedType, BaseType):
to determine which type is actually used when the field is resolved. to determine which type is actually used when the field is resolved.
The schema in this example can take a search text and return any of the GraphQL object types The schema in this example can take a search text and return any of the GraphQL object types
indicated: Human, Droid or Startship. indicated: Human, Droid or Starship.
Ambiguous return types can be resolved on each ObjectType through ``Meta.possible_types`` Ambiguous return types can be resolved on each ObjectType through ``Meta.possible_types``
attribute or ``is_type_of`` method. Or by implementing ``resolve_type`` class method on the attribute or ``is_type_of`` method. Or by implementing ``resolve_type`` class method on the

View File

@ -41,3 +41,10 @@ def get_type(_type):
if inspect.isfunction(_type) or isinstance(_type, partial): if inspect.isfunction(_type) or isinstance(_type, partial):
return _type() return _type()
return _type return _type
def get_underlying_type(_type):
"""Get the underlying type even if it is wrapped in structures like NonNull"""
while hasattr(_type, "of_type"):
_type = _type.of_type
return _type

View File

@ -21,7 +21,7 @@ class UUID(Scalar):
return str(uuid) return str(uuid)
@staticmethod @staticmethod
def parse_literal(node): def parse_literal(node, _variables=None):
if isinstance(node, StringValueNode): if isinstance(node, StringValueNode):
return _UUID(node.value) return _UUID(node.value)

View File

@ -0,0 +1,6 @@
def is_introspection_key(key):
# from: https://spec.graphql.org/June2018/#sec-Schema
# > All types and directives defined within a schema must not have a name which
# > begins with "__" (two underscores), as this is used exclusively
# > by GraphQLs introspection system.
return str(key).startswith("__")

View File

@ -27,7 +27,6 @@ def import_string(dotted_path, dotted_attributes=None):
if not dotted_attributes: if not dotted_attributes:
return result return result
else:
attributes = dotted_attributes.split(".") attributes = dotted_attributes.split(".")
traveled_attributes = [] traveled_attributes = []
try: try:

View File

@ -36,4 +36,4 @@ class OrderedType:
return NotImplemented return NotImplemented
def __hash__(self): def __hash__(self):
return hash((self.creation_counter)) return hash(self.creation_counter)

View File

@ -1,5 +1,4 @@
import re import re
from unidecode import unidecode
# Adapted from this response in Stackoverflow # Adapted from this response in Stackoverflow
@ -16,7 +15,3 @@ def to_camel_case(snake_str):
def to_snake_case(name): def to_snake_case(name):
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def to_const(string):
return re.sub(r"[\W|^]+", "_", unidecode(string)).upper()

View File

@ -94,6 +94,7 @@ TEST_DATA = {
], ],
"movies": { "movies": {
"1198359": { "1198359": {
"id": "1198359",
"name": "King Arthur: Legend of the Sword", "name": "King Arthur: Legend of the Sword",
"synopsis": ( "synopsis": (
"When the child Arthur's father is murdered, Vortigern, " "When the child Arthur's father is murdered, Vortigern, "
@ -159,7 +160,7 @@ def test_example_end_to_end():
"date": "2017-05-19", "date": "2017-05-19",
"movie": { "movie": {
"__typename": "Movie", "__typename": "Movie",
"id": "TW92aWU6Tm9uZQ==", "id": "TW92aWU6MTE5ODM1OQ==",
"name": "King Arthur: Legend of the Sword", "name": "King Arthur: Legend of the Sword",
"synopsis": ( "synopsis": (
"When the child Arthur's father is murdered, Vortigern, " "When the child Arthur's father is murdered, Vortigern, "
@ -172,7 +173,7 @@ def test_example_end_to_end():
"__typename": "Event", "__typename": "Event",
"id": "RXZlbnQ6MjM0", "id": "RXZlbnQ6MjM0",
"date": "2017-05-20", "date": "2017-05-20",
"movie": {"__typename": "Movie", "id": "TW92aWU6Tm9uZQ=="}, "movie": {"__typename": "Movie", "id": "TW92aWU6MTE5ODM1OQ=="},
}, },
] ]
} }

View File

@ -38,4 +38,4 @@ def test_orderedtype_non_orderabletypes():
assert one.__lt__(1) == NotImplemented assert one.__lt__(1) == NotImplemented
assert one.__gt__(1) == NotImplemented assert one.__gt__(1) == NotImplemented
assert not one == 1 assert one != 1

View File

@ -1,5 +1,5 @@
# coding: utf-8 # coding: utf-8
from ..str_converters import to_camel_case, to_const, to_snake_case from ..str_converters import to_camel_case, to_snake_case
def test_snake_case(): def test_snake_case():
@ -17,11 +17,3 @@ def test_camel_case():
assert to_camel_case("snakes_on_a__plane") == "snakesOnA_Plane" assert to_camel_case("snakes_on_a__plane") == "snakesOnA_Plane"
assert to_camel_case("i_phone_hysteria") == "iPhoneHysteria" assert to_camel_case("i_phone_hysteria") == "iPhoneHysteria"
assert to_camel_case("field_i18n") == "fieldI18n" assert to_camel_case("field_i18n") == "fieldI18n"
def test_to_const():
assert to_const('snakes $1. on a "#plane') == "SNAKES_1_ON_A_PLANE"
def test_to_const_unicode():
assert to_const("Skoða þetta unicode stöff") == "SKODA_THETTA_UNICODE_STOFF"

View File

@ -0,0 +1,5 @@
from .depth_limit import depth_limit_validator
from .disable_introspection import DisableIntrospection
__all__ = ["DisableIntrospection", "depth_limit_validator"]

View File

@ -0,0 +1,195 @@
# This is a Python port of https://github.com/stems/graphql-depth-limit
# which is licensed under the terms of the MIT license, reproduced below.
#
# -----------
#
# MIT License
#
# Copyright (c) 2017 Stem
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
try:
from re import Pattern
except ImportError:
# backwards compatibility for v3.6
from typing import Pattern
from typing import Callable, Dict, List, Optional, Union
from graphql import GraphQLError
from graphql.validation import ValidationContext, ValidationRule
from graphql.language import (
DefinitionNode,
FieldNode,
FragmentDefinitionNode,
FragmentSpreadNode,
InlineFragmentNode,
Node,
OperationDefinitionNode,
)
from ..utils.is_introspection_key import is_introspection_key
IgnoreType = Union[Callable[[str], bool], Pattern, str]
def depth_limit_validator(
max_depth: int,
ignore: Optional[List[IgnoreType]] = None,
callback: Callable[[Dict[str, int]], None] = None,
):
class DepthLimitValidator(ValidationRule):
def __init__(self, validation_context: ValidationContext):
document = validation_context.document
definitions = document.definitions
fragments = get_fragments(definitions)
queries = get_queries_and_mutations(definitions)
query_depths = {}
for name in queries:
query_depths[name] = determine_depth(
node=queries[name],
fragments=fragments,
depth_so_far=0,
max_depth=max_depth,
context=validation_context,
operation_name=name,
ignore=ignore,
)
if callable(callback):
callback(query_depths)
super().__init__(validation_context)
return DepthLimitValidator
def get_fragments(
definitions: List[DefinitionNode],
) -> Dict[str, FragmentDefinitionNode]:
fragments = {}
for definition in definitions:
if isinstance(definition, FragmentDefinitionNode):
fragments[definition.name.value] = definition
return fragments
# This will actually get both queries and mutations.
# We can basically treat those the same
def get_queries_and_mutations(
definitions: List[DefinitionNode],
) -> Dict[str, OperationDefinitionNode]:
operations = {}
for definition in definitions:
if isinstance(definition, OperationDefinitionNode):
operation = definition.name.value if definition.name else "anonymous"
operations[operation] = definition
return operations
def determine_depth(
node: Node,
fragments: Dict[str, FragmentDefinitionNode],
depth_so_far: int,
max_depth: int,
context: ValidationContext,
operation_name: str,
ignore: Optional[List[IgnoreType]] = None,
) -> int:
if depth_so_far > max_depth:
context.report_error(
GraphQLError(
f"'{operation_name}' exceeds maximum operation depth of {max_depth}.",
[node],
)
)
return depth_so_far
if isinstance(node, FieldNode):
should_ignore = is_introspection_key(node.name.value) or is_ignored(
node, ignore
)
if should_ignore or not node.selection_set:
return 0
return 1 + max(
map(
lambda selection: determine_depth(
node=selection,
fragments=fragments,
depth_so_far=depth_so_far + 1,
max_depth=max_depth,
context=context,
operation_name=operation_name,
ignore=ignore,
),
node.selection_set.selections,
)
)
elif isinstance(node, FragmentSpreadNode):
return determine_depth(
node=fragments[node.name.value],
fragments=fragments,
depth_so_far=depth_so_far,
max_depth=max_depth,
context=context,
operation_name=operation_name,
ignore=ignore,
)
elif isinstance(
node, (InlineFragmentNode, FragmentDefinitionNode, OperationDefinitionNode)
):
return max(
map(
lambda selection: determine_depth(
node=selection,
fragments=fragments,
depth_so_far=depth_so_far,
max_depth=max_depth,
context=context,
operation_name=operation_name,
ignore=ignore,
),
node.selection_set.selections,
)
)
else:
raise Exception(
f"Depth crawler cannot handle: {node.kind}."
) # pragma: no cover
def is_ignored(node: FieldNode, ignore: Optional[List[IgnoreType]] = None) -> bool:
if ignore is None:
return False
for rule in ignore:
field_name = node.name.value
if isinstance(rule, str):
if field_name == rule:
return True
elif isinstance(rule, Pattern):
if rule.match(field_name):
return True
elif callable(rule):
if rule(field_name):
return True
else:
raise ValueError(f"Invalid ignore option: {rule}.")
return False

View File

@ -0,0 +1,16 @@
from graphql import GraphQLError
from graphql.language import FieldNode
from graphql.validation import ValidationRule
from ..utils.is_introspection_key import is_introspection_key
class DisableIntrospection(ValidationRule):
def enter_field(self, node: FieldNode, *_args):
field_name = node.name.value
if is_introspection_key(field_name):
self.report_error(
GraphQLError(
f"Cannot query '{field_name}': introspection is disabled.", node
)
)

View File

View File

@ -0,0 +1,254 @@
import re
from pytest import raises
from graphql import parse, get_introspection_query, validate
from ...types import Schema, ObjectType, Interface
from ...types import String, Int, List, Field
from ..depth_limit import depth_limit_validator
class PetType(Interface):
name = String(required=True)
class meta:
name = "Pet"
class CatType(ObjectType):
class meta:
name = "Cat"
interfaces = (PetType,)
class DogType(ObjectType):
class meta:
name = "Dog"
interfaces = (PetType,)
class AddressType(ObjectType):
street = String(required=True)
number = Int(required=True)
city = String(required=True)
country = String(required=True)
class Meta:
name = "Address"
class HumanType(ObjectType):
name = String(required=True)
email = String(required=True)
address = Field(AddressType, required=True)
pets = List(PetType, required=True)
class Meta:
name = "Human"
class Query(ObjectType):
user = Field(HumanType, required=True, name=String())
version = String(required=True)
user1 = Field(HumanType, required=True)
user2 = Field(HumanType, required=True)
user3 = Field(HumanType, required=True)
@staticmethod
def resolve_user(root, info, name=None):
pass
schema = Schema(query=Query)
def run_query(query: str, max_depth: int, ignore=None):
document = parse(query)
result = None
def callback(query_depths):
nonlocal result
result = query_depths
errors = validate(
schema=schema.graphql_schema,
document_ast=document,
rules=(
depth_limit_validator(
max_depth=max_depth, ignore=ignore, callback=callback
),
),
)
return errors, result
def test_should_count_depth_without_fragment():
query = """
query read0 {
version
}
query read1 {
version
user {
name
}
}
query read2 {
matt: user(name: "matt") {
email
}
andy: user(name: "andy") {
email
address {
city
}
}
}
query read3 {
matt: user(name: "matt") {
email
}
andy: user(name: "andy") {
email
address {
city
}
pets {
name
owner {
name
}
}
}
}
"""
expected = {"read0": 0, "read1": 1, "read2": 2, "read3": 3}
errors, result = run_query(query, 10)
assert not errors
assert result == expected
def test_should_count_with_fragments():
query = """
query read0 {
... on Query {
version
}
}
query read1 {
version
user {
... on Human {
name
}
}
}
fragment humanInfo on Human {
email
}
fragment petInfo on Pet {
name
owner {
name
}
}
query read2 {
matt: user(name: "matt") {
...humanInfo
}
andy: user(name: "andy") {
...humanInfo
address {
city
}
}
}
query read3 {
matt: user(name: "matt") {
...humanInfo
}
andy: user(name: "andy") {
... on Human {
email
}
address {
city
}
pets {
...petInfo
}
}
}
"""
expected = {"read0": 0, "read1": 1, "read2": 2, "read3": 3}
errors, result = run_query(query, 10)
assert not errors
assert result == expected
def test_should_ignore_the_introspection_query():
errors, result = run_query(get_introspection_query(), 10)
assert not errors
assert result == {"IntrospectionQuery": 0}
def test_should_catch_very_deep_query():
query = """{
user {
pets {
owner {
pets {
owner {
pets {
name
}
}
}
}
}
}
}
"""
errors, result = run_query(query, 4)
assert len(errors) == 1
assert errors[0].message == "'anonymous' exceeds maximum operation depth of 4."
def test_should_ignore_field():
query = """
query read1 {
user { address { city } }
}
query read2 {
user1 { address { city } }
user2 { address { city } }
user3 { address { city } }
}
"""
errors, result = run_query(
query,
10,
ignore=["user1", re.compile("user2"), lambda field_name: field_name == "user3"],
)
expected = {"read1": 2, "read2": 0}
assert not errors
assert result == expected
def test_should_raise_invalid_ignore():
query = """
query read1 {
user { address { city } }
}
"""
with raises(ValueError, match="Invalid ignore option:"):
run_query(query, 10, ignore=[True])

View File

@ -0,0 +1,37 @@
from graphql import parse, validate
from ...types import Schema, ObjectType, String
from ..disable_introspection import DisableIntrospection
class Query(ObjectType):
name = String(required=True)
@staticmethod
def resolve_name(root, info):
return "Hello world!"
schema = Schema(query=Query)
def run_query(query: str):
document = parse(query)
return validate(
schema=schema.graphql_schema,
document_ast=document,
rules=(DisableIntrospection,),
)
def test_disallows_introspection_queries():
errors = run_query("{ __schema { queryType { name } } }")
assert len(errors) == 1
assert errors[0].message == "Cannot query '__schema': introspection is disabled."
def test_allows_non_introspection_queries():
errors = run_query("{ name }")
assert len(errors) == 0

Some files were not shown because too many files have changed in this diff Show More