Merge branch 'master' into sdr/subclass_mutations

This commit is contained in:
Syrus Akbary 2018-07-09 18:49:07 -07:00 committed by GitHub
commit 319605bfaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
110 changed files with 2640 additions and 2468 deletions

View File

@ -2,30 +2,23 @@ repos:
- repo: git://github.com/pre-commit/pre-commit-hooks - repo: git://github.com/pre-commit/pre-commit-hooks
rev: v1.3.0 rev: v1.3.0
hooks: hooks:
- id: autopep8-wrapper
args:
- -i
- --ignore=E128,E309,E501
exclude: ^docs/.*$
- id: check-json - id: check-json
- id: check-yaml - id: check-yaml
- id: debug-statements - id: debug-statements
- id: end-of-file-fixer - id: end-of-file-fixer
exclude: ^docs/.*$ exclude: ^docs/.*$
- id: trailing-whitespace - id: trailing-whitespace
exclude: README.md
- id: pretty-format-json - id: pretty-format-json
args: args:
- --autofix - --autofix
- id: flake8 - id: flake8
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v1.2.0 rev: v1.4.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
- repo: https://github.com/asottile/seed-isort-config - repo: https://github.com/ambv/black
rev: v1.0.0 rev: 18.6b4
hooks: hooks:
- id: seed-isort-config - id: black
- repo: https://github.com/pre-commit/mirrors-isort language_version: python3.6
rev: v4.3.4
hooks:
- id: isort

View File

@ -1,64 +1,28 @@
language: python language: python
sudo: false
python:
- 2.7
- 3.5
- 3.6
# - "pypy-5.3.1"
before_install:
- |
if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then
export PYENV_ROOT="$HOME/.pyenv"
if [ -f "$PYENV_ROOT/bin/pyenv" ]; then
cd "$PYENV_ROOT" && git pull
else
rm -rf "$PYENV_ROOT" && git clone --depth 1 https://github.com/yyuu/pyenv.git "$PYENV_ROOT"
fi
export PYPY_VERSION="4.0.1"
"$PYENV_ROOT/bin/pyenv" install "pypy-$PYPY_VERSION"
virtualenv --python="$PYENV_ROOT/versions/pypy-$PYPY_VERSION/bin/python" "$HOME/virtualenvs/pypy-$PYPY_VERSION"
source "$HOME/virtualenvs/pypy-$PYPY_VERSION/bin/activate"
fi
install:
- |
if [ "$TEST_TYPE" = build ]; then
pip install -e .[test]
python setup.py develop
elif [ "$TEST_TYPE" = lint ]; then
pip install flake8
elif [ "$TEST_TYPE" = mypy ]; then
pip install mypy
fi
script:
- |
if [ "$TEST_TYPE" = lint ]; then
echo "Checking Python code lint."
flake8 graphene
exit
elif [ "$TEST_TYPE" = mypy ]; then
echo "Checking Python types."
mypy graphene
exit
elif [ "$TEST_TYPE" = build ]; then
py.test --cov=graphene graphene examples
fi
after_success:
- |
if [ "$TEST_TYPE" = build ]; then
coveralls
fi
env:
matrix: matrix:
- TEST_TYPE=build
global:
secure: SQC0eCWCWw8bZxbLE8vQn+UjJOp3Z1m779s9SMK3lCLwJxro/VCLBZ7hj4xsrq1MtcFO2U2Kqf068symw4Hr/0amYI3HFTCFiwXAC3PAKXeURca03eNO2heku+FtnQcOjBanExTsIBQRLDXMOaUkf3MIztpLJ4LHqMfUupKmw9YSB0v40jDbSN8khBnndFykmOnVVHznFp8USoN5F0CiPpnfEvHnJkaX76lNf7Kc9XNShBTTtJsnsHMhuYQeInt0vg9HSjoIYC38Tv2hmMj1myNdzyrHF+LgRjI6ceGi50ApAnGepXC/DNRhXROfECKez+LON/ZSqBGdJhUILqC8A4WmWmIjNcwitVFp3JGBqO7LULS0BI96EtSLe8rD1rkkdTbjivajkbykM1Q0Tnmg1adzGwLxRUbTq9tJQlTTkHBCuXIkpKb1mAtb/TY7A6BqfnPi2xTc/++qEawUG7ePhscdTj0IBrUfZsUNUYZqD8E8XbSWKIuS3SHE+cZ+s/kdAsm4q+FFAlpZKOYGxIkwvgyfu4/Plfol4b7X6iAP9J3r1Kv0DgBVFst5CXEwzZs19/g0CgokQbCXf1N+xeNnUELl6/fImaR3RKP22EaABoil4z8vzl4EqxqVoH1nfhE+WlpryXsuSaF/1R+WklR7aQ1FwoCk8V8HxM2zrj4tI8k=
matrix:
fast_finish: true
include: include:
- python: '2.7' - env: TOXENV=py27
env: TEST_TYPE=lint python: 2.7
- python: '3.6' - env: TOXENV=py34
env: TEST_TYPE=mypy python: 3.4
- env: TOXENV=py35
python: 3.5
- env: TOXENV=py36
python: 3.6
- env: TOXENV=pypy
python: pypy-5.7.1
- env: TOXENV=pre-commit
python: 3.6
- env: TOXENV=mypy
python: 3.6
install:
- pip install coveralls tox
script: tox
after_success: coveralls
cache:
directories:
- $HOME/.cache/pip
- $HOME/.cache/pre-commit
deploy: deploy:
provider: pypi provider: pypi
user: syrusakbary user: syrusakbary

View File

@ -2,7 +2,7 @@ import os
import sphinx_graphene_theme import sphinx_graphene_theme
on_rtd = os.environ.get('READTHEDOCS', None) == 'True' on_rtd = os.environ.get("READTHEDOCS", None) == "True"
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
@ -36,46 +36,44 @@ on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.todo', "sphinx.ext.todo",
'sphinx.ext.coverage', "sphinx.ext.coverage",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
] ]
if not on_rtd: if not on_rtd:
extensions += [ extensions += ["sphinx.ext.githubpages"]
'sphinx.ext.githubpages',
]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
# source_suffix = ['.rst', '.md'] # source_suffix = ['.rst', '.md']
source_suffix = '.rst' source_suffix = ".rst"
# The encoding of source files. # The encoding of source files.
# #
# source_encoding = 'utf-8-sig' # source_encoding = 'utf-8-sig'
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = u'Graphene' project = u"Graphene"
copyright = u'Graphene 2016' copyright = u"Graphene 2016"
author = u'Syrus Akbary' author = u"Syrus Akbary"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = u'1.0' version = u"1.0"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = u'1.0' release = u"1.0"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
@ -96,7 +94,7 @@ language = None
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path # This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The reST default role (used for this markup: `text`) to use for all # The reST default role (used for this markup: `text`) to use for all
# documents. # documents.
@ -118,7 +116,7 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# show_authors = False # show_authors = False
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting. # A list of ignored prefixes for module index sorting.
# modindex_common_prefix = [] # modindex_common_prefix = []
@ -175,7 +173,7 @@ html_theme_path = [sphinx_graphene_theme.get_html_theme_path()]
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or # Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied # .htaccess) here, relative to this directory. These files are copied
@ -255,7 +253,7 @@ html_static_path = ['_static']
# html_search_scorer = 'scorer.js' # html_search_scorer = 'scorer.js'
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'Graphenedoc' htmlhelp_basename = "Graphenedoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
@ -263,15 +261,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
@ -281,8 +276,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'Graphene.tex', u'Graphene Documentation', (master_doc, "Graphene.tex", u"Graphene Documentation", u"Syrus Akbary", "manual")
u'Syrus Akbary', 'manual'),
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
@ -322,10 +316,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "graphene", u"Graphene Documentation", [author], 1)]
(master_doc, 'graphene', u'Graphene Documentation',
[author], 1)
]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
# #
@ -338,9 +329,15 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'Graphene', u'Graphene Documentation', (
author, 'Graphene', 'One line description of project.', master_doc,
'Miscellaneous'), "Graphene",
u"Graphene Documentation",
author,
"Graphene",
"One line description of project.",
"Miscellaneous",
)
] ]
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.
@ -414,7 +411,7 @@ epub_copyright = copyright
# epub_post_files = [] # epub_post_files = []
# A list of files that should not be packed into the epub file. # A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html'] epub_exclude_files = ["search.html"]
# The depth of the table of contents in toc.ncx. # The depth of the table of contents in toc.ncx.
# #
@ -447,9 +444,15 @@ epub_exclude_files = ['search.html']
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'https://docs.python.org/': None, "https://docs.python.org/": None,
'python': ('https://docs.python.org/', None), "python": ("https://docs.python.org/", None),
'graphene_django': ('http://docs.graphene-python.org/projects/django/en/latest/', None), "graphene_django": (
'graphene_sqlalchemy': ('http://docs.graphene-python.org/projects/sqlalchemy/en/latest/', None), "http://docs.graphene-python.org/projects/django/en/latest/",
'graphene_gae': ('http://docs.graphene-python.org/projects/gae/en/latest/', None), None,
),
"graphene_sqlalchemy": (
"http://docs.graphene-python.org/projects/sqlalchemy/en/latest/",
None,
),
"graphene_gae": ("http://docs.graphene-python.org/projects/gae/en/latest/", None),
} }

View File

@ -22,7 +22,6 @@ class Query(graphene.ObjectType):
class CreateAddress(graphene.Mutation): class CreateAddress(graphene.Mutation):
class Arguments: class Arguments:
geo = GeoInput(required=True) geo = GeoInput(required=True)
@ -37,42 +36,34 @@ class Mutation(graphene.ObjectType):
schema = graphene.Schema(query=Query, mutation=Mutation) schema = graphene.Schema(query=Query, mutation=Mutation)
query = ''' query = """
query something{ query something{
address(geo: {lat:32.2, lng:12}) { address(geo: {lat:32.2, lng:12}) {
latlng latlng
} }
} }
''' """
mutation = ''' mutation = """
mutation addAddress{ mutation addAddress{
createAddress(geo: {lat:32.2, lng:12}) { createAddress(geo: {lat:32.2, lng:12}) {
latlng latlng
} }
} }
''' """
def test_query(): def test_query():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"address": {"latlng": "(32.2,12.0)"}}
'address': {
'latlng': "(32.2,12.0)",
}
}
def test_mutation(): def test_mutation():
result = schema.execute(mutation) result = schema.execute(mutation)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"createAddress": {"latlng": "(32.2,12.0)"}}
'createAddress': {
'latlng': "(32.2,12.0)",
}
}
if __name__ == '__main__': if __name__ == "__main__":
result = schema.execute(query) result = schema.execute(query)
print(result.data['address']['latlng']) print(result.data["address"]["latlng"])

View File

@ -10,31 +10,26 @@ class Query(graphene.ObjectType):
me = graphene.Field(User) me = graphene.Field(User)
def resolve_me(self, info): def resolve_me(self, info):
return info.context['user'] return info.context["user"]
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
query = ''' query = """
query something{ query something{
me { me {
id id
name name
} }
} }
''' """
def test_query(): def test_query():
result = schema.execute(query, context_value={'user': User(id='1', name='Syrus')}) result = schema.execute(query, context_value={"user": User(id="1", name="Syrus")})
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"me": {"id": "1", "name": "Syrus"}}
'me': {
'id': '1',
'name': 'Syrus',
}
}
if __name__ == '__main__': if __name__ == "__main__":
result = schema.execute(query, context_value={'user': User(id='X', name='Console')}) result = schema.execute(query, context_value={"user": User(id="X", name="Console")})
print(result.data['me']) print(result.data["me"])

View File

@ -12,11 +12,11 @@ class Query(graphene.ObjectType):
patron = graphene.Field(Patron) patron = graphene.Field(Patron)
def resolve_patron(self, info): def resolve_patron(self, info):
return Patron(id=1, name='Syrus', age=27) return Patron(id=1, name="Syrus", age=27)
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
query = ''' query = """
query something{ query something{
patron { patron {
id id
@ -24,21 +24,15 @@ query = '''
age age
} }
} }
''' """
def test_query(): def test_query():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"patron": {"id": "1", "name": "Syrus", "age": 27}}
'patron': {
'id': '1',
'name': 'Syrus',
'age': 27,
}
}
if __name__ == '__main__': if __name__ == "__main__":
result = schema.execute(query) result = schema.execute(query)
print(result.data['patron']) print(result.data["patron"])

View File

@ -4,75 +4,73 @@ droid_data = {}
def setup(): def setup():
from .schema import Human, Droid from .schema import Human, Droid
global human_data, droid_data global human_data, droid_data
luke = Human( luke = Human(
id='1000', id="1000",
name='Luke Skywalker', name="Luke Skywalker",
friends=['1002', '1003', '2000', '2001'], friends=["1002", "1003", "2000", "2001"],
appears_in=[4, 5, 6], appears_in=[4, 5, 6],
home_planet='Tatooine', home_planet="Tatooine",
) )
vader = Human( vader = Human(
id='1001', id="1001",
name='Darth Vader', name="Darth Vader",
friends=['1004'], friends=["1004"],
appears_in=[4, 5, 6], appears_in=[4, 5, 6],
home_planet='Tatooine', home_planet="Tatooine",
) )
han = Human( han = Human(
id='1002', id="1002",
name='Han Solo', name="Han Solo",
friends=['1000', '1003', '2001'], friends=["1000", "1003", "2001"],
appears_in=[4, 5, 6], appears_in=[4, 5, 6],
home_planet=None, home_planet=None,
) )
leia = Human( leia = Human(
id='1003', id="1003",
name='Leia Organa', name="Leia Organa",
friends=['1000', '1002', '2000', '2001'], friends=["1000", "1002", "2000", "2001"],
appears_in=[4, 5, 6], appears_in=[4, 5, 6],
home_planet='Alderaan', home_planet="Alderaan",
) )
tarkin = Human( tarkin = Human(
id='1004', id="1004",
name='Wilhuff Tarkin', name="Wilhuff Tarkin",
friends=['1001'], friends=["1001"],
appears_in=[4], appears_in=[4],
home_planet=None, home_planet=None,
) )
human_data = { human_data = {
'1000': luke, "1000": luke,
'1001': vader, "1001": vader,
'1002': han, "1002": han,
'1003': leia, "1003": leia,
'1004': tarkin, "1004": tarkin,
} }
c3po = Droid( c3po = Droid(
id='2000', id="2000",
name='C-3PO', name="C-3PO",
friends=['1000', '1002', '1003', '2001'], friends=["1000", "1002", "1003", "2001"],
appears_in=[4, 5, 6], appears_in=[4, 5, 6],
primary_function='Protocol', primary_function="Protocol",
) )
r2d2 = Droid( r2d2 = Droid(
id='2001', id="2001",
name='R2-D2', name="R2-D2",
friends=['1000', '1002', '1003'], friends=["1000", "1002", "1003"],
appears_in=[4, 5, 6], appears_in=[4, 5, 6],
primary_function='Astromech', primary_function="Astromech",
) )
droid_data = { droid_data = {"2000": c3po, "2001": r2d2}
'2000': c3po,
'2001': r2d2,
}
def get_character(id): def get_character(id):
@ -85,8 +83,8 @@ def get_friends(character):
def get_hero(episode): def get_hero(episode):
if episode == 5: if episode == 5:
return human_data['1000'] return human_data["1000"]
return droid_data['2001'] return droid_data["2001"]
def get_human(id): def get_human(id):

View File

@ -21,29 +21,23 @@ class Character(graphene.Interface):
class Human(graphene.ObjectType): class Human(graphene.ObjectType):
class Meta: class Meta:
interfaces = (Character,) interfaces = (Character,)
home_planet = graphene.String() home_planet = graphene.String()
class Droid(graphene.ObjectType): class Droid(graphene.ObjectType):
class Meta: class Meta:
interfaces = (Character,) interfaces = (Character,)
primary_function = graphene.String() primary_function = graphene.String()
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
hero = graphene.Field(Character, hero = graphene.Field(Character, episode=Episode())
episode=Episode() human = graphene.Field(Human, id=graphene.String())
) droid = graphene.Field(Droid, id=graphene.String())
human = graphene.Field(Human,
id=graphene.String()
)
droid = graphene.Field(Droid,
id=graphene.String()
)
def resolve_hero(self, info, episode=None): def resolve_hero(self, info, episode=None):
return get_hero(episode) return get_hero(episode)

View File

@ -6,196 +6,95 @@ from snapshottest import Snapshot
snapshots = Snapshot() snapshots = Snapshot()
snapshots['test_hero_name_query 1'] = { snapshots["test_hero_name_query 1"] = {"data": {"hero": {"name": "R2-D2"}}}
'data': {
'hero': {
'name': 'R2-D2'
}
}
}
snapshots['test_hero_name_and_friends_query 1'] = { snapshots["test_hero_name_and_friends_query 1"] = {
'data': { "data": {
'hero': { "hero": {
'id': '2001', "id": "2001",
'name': 'R2-D2', "name": "R2-D2",
'friends': [ "friends": [
{ {"name": "Luke Skywalker"},
'name': 'Luke Skywalker' {"name": "Han Solo"},
}, {"name": "Leia Organa"},
{
'name': 'Han Solo'
},
{
'name': 'Leia Organa'
}
]
}
}
}
snapshots['test_nested_query 1'] = {
'data': {
'hero': {
'name': 'R2-D2',
'friends': [
{
'name': 'Luke Skywalker',
'appearsIn': [
'NEWHOPE',
'EMPIRE',
'JEDI'
], ],
'friends': [
{
'name': 'Han Solo'
},
{
'name': 'Leia Organa'
},
{
'name': 'C-3PO'
},
{
'name': 'R2-D2'
} }
] }
}, }
snapshots["test_nested_query 1"] = {
"data": {
"hero": {
"name": "R2-D2",
"friends": [
{ {
'name': 'Han Solo', "name": "Luke Skywalker",
'appearsIn': [ "appearsIn": ["NEWHOPE", "EMPIRE", "JEDI"],
'NEWHOPE', "friends": [
'EMPIRE', {"name": "Han Solo"},
'JEDI' {"name": "Leia Organa"},
{"name": "C-3PO"},
{"name": "R2-D2"},
], ],
'friends': [
{
'name': 'Luke Skywalker'
}, },
{ {
'name': 'Leia Organa' "name": "Han Solo",
}, "appearsIn": ["NEWHOPE", "EMPIRE", "JEDI"],
{ "friends": [
'name': 'R2-D2' {"name": "Luke Skywalker"},
} {"name": "Leia Organa"},
] {"name": "R2-D2"},
},
{
'name': 'Leia Organa',
'appearsIn': [
'NEWHOPE',
'EMPIRE',
'JEDI'
], ],
'friends': [
{
'name': 'Luke Skywalker'
}, },
{ {
'name': 'Han Solo' "name": "Leia Organa",
"appearsIn": ["NEWHOPE", "EMPIRE", "JEDI"],
"friends": [
{"name": "Luke Skywalker"},
{"name": "Han Solo"},
{"name": "C-3PO"},
{"name": "R2-D2"},
],
}, },
{ ],
'name': 'C-3PO'
},
{
'name': 'R2-D2'
}
]
}
]
} }
} }
} }
snapshots['test_fetch_luke_query 1'] = { snapshots["test_fetch_luke_query 1"] = {"data": {"human": {"name": "Luke Skywalker"}}}
'data': {
'human': { snapshots["test_fetch_some_id_query 1"] = {
'name': 'Luke Skywalker' "data": {"human": {"name": "Luke Skywalker"}}
} }
snapshots["test_fetch_some_id_query2 1"] = {"data": {"human": {"name": "Han Solo"}}}
snapshots["test_invalid_id_query 1"] = {"data": {"human": None}}
snapshots["test_fetch_luke_aliased 1"] = {"data": {"luke": {"name": "Luke Skywalker"}}}
snapshots["test_fetch_luke_and_leia_aliased 1"] = {
"data": {"luke": {"name": "Luke Skywalker"}, "leia": {"name": "Leia Organa"}}
}
snapshots["test_duplicate_fields 1"] = {
"data": {
"luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"},
"leia": {"name": "Leia Organa", "homePlanet": "Alderaan"},
} }
} }
snapshots['test_fetch_some_id_query 1'] = { snapshots["test_use_fragment 1"] = {
'data': { "data": {
'human': { "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"},
'name': 'Luke Skywalker' "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"},
}
}
}
snapshots['test_fetch_some_id_query2 1'] = {
'data': {
'human': {
'name': 'Han Solo'
}
}
}
snapshots['test_invalid_id_query 1'] = {
'data': {
'human': None
} }
} }
snapshots['test_fetch_luke_aliased 1'] = { snapshots["test_check_type_of_r2 1"] = {
'data': { "data": {"hero": {"__typename": "Droid", "name": "R2-D2"}}
'luke': {
'name': 'Luke Skywalker'
}
}
} }
snapshots['test_fetch_luke_and_leia_aliased 1'] = { snapshots["test_check_type_of_luke 1"] = {
'data': { "data": {"hero": {"__typename": "Human", "name": "Luke Skywalker"}}
'luke': {
'name': 'Luke Skywalker'
},
'leia': {
'name': 'Leia Organa'
}
}
}
snapshots['test_duplicate_fields 1'] = {
'data': {
'luke': {
'name': 'Luke Skywalker',
'homePlanet': 'Tatooine'
},
'leia': {
'name': 'Leia Organa',
'homePlanet': 'Alderaan'
}
}
}
snapshots['test_use_fragment 1'] = {
'data': {
'luke': {
'name': 'Luke Skywalker',
'homePlanet': 'Tatooine'
},
'leia': {
'name': 'Leia Organa',
'homePlanet': 'Alderaan'
}
}
}
snapshots['test_check_type_of_r2 1'] = {
'data': {
'hero': {
'__typename': 'Droid',
'name': 'R2-D2'
}
}
}
snapshots['test_check_type_of_luke 1'] = {
'data': {
'hero': {
'__typename': 'Human',
'name': 'Luke Skywalker'
}
}
} }

View File

@ -9,18 +9,18 @@ client = Client(schema)
def test_hero_name_query(snapshot): def test_hero_name_query(snapshot):
query = ''' query = """
query HeroNameQuery { query HeroNameQuery {
hero { hero {
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_hero_name_and_friends_query(snapshot): def test_hero_name_and_friends_query(snapshot):
query = ''' query = """
query HeroNameAndFriendsQuery { query HeroNameAndFriendsQuery {
hero { hero {
id id
@ -30,12 +30,12 @@ def test_hero_name_and_friends_query(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_nested_query(snapshot): def test_nested_query(snapshot):
query = ''' query = """
query NestedQuery { query NestedQuery {
hero { hero {
name name
@ -48,76 +48,70 @@ def test_nested_query(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_fetch_luke_query(snapshot): def test_fetch_luke_query(snapshot):
query = ''' query = """
query FetchLukeQuery { query FetchLukeQuery {
human(id: "1000") { human(id: "1000") {
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_fetch_some_id_query(snapshot): def test_fetch_some_id_query(snapshot):
query = ''' query = """
query FetchSomeIDQuery($someId: String!) { query FetchSomeIDQuery($someId: String!) {
human(id: $someId) { human(id: $someId) {
name name
} }
} }
''' """
params = { params = {"someId": "1000"}
'someId': '1000',
}
snapshot.assert_match(client.execute(query, variable_values=params)) snapshot.assert_match(client.execute(query, variable_values=params))
def test_fetch_some_id_query2(snapshot): def test_fetch_some_id_query2(snapshot):
query = ''' query = """
query FetchSomeIDQuery($someId: String!) { query FetchSomeIDQuery($someId: String!) {
human(id: $someId) { human(id: $someId) {
name name
} }
} }
''' """
params = { params = {"someId": "1002"}
'someId': '1002',
}
snapshot.assert_match(client.execute(query, variable_values=params)) snapshot.assert_match(client.execute(query, variable_values=params))
def test_invalid_id_query(snapshot): def test_invalid_id_query(snapshot):
query = ''' query = """
query humanQuery($id: String!) { query humanQuery($id: String!) {
human(id: $id) { human(id: $id) {
name name
} }
} }
''' """
params = { params = {"id": "not a valid id"}
'id': 'not a valid id',
}
snapshot.assert_match(client.execute(query, variable_values=params)) snapshot.assert_match(client.execute(query, variable_values=params))
def test_fetch_luke_aliased(snapshot): def test_fetch_luke_aliased(snapshot):
query = ''' query = """
query FetchLukeAliased { query FetchLukeAliased {
luke: human(id: "1000") { luke: human(id: "1000") {
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_fetch_luke_and_leia_aliased(snapshot): def test_fetch_luke_and_leia_aliased(snapshot):
query = ''' query = """
query FetchLukeAndLeiaAliased { query FetchLukeAndLeiaAliased {
luke: human(id: "1000") { luke: human(id: "1000") {
name name
@ -126,12 +120,12 @@ def test_fetch_luke_and_leia_aliased(snapshot):
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_duplicate_fields(snapshot): def test_duplicate_fields(snapshot):
query = ''' query = """
query DuplicateFields { query DuplicateFields {
luke: human(id: "1000") { luke: human(id: "1000") {
name name
@ -142,12 +136,12 @@ def test_duplicate_fields(snapshot):
homePlanet homePlanet
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_use_fragment(snapshot): def test_use_fragment(snapshot):
query = ''' query = """
query UseFragment { query UseFragment {
luke: human(id: "1000") { luke: human(id: "1000") {
...HumanFragment ...HumanFragment
@ -160,29 +154,29 @@ def test_use_fragment(snapshot):
name name
homePlanet homePlanet
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_check_type_of_r2(snapshot): def test_check_type_of_r2(snapshot):
query = ''' query = """
query CheckTypeOfR2 { query CheckTypeOfR2 {
hero { hero {
__typename __typename
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_check_type_of_luke(snapshot): def test_check_type_of_luke(snapshot):
query = ''' query = """
query CheckTypeOfLuke { query CheckTypeOfLuke {
hero(episode: EMPIRE) { hero(episode: EMPIRE) {
__typename __typename
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))

View File

@ -5,101 +5,67 @@ def setup():
global data global data
from .schema import Ship, Faction from .schema import Ship, Faction
xwing = Ship(
id='1',
name='X-Wing',
)
ywing = Ship( xwing = Ship(id="1", name="X-Wing")
id='2',
name='Y-Wing',
)
awing = Ship( ywing = Ship(id="2", name="Y-Wing")
id='3',
name='A-Wing', awing = Ship(id="3", name="A-Wing")
)
# Yeah, technically it's Corellian. But it flew in the service of the rebels, # Yeah, technically it's Corellian. But it flew in the service of the rebels,
# so for the purposes of this demo it's a rebel ship. # so for the purposes of this demo it's a rebel ship.
falcon = Ship( falcon = Ship(id="4", name="Millenium Falcon")
id='4',
name='Millenium Falcon',
)
homeOne = Ship( homeOne = Ship(id="5", name="Home One")
id='5',
name='Home One',
)
tieFighter = Ship( tieFighter = Ship(id="6", name="TIE Fighter")
id='6',
name='TIE Fighter',
)
tieInterceptor = Ship( tieInterceptor = Ship(id="7", name="TIE Interceptor")
id='7',
name='TIE Interceptor',
)
executor = Ship( executor = Ship(id="8", name="Executor")
id='8',
name='Executor',
)
rebels = Faction( rebels = Faction(
id='1', id="1", name="Alliance to Restore the Republic", ships=["1", "2", "3", "4", "5"]
name='Alliance to Restore the Republic',
ships=['1', '2', '3', '4', '5']
) )
empire = Faction( empire = Faction(id="2", name="Galactic Empire", ships=["6", "7", "8"])
id='2',
name='Galactic Empire',
ships=['6', '7', '8']
)
data = { data = {
'Faction': { "Faction": {"1": rebels, "2": empire},
'1': rebels, "Ship": {
'2': empire "1": xwing,
"2": ywing,
"3": awing,
"4": falcon,
"5": homeOne,
"6": tieFighter,
"7": tieInterceptor,
"8": executor,
}, },
'Ship': {
'1': xwing,
'2': ywing,
'3': awing,
'4': falcon,
'5': homeOne,
'6': tieFighter,
'7': tieInterceptor,
'8': executor
}
} }
def create_ship(ship_name, faction_id): def create_ship(ship_name, faction_id):
from .schema import Ship from .schema import Ship
next_ship = len(data['Ship'].keys()) + 1
new_ship = Ship( next_ship = len(data["Ship"].keys()) + 1
id=str(next_ship), new_ship = Ship(id=str(next_ship), name=ship_name)
name=ship_name data["Ship"][new_ship.id] = new_ship
) data["Faction"][faction_id].ships.append(new_ship.id)
data['Ship'][new_ship.id] = new_ship
data['Faction'][faction_id].ships.append(new_ship.id)
return new_ship return new_ship
def get_ship(_id): def get_ship(_id):
return data['Ship'][_id] return data["Ship"][_id]
def get_faction(_id): def get_faction(_id):
return data['Faction'][_id] return data["Faction"][_id]
def get_rebels(): def get_rebels():
return get_faction('1') return get_faction("1")
def get_empire(): def get_empire():
return get_faction('2') return get_faction("2")

View File

@ -5,12 +5,12 @@ from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
class Ship(graphene.ObjectType): class Ship(graphene.ObjectType):
'''A ship in the Star Wars saga''' """A ship in the Star Wars saga"""
class Meta: class Meta:
interfaces = (relay.Node,) interfaces = (relay.Node,)
name = graphene.String(description='The name of the ship.') name = graphene.String(description="The name of the ship.")
@classmethod @classmethod
def get_node(cls, info, id): def get_node(cls, info, id):
@ -18,19 +18,20 @@ class Ship(graphene.ObjectType):
class ShipConnection(relay.Connection): class ShipConnection(relay.Connection):
class Meta: class Meta:
node = Ship node = Ship
class Faction(graphene.ObjectType): class Faction(graphene.ObjectType):
'''A faction in the Star Wars saga''' """A faction in the Star Wars saga"""
class Meta: class Meta:
interfaces = (relay.Node,) interfaces = (relay.Node,)
name = graphene.String(description='The name of the faction.') name = graphene.String(description="The name of the faction.")
ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.') ships = relay.ConnectionField(
ShipConnection, description="The ships used by the faction."
)
def resolve_ships(self, info, **args): def resolve_ships(self, info, **args):
# Transform the instance ship_ids into real instances # Transform the instance ship_ids into real instances
@ -42,7 +43,6 @@ class Faction(graphene.ObjectType):
class IntroduceShip(relay.ClientIDMutation): class IntroduceShip(relay.ClientIDMutation):
class Input: class Input:
ship_name = graphene.String(required=True) ship_name = graphene.String(required=True)
faction_id = graphene.String(required=True) faction_id = graphene.String(required=True)
@ -51,7 +51,9 @@ class IntroduceShip(relay.ClientIDMutation):
faction = graphene.Field(Faction) faction = graphene.Field(Faction)
@classmethod @classmethod
def mutate_and_get_payload(cls, root, info, ship_name, faction_id, client_mutation_id=None): def mutate_and_get_payload(
cls, root, info, ship_name, faction_id, client_mutation_id=None
):
ship = create_ship(ship_name, faction_id) ship = create_ship(ship_name, faction_id)
faction = get_faction(faction_id) faction = get_faction(faction_id)
return IntroduceShip(ship=ship, faction=faction) return IntroduceShip(ship=ship, faction=faction)

View File

@ -6,26 +6,21 @@ from snapshottest import Snapshot
snapshots = Snapshot() snapshots = Snapshot()
snapshots['test_correct_fetch_first_ship_rebels 1'] = { snapshots["test_correct_fetch_first_ship_rebels 1"] = {
'data': { "data": {
'rebels': { "rebels": {
'name': 'Alliance to Restore the Republic', "name": "Alliance to Restore the Republic",
'ships': { "ships": {
'pageInfo': { "pageInfo": {
'startCursor': 'YXJyYXljb25uZWN0aW9uOjA=', "startCursor": "YXJyYXljb25uZWN0aW9uOjA=",
'endCursor': 'YXJyYXljb25uZWN0aW9uOjA=', "endCursor": "YXJyYXljb25uZWN0aW9uOjA=",
'hasNextPage': True, "hasNextPage": True,
'hasPreviousPage': False "hasPreviousPage": False,
},
"edges": [
{"cursor": "YXJyYXljb25uZWN0aW9uOjA=", "node": {"name": "X-Wing"}}
],
}, },
'edges': [
{
'cursor': 'YXJyYXljb25uZWN0aW9uOjA=',
'node': {
'name': 'X-Wing'
}
}
]
}
} }
} }
} }

View File

@ -6,56 +6,23 @@ from snapshottest import Snapshot
snapshots = Snapshot() snapshots = Snapshot()
snapshots['test_mutations 1'] = { snapshots["test_mutations 1"] = {
'data': { "data": {
'introduceShip': { "introduceShip": {
'ship': { "ship": {"id": "U2hpcDo5", "name": "Peter"},
'id': 'U2hpcDo5', "faction": {
'name': 'Peter' "name": "Alliance to Restore the Republic",
}, "ships": {
'faction': { "edges": [
'name': 'Alliance to Restore the Republic', {"node": {"id": "U2hpcDox", "name": "X-Wing"}},
'ships': { {"node": {"id": "U2hpcDoy", "name": "Y-Wing"}},
'edges': [ {"node": {"id": "U2hpcDoz", "name": "A-Wing"}},
{ {"node": {"id": "U2hpcDo0", "name": "Millenium Falcon"}},
'node': { {"node": {"id": "U2hpcDo1", "name": "Home One"}},
'id': 'U2hpcDox', {"node": {"id": "U2hpcDo5", "name": "Peter"}},
'name': 'X-Wing'
}
},
{
'node': {
'id': 'U2hpcDoy',
'name': 'Y-Wing'
}
},
{
'node': {
'id': 'U2hpcDoz',
'name': 'A-Wing'
}
},
{
'node': {
'id': 'U2hpcDo0',
'name': 'Millenium Falcon'
}
},
{
'node': {
'id': 'U2hpcDo1',
'name': 'Home One'
}
},
{
'node': {
'id': 'U2hpcDo5',
'name': 'Peter'
}
}
] ]
} },
} },
} }
} }
} }

View File

@ -6,52 +6,31 @@ from snapshottest import Snapshot
snapshots = Snapshot() snapshots = Snapshot()
snapshots['test_correctly_fetches_id_name_rebels 1'] = { snapshots["test_correctly_fetches_id_name_rebels 1"] = {
'data': { "data": {
'rebels': { "rebels": {"id": "RmFjdGlvbjox", "name": "Alliance to Restore the Republic"}
'id': 'RmFjdGlvbjox',
'name': 'Alliance to Restore the Republic'
}
} }
} }
snapshots['test_correctly_refetches_rebels 1'] = { snapshots["test_correctly_refetches_rebels 1"] = {
'data': { "data": {"node": {"id": "RmFjdGlvbjox", "name": "Alliance to Restore the Republic"}}
'node': {
'id': 'RmFjdGlvbjox',
'name': 'Alliance to Restore the Republic'
}
}
} }
snapshots['test_correctly_fetches_id_name_empire 1'] = { snapshots["test_correctly_fetches_id_name_empire 1"] = {
'data': { "data": {"empire": {"id": "RmFjdGlvbjoy", "name": "Galactic Empire"}}
'empire': {
'id': 'RmFjdGlvbjoy',
'name': 'Galactic Empire'
}
}
} }
snapshots['test_correctly_refetches_empire 1'] = { snapshots["test_correctly_refetches_empire 1"] = {
'data': { "data": {"node": {"id": "RmFjdGlvbjoy", "name": "Galactic Empire"}}
'node': {
'id': 'RmFjdGlvbjoy',
'name': 'Galactic Empire'
}
}
} }
snapshots['test_correctly_refetches_xwing 1'] = { snapshots["test_correctly_refetches_xwing 1"] = {
'data': { "data": {"node": {"id": "U2hpcDox", "name": "X-Wing"}}
'node': {
'id': 'U2hpcDox',
'name': 'X-Wing'
}
}
} }
snapshots['test_str_schema 1'] = '''schema { snapshots[
"test_str_schema 1"
] = """schema {
query: Query query: Query
mutation: Mutation mutation: Mutation
} }
@ -109,4 +88,4 @@ type ShipEdge {
node: Ship node: Ship
cursor: String! cursor: String!
} }
''' """

View File

@ -9,7 +9,7 @@ client = Client(schema)
def test_correct_fetch_first_ship_rebels(snapshot): def test_correct_fetch_first_ship_rebels(snapshot):
query = ''' query = """
query RebelsShipsQuery { query RebelsShipsQuery {
rebels { rebels {
name, name,
@ -29,5 +29,5 @@ def test_correct_fetch_first_ship_rebels(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))

View File

@ -9,7 +9,7 @@ client = Client(schema)
def test_mutations(snapshot): def test_mutations(snapshot):
query = ''' query = """
mutation MyMutation { mutation MyMutation {
introduceShip(input:{clientMutationId:"abc", shipName: "Peter", factionId: "1"}) { introduceShip(input:{clientMutationId:"abc", shipName: "Peter", factionId: "1"}) {
ship { ship {
@ -29,5 +29,5 @@ def test_mutations(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))

View File

@ -13,19 +13,19 @@ def test_str_schema(snapshot):
def test_correctly_fetches_id_name_rebels(snapshot): def test_correctly_fetches_id_name_rebels(snapshot):
query = ''' query = """
query RebelsQuery { query RebelsQuery {
rebels { rebels {
id id
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_correctly_refetches_rebels(snapshot): def test_correctly_refetches_rebels(snapshot):
query = ''' query = """
query RebelsRefetchQuery { query RebelsRefetchQuery {
node(id: "RmFjdGlvbjox") { node(id: "RmFjdGlvbjox") {
id id
@ -34,24 +34,24 @@ def test_correctly_refetches_rebels(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_correctly_fetches_id_name_empire(snapshot): def test_correctly_fetches_id_name_empire(snapshot):
query = ''' query = """
query EmpireQuery { query EmpireQuery {
empire { empire {
id id
name name
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_correctly_refetches_empire(snapshot): def test_correctly_refetches_empire(snapshot):
query = ''' query = """
query EmpireRefetchQuery { query EmpireRefetchQuery {
node(id: "RmFjdGlvbjoy") { node(id: "RmFjdGlvbjoy") {
id id
@ -60,12 +60,12 @@ def test_correctly_refetches_empire(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))
def test_correctly_refetches_xwing(snapshot): def test_correctly_refetches_xwing(snapshot):
query = ''' query = """
query XWingRefetchQuery { query XWingRefetchQuery {
node(id: "U2hpcDox") { node(id: "U2hpcDox") {
id id
@ -74,5 +74,5 @@ def test_correctly_refetches_xwing(snapshot):
} }
} }
} }
''' """
snapshot.assert_match(client.execute(query)) snapshot.assert_match(client.execute(query))

View File

@ -10,17 +10,24 @@ from .types import (
InputField, InputField,
Schema, Schema,
Scalar, Scalar,
String, ID, Int, Float, Boolean, String,
Date, DateTime, Time, ID,
Int,
Float,
Boolean,
Date,
DateTime,
Time,
JSONString, JSONString,
UUID, UUID,
List, NonNull, List,
NonNull,
Enum, Enum,
Argument, Argument,
Dynamic, Dynamic,
Union, Union,
Context, Context,
ResolveInfo ResolveInfo,
) )
from .relay import ( from .relay import (
Node, Node,
@ -29,54 +36,53 @@ from .relay import (
ClientIDMutation, ClientIDMutation,
Connection, Connection,
ConnectionField, ConnectionField,
PageInfo PageInfo,
) )
from .utils.resolve_only_args import resolve_only_args from .utils.resolve_only_args import resolve_only_args
from .utils.module_loading import lazy_import from .utils.module_loading import lazy_import
VERSION = (2, 1, 2, 'final', 0) VERSION = (2, 1, 2, "final", 0)
__version__ = get_version(VERSION) __version__ = get_version(VERSION)
__all__ = [ __all__ = [
'__version__', "__version__",
'ObjectType', "ObjectType",
'InputObjectType', "InputObjectType",
'Interface', "Interface",
'Mutation', "Mutation",
'Field', "Field",
'InputField', "InputField",
'Schema', "Schema",
'Scalar', "Scalar",
'String', "String",
'ID', "ID",
'Int', "Int",
'Float', "Float",
'Enum', "Enum",
'Boolean', "Boolean",
'Date', "Date",
'DateTime', "DateTime",
'Time', "Time",
'JSONString', "JSONString",
'UUID', "UUID",
'List', "List",
'NonNull', "NonNull",
'Argument', "Argument",
'Dynamic', "Dynamic",
'Union', "Union",
'resolve_only_args', "resolve_only_args",
'Node', "Node",
'is_node', "is_node",
'GlobalID', "GlobalID",
'ClientIDMutation', "ClientIDMutation",
'Connection', "Connection",
'ConnectionField', "ConnectionField",
'PageInfo', "PageInfo",
'lazy_import', "lazy_import",
'Context', "Context",
'ResolveInfo', "ResolveInfo",
# Deprecated # Deprecated
'AbstractType', "AbstractType",
] ]

View File

@ -13,8 +13,12 @@ except ImportError:
from .signature import signature from .signature import signature
if six.PY2: if six.PY2:
def func_name(func): def func_name(func):
return func.func_name return func.func_name
else: else:
def func_name(func): def func_name(func):
return func.__name__ return func.__name__

View File

@ -2,21 +2,23 @@
import sys as _sys import sys as _sys
__all__ = ['Enum', 'IntEnum', 'unique'] __all__ = ["Enum", "IntEnum", "unique"]
version = 1, 1, 6 version = 1, 1, 6
pyver = float('%s.%s' % _sys.version_info[:2]) pyver = float("%s.%s" % _sys.version_info[:2])
try: try:
any any
except NameError: except NameError:
def any(iterable): def any(iterable):
for element in iterable: for element in iterable:
if element: if element:
return True return True
return False return False
try: try:
from collections import OrderedDict from collections import OrderedDict
except ImportError: except ImportError:
@ -64,34 +66,38 @@ class _RouteClassAttributeToGetattr(object):
def _is_descriptor(obj): def _is_descriptor(obj):
"""Returns True if obj is a descriptor, False otherwise.""" """Returns True if obj is a descriptor, False otherwise."""
return ( return (
hasattr(obj, '__get__') or hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
hasattr(obj, '__set__') or )
hasattr(obj, '__delete__'))
def _is_dunder(name): def _is_dunder(name):
"""Returns True if a __dunder__ name, False otherwise.""" """Returns True if a __dunder__ name, False otherwise."""
return (len(name) > 4 and return (
name[:2] == name[-2:] == '__' and len(name) > 4
name[2:3] != '_' and and name[:2] == name[-2:] == "__"
name[-3:-2] != '_') and name[2:3] != "_"
and name[-3:-2] != "_"
)
def _is_sunder(name): def _is_sunder(name):
"""Returns True if a _sunder_ name, False otherwise.""" """Returns True if a _sunder_ name, False otherwise."""
return (len(name) > 2 and return (
name[0] == name[-1] == '_' and len(name) > 2
name[1:2] != '_' and and name[0] == name[-1] == "_"
name[-2:-1] != '_') and name[1:2] != "_"
and name[-2:-1] != "_"
)
def _make_class_unpicklable(cls): def _make_class_unpicklable(cls):
"""Make the given class un-picklable.""" """Make the given class un-picklable."""
def _break_on_call_reduce(self, protocol=None): def _break_on_call_reduce(self, protocol=None):
raise TypeError('%r cannot be pickled' % self) raise TypeError("%r cannot be pickled" % self)
cls.__reduce_ex__ = _break_on_call_reduce cls.__reduce_ex__ = _break_on_call_reduce
cls.__module__ = '<unknown>' cls.__module__ = "<unknown>"
class _EnumDict(OrderedDict): class _EnumDict(OrderedDict):
@ -122,22 +128,22 @@ class _EnumDict(OrderedDict):
leftover from 2.x leftover from 2.x
""" """
if pyver >= 3.0 and key in ('_order_', '__order__'): if pyver >= 3.0 and key in ("_order_", "__order__"):
return return
elif key == '__order__': elif key == "__order__":
key = '_order_' key = "_order_"
if _is_sunder(key): if _is_sunder(key):
if key != '_order_': if key != "_order_":
raise ValueError('_names_ are reserved for future Enum use') raise ValueError("_names_ are reserved for future Enum use")
elif _is_dunder(key): elif _is_dunder(key):
pass pass
elif key in self._member_names: elif key in self._member_names:
# descriptor overwriting an enum? # descriptor overwriting an enum?
raise TypeError('Attempted to reuse key: %r' % key) raise TypeError("Attempted to reuse key: %r" % key)
elif not _is_descriptor(value): elif not _is_descriptor(value):
if key in self: if key in self:
# enum overwriting a descriptor? # enum overwriting a descriptor?
raise TypeError('Key already defined as: %r' % self[key]) raise TypeError("Key already defined as: %r" % self[key])
self._member_names.append(key) self._member_names.append(key)
super(_EnumDict, self).__setitem__(key, value) super(_EnumDict, self).__setitem__(key, value)
@ -150,6 +156,7 @@ Enum = None
class EnumMeta(type): class EnumMeta(type):
"""Metaclass for Enum""" """Metaclass for Enum"""
@classmethod @classmethod
def __prepare__(metacls, cls, bases): def __prepare__(metacls, cls, bases):
return _EnumDict() return _EnumDict()
@ -166,8 +173,9 @@ class EnumMeta(type):
classdict[k] = v classdict[k] = v
member_type, first_enum = metacls._get_mixins_(bases) member_type, first_enum = metacls._get_mixins_(bases)
__new__, save_new, use_args = metacls._find_new_(classdict, member_type, __new__, save_new, use_args = metacls._find_new_(
first_enum) classdict, member_type, first_enum
)
# save enum items into separate mapping so they don't get baked into # save enum items into separate mapping so they don't get baked into
# the new class # the new class
members = {k: classdict[k] for k in classdict._member_names} members = {k: classdict[k] for k in classdict._member_names}
@ -175,27 +183,33 @@ class EnumMeta(type):
del classdict[name] del classdict[name]
# py2 support for definition order # py2 support for definition order
_order_ = classdict.get('_order_') _order_ = classdict.get("_order_")
if _order_ is None: if _order_ is None:
if pyver < 3.0: if pyver < 3.0:
try: try:
_order_ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] _order_ = [
name
for (name, value) in sorted(
members.items(), key=lambda item: item[1]
)
]
except TypeError: except TypeError:
_order_ = [name for name in sorted(members.keys())] _order_ = [name for name in sorted(members.keys())]
else: else:
_order_ = classdict._member_names _order_ = classdict._member_names
else: else:
del classdict['_order_'] del classdict["_order_"]
if pyver < 3.0: if pyver < 3.0:
_order_ = _order_.replace(',', ' ').split() _order_ = _order_.replace(",", " ").split()
aliases = [name for name in members if name not in _order_] aliases = [name for name in members if name not in _order_]
_order_ += aliases _order_ += aliases
# check for illegal enum names (any others?) # check for illegal enum names (any others?)
invalid_names = set(members) & {'mro'} invalid_names = set(members) & {"mro"}
if invalid_names: if invalid_names:
raise ValueError('Invalid enum member name(s): %s' % ( raise ValueError(
', '.join(invalid_names), )) "Invalid enum member name(s): {}".format(", ".join(invalid_names))
)
# save attributes from super classes so we know if we can take # save attributes from super classes so we know if we can take
# the shortcut of storing members in the class dict # the shortcut of storing members in the class dict
@ -228,11 +242,11 @@ class EnumMeta(type):
args = (args,) # wrap it one more time args = (args,) # wrap it one more time
if not use_args or not args: if not use_args or not args:
enum_member = __new__(enum_class) enum_member = __new__(enum_class)
if not hasattr(enum_member, '_value_'): if not hasattr(enum_member, "_value_"):
enum_member._value_ = value enum_member._value_ = value
else: else:
enum_member = __new__(enum_class, *args) enum_member = __new__(enum_class, *args)
if not hasattr(enum_member, '_value_'): if not hasattr(enum_member, "_value_"):
enum_member._value_ = member_type(*args) enum_member._value_ = member_type(*args)
value = enum_member._value_ value = enum_member._value_
enum_member._name_ = member_name enum_member._name_ = member_name
@ -272,22 +286,26 @@ class EnumMeta(type):
# __reduce_ex__ instead of any of the others as it is preferred by # __reduce_ex__ instead of any of the others as it is preferred by
# pickle over __reduce__, and it handles all pickle protocols. # pickle over __reduce__, and it handles all pickle protocols.
unpicklable = False unpicklable = False
if '__reduce_ex__' not in classdict: if "__reduce_ex__" not in classdict:
if member_type is not object: if member_type is not object:
methods = ('__getnewargs_ex__', '__getnewargs__', methods = (
'__reduce_ex__', '__reduce__') "__getnewargs_ex__",
"__getnewargs__",
"__reduce_ex__",
"__reduce__",
)
if not any(m in member_type.__dict__ for m in methods): if not any(m in member_type.__dict__ for m in methods):
_make_class_unpicklable(enum_class) _make_class_unpicklable(enum_class)
unpicklable = True unpicklable = True
# double check that repr and friends are not the mixin's or various # double check that repr and friends are not the mixin's or various
# things break (such as pickle) # things break (such as pickle)
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"):
class_method = getattr(enum_class, name) class_method = getattr(enum_class, name)
getattr(member_type, name, None) getattr(member_type, name, None)
enum_method = getattr(first_enum, name, None) enum_method = getattr(first_enum, name, None)
if name not in classdict and class_method is not enum_method: if name not in classdict and class_method is not enum_method:
if name == '__reduce_ex__' and unpicklable: if name == "__reduce_ex__" and unpicklable:
continue continue
setattr(enum_class, name, enum_method) setattr(enum_class, name, enum_method)
@ -297,19 +315,19 @@ class EnumMeta(type):
if pyver < 2.6: if pyver < 2.6:
if issubclass(enum_class, int): if issubclass(enum_class, int):
setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) setattr(enum_class, "__cmp__", getattr(int, "__cmp__"))
elif pyver < 3.0: elif pyver < 3.0:
if issubclass(enum_class, int): if issubclass(enum_class, int):
for method in ( for method in (
'__le__', "__le__",
'__lt__', "__lt__",
'__gt__', "__gt__",
'__ge__', "__ge__",
'__eq__', "__eq__",
'__ne__', "__ne__",
'__hash__', "__hash__",
): ):
setattr(enum_class, method, getattr(int, method)) setattr(enum_class, method, getattr(int, method))
@ -319,8 +337,8 @@ class EnumMeta(type):
# if the user defined their own __new__, save it before it gets # if the user defined their own __new__, save it before it gets
# clobbered in case they subclass later # clobbered in case they subclass later
if save_new: if save_new:
setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) setattr(enum_class, "__member_new__", enum_class.__dict__["__new__"])
setattr(enum_class, '__new__', Enum.__dict__['__new__']) setattr(enum_class, "__new__", Enum.__dict__["__new__"])
return enum_class return enum_class
def __bool__(cls): def __bool__(cls):
@ -357,13 +375,16 @@ class EnumMeta(type):
# nicer error message when someone tries to delete an attribute # nicer error message when someone tries to delete an attribute
# (see issue19025). # (see issue19025).
if attr in cls._member_map_: if attr in cls._member_map_:
raise AttributeError( raise AttributeError("%s: cannot delete Enum member." % cls.__name__)
"%s: cannot delete Enum member." % cls.__name__)
super(EnumMeta, cls).__delattr__(attr) super(EnumMeta, cls).__delattr__(attr)
def __dir__(self): def __dir__(self):
return (['__class__', '__doc__', '__members__', '__module__'] + return [
self._member_names_) "__class__",
"__doc__",
"__members__",
"__module__",
] + self._member_names_
@property @property
def __members__(cls): def __members__(cls):
@ -416,9 +437,9 @@ class EnumMeta(type):
resulting in an inconsistent Enumeration. resulting in an inconsistent Enumeration.
""" """
member_map = cls.__dict__.get('_member_map_', {}) member_map = cls.__dict__.get("_member_map_", {})
if name in member_map: if name in member_map:
raise AttributeError('Cannot reassign members.') raise AttributeError("Cannot reassign members.")
super(EnumMeta, cls).__setattr__(name, value) super(EnumMeta, cls).__setattr__(name, value)
def _create_(cls, class_name, names=None, module=None, type=None, start=1): def _create_(cls, class_name, names=None, module=None, type=None, start=1):
@ -437,9 +458,9 @@ class EnumMeta(type):
# if class_name is unicode, attempt a conversion to ASCII # if class_name is unicode, attempt a conversion to ASCII
if isinstance(class_name, unicode): if isinstance(class_name, unicode):
try: try:
class_name = class_name.encode('ascii') class_name = class_name.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
raise TypeError('%r is not representable in ASCII' % class_name) raise TypeError("%r is not representable in ASCII" % class_name)
metacls = cls.__class__ metacls = cls.__class__
if type is None: if type is None:
bases = (cls,) bases = (cls,)
@ -450,7 +471,7 @@ class EnumMeta(type):
# special processing needed for names? # special processing needed for names?
if isinstance(names, basestring): if isinstance(names, basestring):
names = names.replace(',', ' ').split() names = names.replace(",", " ").split()
if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): if isinstance(names, (tuple, list)) and isinstance(names[0], basestring):
names = [(e, i + start) for (i, e) in enumerate(names)] names = [(e, i + start) for (i, e) in enumerate(names)]
@ -465,14 +486,14 @@ class EnumMeta(type):
_order_.append(member_name) _order_.append(member_name)
# only set _order_ in classdict if name/value was not from a mapping # only set _order_ in classdict if name/value was not from a mapping
if not isinstance(item, basestring): if not isinstance(item, basestring):
classdict['_order_'] = ' '.join(_order_) classdict["_order_"] = " ".join(_order_)
enum_class = metacls.__new__(metacls, class_name, bases, classdict) enum_class = metacls.__new__(metacls, class_name, bases, classdict)
# TODO: replace the frame hack if a blessed way to know the calling # TODO: replace the frame hack if a blessed way to know the calling
# module is ever developed # module is ever developed
if module is None: if module is None:
try: try:
module = _sys._getframe(2).f_globals['__name__'] module = _sys._getframe(2).f_globals["__name__"]
except (AttributeError, ValueError): except (AttributeError, ValueError):
pass pass
if module is None: if module is None:
@ -498,14 +519,14 @@ class EnumMeta(type):
# type has been mixed in so we can use the correct __new__ # type has been mixed in so we can use the correct __new__
member_type = first_enum = None member_type = first_enum = None
for base in bases: for base in bases:
if (base is not Enum and if base is not Enum and issubclass(base, Enum) and base._member_names_:
issubclass(base, Enum) and
base._member_names_):
raise TypeError("Cannot extend enumerations") raise TypeError("Cannot extend enumerations")
# base is now the last base in bases # base is now the last base in bases
if not issubclass(base, Enum): if not issubclass(base, Enum):
raise TypeError("new enumerations must be created as " raise TypeError(
"`ClassName([mixin_type,] enum_type)`") "new enumerations must be created as "
"`ClassName([mixin_type,] enum_type)`"
)
# get correct mix-in type (either mix-in type of Enum subclass, or # get correct mix-in type (either mix-in type of Enum subclass, or
# first base if last base is Enum) # first base if last base is Enum)
@ -528,6 +549,7 @@ class EnumMeta(type):
return member_type, first_enum return member_type, first_enum
if pyver < 3.0: if pyver < 3.0:
@staticmethod @staticmethod
def _find_new_(classdict, member_type, first_enum): def _find_new_(classdict, member_type, first_enum):
"""Returns the __new__ to be used for creating the enum members. """Returns the __new__ to be used for creating the enum members.
@ -540,32 +562,27 @@ class EnumMeta(type):
# now find the correct __new__, checking to see of one was defined # now find the correct __new__, checking to see of one was defined
# by the user; also check earlier enum classes in case a __new__ was # by the user; also check earlier enum classes in case a __new__ was
# saved as __member_new__ # saved as __member_new__
__new__ = classdict.get('__new__', None) __new__ = classdict.get("__new__", None)
if __new__: if __new__:
return None, True, True # __new__, save_new, use_args return None, True, True # __new__, save_new, use_args
N__new__ = getattr(None, '__new__') N__new__ = getattr(None, "__new__")
O__new__ = getattr(object, '__new__') O__new__ = getattr(object, "__new__")
if Enum is None: if Enum is None:
E__new__ = N__new__ E__new__ = N__new__
else: else:
E__new__ = Enum.__dict__['__new__'] E__new__ = Enum.__dict__["__new__"]
# check all possibles for __member_new__ before falling back to # check all possibles for __member_new__ before falling back to
# __new__ # __new__
for method in ('__member_new__', '__new__'): for method in ("__member_new__", "__new__"):
for possible in (member_type, first_enum): for possible in (member_type, first_enum):
try: try:
target = possible.__dict__[method] target = possible.__dict__[method]
except (AttributeError, KeyError): except (AttributeError, KeyError):
target = getattr(possible, method, None) target = getattr(possible, method, None)
if target not in [ if target not in [None, N__new__, O__new__, E__new__]:
None, if method == "__member_new__":
N__new__, classdict["__new__"] = target
O__new__,
E__new__,
]:
if method == '__member_new__':
classdict['__new__'] = target
return None, False, True return None, False, True
if isinstance(target, staticmethod): if isinstance(target, staticmethod):
target = target.__get__(member_type) target = target.__get__(member_type)
@ -585,7 +602,9 @@ class EnumMeta(type):
use_args = True use_args = True
return __new__, False, use_args return __new__, False, use_args
else: else:
@staticmethod @staticmethod
def _find_new_(classdict, member_type, first_enum): def _find_new_(classdict, member_type, first_enum):
"""Returns the __new__ to be used for creating the enum members. """Returns the __new__ to be used for creating the enum members.
@ -598,7 +617,7 @@ class EnumMeta(type):
# now find the correct __new__, checking to see of one was defined # now find the correct __new__, checking to see of one was defined
# by the user; also check earlier enum classes in case a __new__ was # by the user; also check earlier enum classes in case a __new__ was
# saved as __member_new__ # saved as __member_new__
__new__ = classdict.get('__new__', None) __new__ = classdict.get("__new__", None)
# should __new__ be saved as __member_new__ later? # should __new__ be saved as __member_new__ later?
save_new = __new__ is not None save_new = __new__ is not None
@ -606,7 +625,7 @@ class EnumMeta(type):
if __new__ is None: if __new__ is None:
# check all possibles for __member_new__ before falling back to # check all possibles for __member_new__ before falling back to
# __new__ # __new__
for method in ('__member_new__', '__new__'): for method in ("__member_new__", "__new__"):
for possible in (member_type, first_enum): for possible in (member_type, first_enum):
target = getattr(possible, method, None) target = getattr(possible, method, None)
if target not in ( if target not in (
@ -640,7 +659,9 @@ class EnumMeta(type):
# create the class. # create the class.
######################################################## ########################################################
temp_enum_dict = {} temp_enum_dict = {}
temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" temp_enum_dict[
"__doc__"
] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n"
def __new__(cls, value): def __new__(cls, value):
@ -661,39 +682,40 @@ def __new__(cls, value):
for member in cls._member_map_.values(): for member in cls._member_map_.values():
if member.value == value: if member.value == value:
return member return member
raise ValueError("%s is not a valid %s" % (value, cls.__name__)) raise ValueError("{} is not a valid {}".format(value, cls.__name__))
temp_enum_dict['__new__'] = __new__ temp_enum_dict["__new__"] = __new__
del __new__ del __new__
def __repr__(self): def __repr__(self):
return "<%s.%s: %r>" % ( return "<{}.{}: {!r}>".format(self.__class__.__name__, self._name_, self._value_)
self.__class__.__name__, self._name_, self._value_)
temp_enum_dict['__repr__'] = __repr__ temp_enum_dict["__repr__"] = __repr__
del __repr__ del __repr__
def __str__(self): def __str__(self):
return "%s.%s" % (self.__class__.__name__, self._name_) return "{}.{}".format(self.__class__.__name__, self._name_)
temp_enum_dict['__str__'] = __str__ temp_enum_dict["__str__"] = __str__
del __str__ del __str__
if pyver >= 3.0: if pyver >= 3.0:
def __dir__(self): def __dir__(self):
added_behavior = [ added_behavior = [
m m
for cls in self.__class__.mro() for cls in self.__class__.mro()
for m in cls.__dict__ for m in cls.__dict__
if m[0] != '_' and m not in self._member_map_ if m[0] != "_" and m not in self._member_map_
] ]
return (['__class__', '__doc__', '__module__', ] + added_behavior) return ["__class__", "__doc__", "__module__"] + added_behavior
temp_enum_dict['__dir__'] = __dir__
temp_enum_dict["__dir__"] = __dir__
del __dir__ del __dir__
@ -713,7 +735,7 @@ def __format__(self, format_spec):
return cls.__format__(val, format_spec) return cls.__format__(val, format_spec)
temp_enum_dict['__format__'] = __format__ temp_enum_dict["__format__"] = __format__
del __format__ del __format__
@ -728,30 +750,50 @@ if pyver < 2.6:
return 0 return 0
return -1 return -1
return NotImplemented return NotImplemented
raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) raise TypeError(
temp_enum_dict['__cmp__'] = __cmp__ "unorderable types: %s() and %s()"
% (self.__class__.__name__, other.__class__.__name__)
)
temp_enum_dict["__cmp__"] = __cmp__
del __cmp__ del __cmp__
else: else:
def __le__(self, other): def __le__(self, other):
raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) raise TypeError(
temp_enum_dict['__le__'] = __le__ "unorderable types: %s() <= %s()"
% (self.__class__.__name__, other.__class__.__name__)
)
temp_enum_dict["__le__"] = __le__
del __le__ del __le__
def __lt__(self, other): def __lt__(self, other):
raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) raise TypeError(
temp_enum_dict['__lt__'] = __lt__ "unorderable types: %s() < %s()"
% (self.__class__.__name__, other.__class__.__name__)
)
temp_enum_dict["__lt__"] = __lt__
del __lt__ del __lt__
def __ge__(self, other): def __ge__(self, other):
raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) raise TypeError(
temp_enum_dict['__ge__'] = __ge__ "unorderable types: %s() >= %s()"
% (self.__class__.__name__, other.__class__.__name__)
)
temp_enum_dict["__ge__"] = __ge__
del __ge__ del __ge__
def __gt__(self, other): def __gt__(self, other):
raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) raise TypeError(
temp_enum_dict['__gt__'] = __gt__ "unorderable types: %s() > %s()"
% (self.__class__.__name__, other.__class__.__name__)
)
temp_enum_dict["__gt__"] = __gt__
del __gt__ del __gt__
@ -761,7 +803,7 @@ def __eq__(self, other):
return NotImplemented return NotImplemented
temp_enum_dict['__eq__'] = __eq__ temp_enum_dict["__eq__"] = __eq__
del __eq__ del __eq__
@ -771,7 +813,7 @@ def __ne__(self, other):
return NotImplemented return NotImplemented
temp_enum_dict['__ne__'] = __ne__ temp_enum_dict["__ne__"] = __ne__
del __ne__ del __ne__
@ -779,7 +821,7 @@ def __hash__(self):
return hash(self._name_) return hash(self._name_)
temp_enum_dict['__hash__'] = __hash__ temp_enum_dict["__hash__"] = __hash__
del __hash__ del __hash__
@ -787,7 +829,7 @@ def __reduce_ex__(self, proto):
return self.__class__, (self._value_,) return self.__class__, (self._value_,)
temp_enum_dict['__reduce_ex__'] = __reduce_ex__ temp_enum_dict["__reduce_ex__"] = __reduce_ex__
del __reduce_ex__ del __reduce_ex__
# _RouteClassAttributeToGetattr is used to provide access to the `name` # _RouteClassAttributeToGetattr is used to provide access to the `name`
@ -803,7 +845,7 @@ def name(self):
return self._name_ return self._name_
temp_enum_dict['name'] = name temp_enum_dict["name"] = name
del name del name
@ -812,7 +854,7 @@ def value(self):
return self._value_ return self._value_
temp_enum_dict['value'] = value temp_enum_dict["value"] = value
del value del value
@ -839,10 +881,10 @@ def _convert(cls, name, module, filter, source=None):
return cls return cls
temp_enum_dict['_convert'] = _convert temp_enum_dict["_convert"] = _convert
del _convert del _convert
Enum = EnumMeta('Enum', (object, ), temp_enum_dict) Enum = EnumMeta("Enum", (object,), temp_enum_dict)
del temp_enum_dict del temp_enum_dict
# Enum has now been created # Enum has now been created
@ -864,10 +906,10 @@ def unique(enumeration):
if name != member.name: if name != member.name:
duplicates.append((name, member.name)) duplicates.append((name, member.name))
if duplicates: if duplicates:
duplicate_names = ', '.join( duplicate_names = ", ".join(
["%s -> %s" % (alias, name) for (alias, name) in duplicates] ["{} -> {}".format(alias, name) for (alias, name) in duplicates]
) )
raise ValueError('duplicate names found in %r: %s' % raise ValueError(
(enumeration, duplicate_names) "duplicate names found in {!r}: {}".format(enumeration, duplicate_names)
) )
return enumeration return enumeration

View File

@ -1,19 +1,23 @@
is_init_subclass_available = hasattr(object, '__init_subclass__') is_init_subclass_available = hasattr(object, "__init_subclass__")
if not is_init_subclass_available: if not is_init_subclass_available:
class InitSubclassMeta(type): class InitSubclassMeta(type):
"""Metaclass that implements PEP 487 protocol""" """Metaclass that implements PEP 487 protocol"""
def __new__(cls, name, bases, ns, **kwargs): def __new__(cls, name, bases, ns, **kwargs):
__init_subclass__ = ns.pop('__init_subclass__', None) __init_subclass__ = ns.pop("__init_subclass__", None)
if __init_subclass__: if __init_subclass__:
__init_subclass__ = classmethod(__init_subclass__) __init_subclass__ = classmethod(__init_subclass__)
ns['__init_subclass__'] = __init_subclass__ ns["__init_subclass__"] = __init_subclass__
return super(InitSubclassMeta, cls).__new__(cls, name, bases, ns, **kwargs) return super(InitSubclassMeta, cls).__new__(cls, name, bases, ns, **kwargs)
def __init__(cls, name, bases, ns, **kwargs): def __init__(cls, name, bases, ns, **kwargs):
super(InitSubclassMeta, cls).__init__(name, bases, ns) super(InitSubclassMeta, cls).__init__(name, bases, ns)
super_class = super(cls, cls) super_class = super(cls, cls)
if hasattr(super_class, '__init_subclass__'): if hasattr(super_class, "__init_subclass__"):
super_class.__init_subclass__.__func__(cls, **kwargs) super_class.__init_subclass__.__func__(cls, **kwargs)
else: else:
InitSubclassMeta = type # type: ignore InitSubclassMeta = type # type: ignore

View File

@ -13,22 +13,24 @@ from collections import OrderedDict
__version__ = "0.4" __version__ = "0.4"
__all__ = ['BoundArguments', 'Parameter', 'Signature', 'signature'] __all__ = ["BoundArguments", "Parameter", "Signature", "signature"]
_WrapperDescriptor = type(type.__call__) _WrapperDescriptor = type(type.__call__)
_MethodWrapper = type(all.__call__) _MethodWrapper = type(all.__call__)
_NonUserDefinedCallables = (_WrapperDescriptor, _NonUserDefinedCallables = (
_WrapperDescriptor,
_MethodWrapper, _MethodWrapper,
types.BuiltinFunctionType) types.BuiltinFunctionType,
)
def formatannotation(annotation, base_module=None): def formatannotation(annotation, base_module=None):
if isinstance(annotation, type): if isinstance(annotation, type):
if annotation.__module__ in ('builtins', '__builtin__', base_module): if annotation.__module__ in ("builtins", "__builtin__", base_module):
return annotation.__name__ return annotation.__name__
return annotation.__module__ + '.' + annotation.__name__ return annotation.__module__ + "." + annotation.__name__
return repr(annotation) return repr(annotation)
@ -49,20 +51,20 @@ def _get_user_defined_method(cls, method_name, *nested):
def signature(obj): def signature(obj):
'''Get a signature object for the passed callable.''' """Get a signature object for the passed callable."""
if not callable(obj): if not callable(obj):
raise TypeError('{!r} is not a callable object'.format(obj)) raise TypeError("{!r} is not a callable object".format(obj))
if isinstance(obj, types.MethodType): if isinstance(obj, types.MethodType):
sig = signature(obj.__func__) sig = signature(obj.__func__)
if obj.__self__ is None: if obj.__self__ is None:
# Unbound method: the first parameter becomes positional-only # Unbound method: the first parameter becomes positional-only
if sig.parameters: if sig.parameters:
first = sig.parameters.values()[0].replace( first = sig.parameters.values()[0].replace(kind=_POSITIONAL_ONLY)
kind=_POSITIONAL_ONLY)
return sig.replace( return sig.replace(
parameters=(first,) + tuple(sig.parameters.values())[1:]) parameters=(first,) + tuple(sig.parameters.values())[1:]
)
else: else:
return sig return sig
else: else:
@ -99,7 +101,7 @@ def signature(obj):
try: try:
ba = sig.bind_partial(*partial_args, **partial_keywords) ba = sig.bind_partial(*partial_args, **partial_keywords)
except TypeError as ex: except TypeError as ex:
msg = 'partial object {!r} has incorrect arguments'.format(obj) msg = "partial object {!r} has incorrect arguments".format(obj)
raise ValueError(msg) raise ValueError(msg)
for arg_name, arg_value in ba.arguments.items(): for arg_name, arg_value in ba.arguments.items():
@ -122,11 +124,14 @@ def signature(obj):
# flag. Later, in '_bind', the 'default' value of this # flag. Later, in '_bind', the 'default' value of this
# parameter will be added to 'kwargs', to simulate # parameter will be added to 'kwargs', to simulate
# the 'functools.partial' real call. # the 'functools.partial' real call.
new_params[arg_name] = param.replace(default=arg_value, new_params[arg_name] = param.replace(
_partial_kwarg=True) default=arg_value, _partial_kwarg=True
)
elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) and elif (
not param._partial_kwarg): param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL)
and not param._partial_kwarg
):
new_params.pop(arg_name) new_params.pop(arg_name)
return sig.replace(parameters=new_params.values()) return sig.replace(parameters=new_params.values())
@ -137,17 +142,17 @@ def signature(obj):
# First, let's see if it has an overloaded __call__ defined # First, let's see if it has an overloaded __call__ defined
# in its metaclass # in its metaclass
call = _get_user_defined_method(type(obj), '__call__') call = _get_user_defined_method(type(obj), "__call__")
if call is not None: if call is not None:
sig = signature(call) sig = signature(call)
else: else:
# Now we check if the 'obj' class has a '__new__' method # Now we check if the 'obj' class has a '__new__' method
new = _get_user_defined_method(obj, '__new__') new = _get_user_defined_method(obj, "__new__")
if new is not None: if new is not None:
sig = signature(new) sig = signature(new)
else: else:
# Finally, we should have at least __init__ implemented # Finally, we should have at least __init__ implemented
init = _get_user_defined_method(obj, '__init__') init = _get_user_defined_method(obj, "__init__")
if init is not None: if init is not None:
sig = signature(init) sig = signature(init)
elif not isinstance(obj, _NonUserDefinedCallables): elif not isinstance(obj, _NonUserDefinedCallables):
@ -155,7 +160,7 @@ def signature(obj):
# We also check that the 'obj' is not an instance of # We also check that the 'obj' is not an instance of
# _WrapperDescriptor or _MethodWrapper to avoid # _WrapperDescriptor or _MethodWrapper to avoid
# infinite recursion (and even potential segfault) # infinite recursion (and even potential segfault)
call = _get_user_defined_method(type(obj), '__call__', 'im_func') call = _get_user_defined_method(type(obj), "__call__", "im_func")
if call is not None: if call is not None:
sig = signature(call) sig = signature(call)
@ -166,14 +171,14 @@ def signature(obj):
if isinstance(obj, types.BuiltinFunctionType): if isinstance(obj, types.BuiltinFunctionType):
# Raise a nicer error message for builtins # Raise a nicer error message for builtins
msg = 'no signature found for builtin function {!r}'.format(obj) msg = "no signature found for builtin function {!r}".format(obj)
raise ValueError(msg) raise ValueError(msg)
raise ValueError('callable {!r} is not supported by signature'.format(obj)) raise ValueError("callable {!r} is not supported by signature".format(obj))
class _void(object): class _void(object):
'''A private marker - used in Parameter & Signature''' """A private marker - used in Parameter & Signature"""
class _empty(object): class _empty(object):
@ -183,25 +188,25 @@ class _empty(object):
class _ParameterKind(int): class _ParameterKind(int):
def __new__(self, *args, **kwargs): def __new__(self, *args, **kwargs):
obj = int.__new__(self, *args) obj = int.__new__(self, *args)
obj._name = kwargs['name'] obj._name = kwargs["name"]
return obj return obj
def __str__(self): def __str__(self):
return self._name return self._name
def __repr__(self): def __repr__(self):
return '<_ParameterKind: {!r}>'.format(self._name) return "<_ParameterKind: {!r}>".format(self._name)
_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY') _POSITIONAL_ONLY = _ParameterKind(0, name="POSITIONAL_ONLY")
_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD') _POSITIONAL_OR_KEYWORD = _ParameterKind(1, name="POSITIONAL_OR_KEYWORD")
_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL') _VAR_POSITIONAL = _ParameterKind(2, name="VAR_POSITIONAL")
_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY') _KEYWORD_ONLY = _ParameterKind(3, name="KEYWORD_ONLY")
_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD') _VAR_KEYWORD = _ParameterKind(4, name="VAR_KEYWORD")
class Parameter(object): class Parameter(object):
'''Represents a parameter in a function signature. """Represents a parameter in a function signature.
Has the following public attributes: Has the following public attributes:
* name : str * name : str
The name of the parameter as a string. The name of the parameter as a string.
@ -216,9 +221,9 @@ class Parameter(object):
Possible values: `Parameter.POSITIONAL_ONLY`, Possible values: `Parameter.POSITIONAL_ONLY`,
`Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`, `Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`,
`Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`. `Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`.
''' """
__slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg') __slots__ = ("_name", "_kind", "_default", "_annotation", "_partial_kwarg")
POSITIONAL_ONLY = _POSITIONAL_ONLY POSITIONAL_ONLY = _POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD
@ -228,30 +233,37 @@ class Parameter(object):
empty = _empty empty = _empty
def __init__(self, name, kind, default=_empty, annotation=_empty, def __init__(
_partial_kwarg=False): self, name, kind, default=_empty, annotation=_empty, _partial_kwarg=False
):
if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD, if kind not in (
_VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD): _POSITIONAL_ONLY,
_POSITIONAL_OR_KEYWORD,
_VAR_POSITIONAL,
_KEYWORD_ONLY,
_VAR_KEYWORD,
):
raise ValueError("invalid value for 'Parameter.kind' attribute") raise ValueError("invalid value for 'Parameter.kind' attribute")
self._kind = kind self._kind = kind
if default is not _empty: if default is not _empty:
if kind in (_VAR_POSITIONAL, _VAR_KEYWORD): if kind in (_VAR_POSITIONAL, _VAR_KEYWORD):
msg = '{} parameters cannot have default values'.format(kind) msg = "{} parameters cannot have default values".format(kind)
raise ValueError(msg) raise ValueError(msg)
self._default = default self._default = default
self._annotation = annotation self._annotation = annotation
if name is None: if name is None:
if kind != _POSITIONAL_ONLY: if kind != _POSITIONAL_ONLY:
raise ValueError("None is not a valid name for a " raise ValueError(
"non-positional-only parameter") "None is not a valid name for a " "non-positional-only parameter"
)
self._name = name self._name = name
else: else:
name = str(name) name = str(name)
if kind != _POSITIONAL_ONLY and not re.match(r'[a-z_]\w*$', name, re.I): if kind != _POSITIONAL_ONLY and not re.match(r"[a-z_]\w*$", name, re.I):
msg = '{!r} is not a valid parameter name'.format(name) msg = "{!r} is not a valid parameter name".format(name)
raise ValueError(msg) raise ValueError(msg)
self._name = name self._name = name
@ -273,9 +285,15 @@ class Parameter(object):
def kind(self): def kind(self):
return self._kind return self._kind
def replace(self, name=_void, kind=_void, annotation=_void, def replace(
default=_void, _partial_kwarg=_void): self,
'''Creates a customized copy of the Parameter.''' name=_void,
kind=_void,
annotation=_void,
default=_void,
_partial_kwarg=_void,
):
"""Creates a customized copy of the Parameter."""
if name is _void: if name is _void:
name = self._name name = self._name
@ -292,8 +310,13 @@ class Parameter(object):
if _partial_kwarg is _void: if _partial_kwarg is _void:
_partial_kwarg = self._partial_kwarg _partial_kwarg = self._partial_kwarg
return type(self)(name, kind, default=default, annotation=annotation, return type(self)(
_partial_kwarg=_partial_kwarg) name,
kind,
default=default,
annotation=annotation,
_partial_kwarg=_partial_kwarg,
)
def __str__(self): def __str__(self):
kind = self.kind kind = self.kind
@ -301,45 +324,45 @@ class Parameter(object):
formatted = self._name formatted = self._name
if kind == _POSITIONAL_ONLY: if kind == _POSITIONAL_ONLY:
if formatted is None: if formatted is None:
formatted = '' formatted = ""
formatted = '<{}>'.format(formatted) formatted = "<{}>".format(formatted)
# Add annotation and default value # Add annotation and default value
if self._annotation is not _empty: if self._annotation is not _empty:
formatted = '{}:{}'.format(formatted, formatted = "{}:{}".format(formatted, formatannotation(self._annotation))
formatannotation(self._annotation))
if self._default is not _empty: if self._default is not _empty:
formatted = '{}={}'.format(formatted, repr(self._default)) formatted = "{}={}".format(formatted, repr(self._default))
if kind == _VAR_POSITIONAL: if kind == _VAR_POSITIONAL:
formatted = '*' + formatted formatted = "*" + formatted
elif kind == _VAR_KEYWORD: elif kind == _VAR_KEYWORD:
formatted = '**' + formatted formatted = "**" + formatted
return formatted return formatted
def __repr__(self): def __repr__(self):
return '<{} at {:#x} {!r}>'.format(self.__class__.__name__, return "<{} at {:#x} {!r}>".format(self.__class__.__name__, id(self), self.name)
id(self), self.name)
def __hash__(self): def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__) msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg) raise TypeError(msg)
def __eq__(self, other): def __eq__(self, other):
return (issubclass(other.__class__, Parameter) and return (
self._name == other._name and issubclass(other.__class__, Parameter)
self._kind == other._kind and and self._name == other._name
self._default == other._default and and self._kind == other._kind
self._annotation == other._annotation) and self._default == other._default
and self._annotation == other._annotation
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
class BoundArguments(object): class BoundArguments(object):
'''Result of `Signature.bind` call. Holds the mapping of arguments """Result of `Signature.bind` call. Holds the mapping of arguments
to the function's parameters. to the function's parameters.
Has the following public attributes: Has the following public attributes:
* arguments : OrderedDict * arguments : OrderedDict
@ -351,7 +374,7 @@ class BoundArguments(object):
Tuple of positional arguments values. Tuple of positional arguments values.
* kwargs : dict * kwargs : dict
Dict of keyword arguments values. Dict of keyword arguments values.
''' """
def __init__(self, signature, arguments): def __init__(self, signature, arguments):
self.arguments = arguments self.arguments = arguments
@ -365,8 +388,7 @@ class BoundArguments(object):
def args(self): def args(self):
args = [] args = []
for param_name, param in self._signature.parameters.items(): for param_name, param in self._signature.parameters.items():
if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or param._partial_kwarg:
param._partial_kwarg):
# Keyword arguments mapped by 'functools.partial' # Keyword arguments mapped by 'functools.partial'
# (Parameter._partial_kwarg is True) are mapped # (Parameter._partial_kwarg is True) are mapped
# in 'BoundArguments.kwargs', along with VAR_KEYWORD & # in 'BoundArguments.kwargs', along with VAR_KEYWORD &
@ -395,8 +417,7 @@ class BoundArguments(object):
kwargs_started = False kwargs_started = False
for param_name, param in self._signature.parameters.items(): for param_name, param in self._signature.parameters.items():
if not kwargs_started: if not kwargs_started:
if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or param._partial_kwarg:
param._partial_kwarg):
kwargs_started = True kwargs_started = True
else: else:
if param_name not in self.arguments: if param_name not in self.arguments:
@ -425,16 +446,18 @@ class BoundArguments(object):
raise TypeError(msg) raise TypeError(msg)
def __eq__(self, other): def __eq__(self, other):
return (issubclass(other.__class__, BoundArguments) and return (
self.signature == other.signature and issubclass(other.__class__, BoundArguments)
self.arguments == other.arguments) and self.signature == other.signature
and self.arguments == other.arguments
)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
class Signature(object): class Signature(object):
'''A Signature object represents the overall signature of a function. """A Signature object represents the overall signature of a function.
It stores a Parameter object for each parameter accepted by the It stores a Parameter object for each parameter accepted by the
function, as well as information specific to the function itself. function, as well as information specific to the function itself.
A Signature object has the following public attributes and methods: A Signature object has the following public attributes and methods:
@ -452,20 +475,21 @@ class Signature(object):
* bind_partial(*args, **kwargs) -> BoundArguments * bind_partial(*args, **kwargs) -> BoundArguments
Creates a partial mapping from positional and keyword arguments Creates a partial mapping from positional and keyword arguments
to parameters (simulating 'functools.partial' behavior.) to parameters (simulating 'functools.partial' behavior.)
''' """
__slots__ = ('_return_annotation', '_parameters') __slots__ = ("_return_annotation", "_parameters")
_parameter_cls = Parameter _parameter_cls = Parameter
_bound_arguments_cls = BoundArguments _bound_arguments_cls = BoundArguments
empty = _empty empty = _empty
def __init__(self, parameters=None, return_annotation=_empty, def __init__(
__validate_parameters__=True): self, parameters=None, return_annotation=_empty, __validate_parameters__=True
'''Constructs Signature from the given list of Parameter ):
"""Constructs Signature from the given list of Parameter
objects and 'return_annotation'. All arguments are optional. objects and 'return_annotation'. All arguments are optional.
''' """
if parameters is None: if parameters is None:
params = OrderedDict() params = OrderedDict()
@ -477,7 +501,7 @@ class Signature(object):
for idx, param in enumerate(parameters): for idx, param in enumerate(parameters):
kind = param.kind kind = param.kind
if kind < top_kind: if kind < top_kind:
msg = 'wrong parameter order: {0} before {1}' msg = "wrong parameter order: {0} before {1}"
msg = msg.format(top_kind, param.kind) msg = msg.format(top_kind, param.kind)
raise ValueError(msg) raise ValueError(msg)
else: else:
@ -489,22 +513,21 @@ class Signature(object):
param = param.replace(name=name) param = param.replace(name=name)
if name in params: if name in params:
msg = 'duplicate parameter name: {!r}'.format(name) msg = "duplicate parameter name: {!r}".format(name)
raise ValueError(msg) raise ValueError(msg)
params[name] = param params[name] = param
else: else:
params = OrderedDict(((param.name, param) params = OrderedDict(((param.name, param) for param in parameters))
for param in parameters))
self._parameters = params self._parameters = params
self._return_annotation = return_annotation self._return_annotation = return_annotation
@classmethod @classmethod
def from_function(cls, func): def from_function(cls, func):
'''Constructs Signature for the given python function''' """Constructs Signature for the given python function"""
if not isinstance(func, types.FunctionType): if not isinstance(func, types.FunctionType):
raise TypeError('{!r} is not a Python function'.format(func)) raise TypeError("{!r} is not a Python function".format(func))
Parameter = cls._parameter_cls Parameter = cls._parameter_cls
@ -513,11 +536,11 @@ class Signature(object):
pos_count = func_code.co_argcount pos_count = func_code.co_argcount
arg_names = func_code.co_varnames arg_names = func_code.co_varnames
positional = tuple(arg_names[:pos_count]) positional = tuple(arg_names[:pos_count])
keyword_only_count = getattr(func_code, 'co_kwonlyargcount', 0) keyword_only_count = getattr(func_code, "co_kwonlyargcount", 0)
keyword_only = arg_names[pos_count : (pos_count + keyword_only_count)] keyword_only = arg_names[pos_count : (pos_count + keyword_only_count)]
annotations = getattr(func, '__annotations__', {}) annotations = getattr(func, "__annotations__", {})
defaults = func.__defaults__ defaults = func.__defaults__
kwdefaults = getattr(func, '__kwdefaults__', None) kwdefaults = getattr(func, "__kwdefaults__", None)
if defaults: if defaults:
pos_default_count = len(defaults) pos_default_count = len(defaults)
@ -530,22 +553,29 @@ class Signature(object):
non_default_count = pos_count - pos_default_count non_default_count = pos_count - pos_default_count
for name in positional[:non_default_count]: for name in positional[:non_default_count]:
annotation = annotations.get(name, _empty) annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, parameters.append(
kind=_POSITIONAL_OR_KEYWORD)) Parameter(name, annotation=annotation, kind=_POSITIONAL_OR_KEYWORD)
)
# ... w/ defaults. # ... w/ defaults.
for offset, name in enumerate(positional[non_default_count:]): for offset, name in enumerate(positional[non_default_count:]):
annotation = annotations.get(name, _empty) annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, parameters.append(
Parameter(
name,
annotation=annotation,
kind=_POSITIONAL_OR_KEYWORD, kind=_POSITIONAL_OR_KEYWORD,
default=defaults[offset])) default=defaults[offset],
)
)
# *args # *args
if func_code.co_flags & 0x04: if func_code.co_flags & 0x04:
name = arg_names[pos_count + keyword_only_count] name = arg_names[pos_count + keyword_only_count]
annotation = annotations.get(name, _empty) annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, parameters.append(
kind=_VAR_POSITIONAL)) Parameter(name, annotation=annotation, kind=_VAR_POSITIONAL)
)
# Keyword-only parameters. # Keyword-only parameters.
for name in keyword_only: for name in keyword_only:
@ -554,9 +584,11 @@ class Signature(object):
default = kwdefaults.get(name, _empty) default = kwdefaults.get(name, _empty)
annotation = annotations.get(name, _empty) annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, parameters.append(
kind=_KEYWORD_ONLY, Parameter(
default=default)) name, annotation=annotation, kind=_KEYWORD_ONLY, default=default
)
)
# **kwargs # **kwargs
if func_code.co_flags & 0x08: if func_code.co_flags & 0x08:
index = pos_count + keyword_only_count index = pos_count + keyword_only_count
@ -565,12 +597,13 @@ class Signature(object):
name = arg_names[index] name = arg_names[index]
annotation = annotations.get(name, _empty) annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, parameters.append(Parameter(name, annotation=annotation, kind=_VAR_KEYWORD))
kind=_VAR_KEYWORD))
return cls(parameters, return cls(
return_annotation=annotations.get('return', _empty), parameters,
__validate_parameters__=False) return_annotation=annotations.get("return", _empty),
__validate_parameters__=False,
)
@property @property
def parameters(self): def parameters(self):
@ -584,10 +617,10 @@ class Signature(object):
return self._return_annotation return self._return_annotation
def replace(self, parameters=_void, return_annotation=_void): def replace(self, parameters=_void, return_annotation=_void):
'''Creates a customized copy of the Signature. """Creates a customized copy of the Signature.
Pass 'parameters' and/or 'return_annotation' arguments Pass 'parameters' and/or 'return_annotation' arguments
to override them in the new copy. to override them in the new copy.
''' """
if parameters is _void: if parameters is _void:
parameters = self.parameters.values() parameters = self.parameters.values()
@ -595,21 +628,23 @@ class Signature(object):
if return_annotation is _void: if return_annotation is _void:
return_annotation = self._return_annotation return_annotation = self._return_annotation
return type(self)(parameters, return type(self)(parameters, return_annotation=return_annotation)
return_annotation=return_annotation)
def __hash__(self): def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__) msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg) raise TypeError(msg)
def __eq__(self, other): def __eq__(self, other):
if (not issubclass(type(other), Signature) or if (
self.return_annotation != other.return_annotation or not issubclass(type(other), Signature)
len(self.parameters) != len(other.parameters)): or self.return_annotation != other.return_annotation
or len(self.parameters) != len(other.parameters)
):
return False return False
other_positions = {param: idx other_positions = {
for idx, param in enumerate(other.parameters.keys())} param: idx for idx, param in enumerate(other.parameters.keys())
}
for idx, (param_name, param) in enumerate(self.parameters.items()): for idx, (param_name, param) in enumerate(self.parameters.items()):
if param.kind == _KEYWORD_ONLY: if param.kind == _KEYWORD_ONLY:
@ -626,8 +661,7 @@ class Signature(object):
except KeyError: except KeyError:
return False return False
else: else:
if (idx != other_idx or if idx != other_idx or param != other.parameters[param_name]:
param != other.parameters[param_name]):
return False return False
return True return True
@ -636,7 +670,7 @@ class Signature(object):
return not self.__eq__(other) return not self.__eq__(other)
def _bind(self, args, kwargs, partial=False): def _bind(self, args, kwargs, partial=False):
'''Private method. Don't use directly.''' """Private method. Don't use directly."""
arguments = OrderedDict() arguments = OrderedDict()
@ -649,7 +683,7 @@ class Signature(object):
# See 'functools.partial' case in 'signature()' implementation # See 'functools.partial' case in 'signature()' implementation
# for details. # for details.
for param_name, param in self.parameters.items(): for param_name, param in self.parameters.items():
if (param._partial_kwarg and param_name not in kwargs): if param._partial_kwarg and param_name not in kwargs:
# Simulating 'functools.partial' behavior # Simulating 'functools.partial' behavior
kwargs[param_name] = param.default kwargs[param_name] = param.default
@ -673,14 +707,12 @@ class Signature(object):
break break
elif param.name in kwargs: elif param.name in kwargs:
if param.kind == _POSITIONAL_ONLY: if param.kind == _POSITIONAL_ONLY:
msg = '{arg!r} parameter is positional only, ' \ msg = "{arg!r} parameter is positional only, " "but was passed as a keyword"
'but was passed as a keyword'
msg = msg.format(arg=param.name) msg = msg.format(arg=param.name)
raise TypeError(msg) raise TypeError(msg)
parameters_ex = (param,) parameters_ex = (param,)
break break
elif (param.kind == _VAR_KEYWORD or elif param.kind == _VAR_KEYWORD or param.default is not _empty:
param.default is not _empty):
# That's fine too - we have a default value for this # That's fine too - we have a default value for this
# parameter. So, lets start parsing `kwargs`, starting # parameter. So, lets start parsing `kwargs`, starting
# with the current parameter # with the current parameter
@ -691,7 +723,7 @@ class Signature(object):
parameters_ex = (param,) parameters_ex = (param,)
break break
else: else:
msg = '{arg!r} parameter lacking default value' msg = "{arg!r} parameter lacking default value"
msg = msg.format(arg=param.name) msg = msg.format(arg=param.name)
raise TypeError(msg) raise TypeError(msg)
else: else:
@ -699,12 +731,12 @@ class Signature(object):
try: try:
param = next(parameters) param = next(parameters)
except StopIteration: except StopIteration:
raise TypeError('too many positional arguments') raise TypeError("too many positional arguments")
else: else:
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY): if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
# Looks like we have no parameter for this positional # Looks like we have no parameter for this positional
# argument # argument
raise TypeError('too many positional arguments') raise TypeError("too many positional arguments")
if param.kind == _VAR_POSITIONAL: if param.kind == _VAR_POSITIONAL:
# We have an '*args'-like argument, let's fill it with # We have an '*args'-like argument, let's fill it with
@ -716,8 +748,10 @@ class Signature(object):
break break
if param.name in kwargs: if param.name in kwargs:
raise TypeError('multiple values for argument ' raise TypeError(
'{arg!r}'.format(arg=param.name)) "multiple values for argument "
"{arg!r}".format(arg=param.name)
)
arguments[param.name] = arg_val arguments[param.name] = arg_val
@ -729,9 +763,10 @@ class Signature(object):
# This should never happen in case of a properly built # This should never happen in case of a properly built
# Signature object (but let's have this check here # Signature object (but let's have this check here
# to ensure correct behaviour just in case) # to ensure correct behaviour just in case)
raise TypeError('{arg!r} parameter is positional only, ' raise TypeError(
'but was passed as a keyword'. "{arg!r} parameter is positional only, "
format(arg=param.name)) "but was passed as a keyword".format(arg=param.name)
)
if param.kind == _VAR_KEYWORD: if param.kind == _VAR_KEYWORD:
# Memorize that we have a '**kwargs'-like parameter # Memorize that we have a '**kwargs'-like parameter
@ -746,10 +781,14 @@ class Signature(object):
# if it has a default value, or it is an '*args'-like # if it has a default value, or it is an '*args'-like
# parameter, left alone by the processing of positional # parameter, left alone by the processing of positional
# arguments. # arguments.
if (not partial and param.kind != _VAR_POSITIONAL and if (
param.default is _empty): not partial
raise TypeError('{arg!r} parameter lacking default value'. and param.kind != _VAR_POSITIONAL
format(arg=param_name)) and param.default is _empty
):
raise TypeError(
"{arg!r} parameter lacking default value".format(arg=param_name)
)
else: else:
arguments[param_name] = arg_val arguments[param_name] = arg_val
@ -759,22 +798,22 @@ class Signature(object):
# Process our '**kwargs'-like parameter # Process our '**kwargs'-like parameter
arguments[kwargs_param.name] = kwargs arguments[kwargs_param.name] = kwargs
else: else:
raise TypeError('too many keyword arguments') raise TypeError("too many keyword arguments")
return self._bound_arguments_cls(self, arguments) return self._bound_arguments_cls(self, arguments)
def bind(self, *args, **kwargs): def bind(self, *args, **kwargs):
'''Get a BoundArguments object, that maps the passed `args` """Get a BoundArguments object, that maps the passed `args`
and `kwargs` to the function's signature. Raises `TypeError` and `kwargs` to the function's signature. Raises `TypeError`
if the passed arguments can not be bound. if the passed arguments can not be bound.
''' """
return self._bind(args, kwargs) return self._bind(args, kwargs)
def bind_partial(self, *args, **kwargs): def bind_partial(self, *args, **kwargs):
'''Get a BoundArguments object, that partially maps the """Get a BoundArguments object, that partially maps the
passed `args` and `kwargs` to the function's signature. passed `args` and `kwargs` to the function's signature.
Raises `TypeError` if the passed arguments can not be bound. Raises `TypeError` if the passed arguments can not be bound.
''' """
return self._bind(args, kwargs, partial=True) return self._bind(args, kwargs, partial=True)
def __str__(self): def __str__(self):
@ -792,17 +831,17 @@ class Signature(object):
# We have a keyword-only parameter to render and we haven't # We have a keyword-only parameter to render and we haven't
# rendered an '*args'-like parameter before, so add a '*' # rendered an '*args'-like parameter before, so add a '*'
# separator to the parameters list ("foo(arg1, *, arg2)" case) # separator to the parameters list ("foo(arg1, *, arg2)" case)
result.append('*') result.append("*")
# This condition should be only triggered once, so # This condition should be only triggered once, so
# reset the flag # reset the flag
render_kw_only_separator = False render_kw_only_separator = False
result.append(formatted) result.append(formatted)
rendered = '({})'.format(', '.join(result)) rendered = "({})".format(", ".join(result))
if self.return_annotation is not _empty: if self.return_annotation is not _empty:
anno = formatannotation(self.return_annotation) anno = formatannotation(self.return_annotation)
rendered += ' -> {}'.format(anno) rendered += " -> {}".format(anno)
return rendered return rendered

View File

@ -2,18 +2,8 @@ from ..enum import _is_dunder, _is_sunder
def test__is_dunder(): def test__is_dunder():
dunder_names = [ dunder_names = ["__i__", "__test__"]
'__i__', non_dunder_names = ["test", "__test", "_test", "_test_", "test__", ""]
'__test__',
]
non_dunder_names = [
'test',
'__test',
'_test',
'_test_',
'test__',
'',
]
for name in dunder_names: for name in dunder_names:
assert _is_dunder(name) is True assert _is_dunder(name) is True
@ -23,17 +13,9 @@ def test__is_dunder():
def test__is_sunder(): def test__is_sunder():
sunder_names = [ sunder_names = ["_i_", "_test_"]
'_i_',
'_test_',
]
non_sunder_names = [ non_sunder_names = ["__i__", "_i__", "__i_", ""]
'__i__',
'_i__',
'__i_',
'',
]
for name in sunder_names: for name in sunder_names:
assert _is_sunder(name) is True assert _is_sunder(name) is True

View File

@ -16,15 +16,15 @@ def get_version(version=None):
main = get_main_version(version) main = get_main_version(version)
sub = '' sub = ""
if version[3] == 'alpha' and version[4] == 0: if version[3] == "alpha" and version[4] == 0:
git_changeset = get_git_changeset() git_changeset = get_git_changeset()
if git_changeset: if git_changeset:
sub = '.dev%s' % git_changeset sub = ".dev%s" % git_changeset
else: else:
sub = '.dev' sub = ".dev"
elif version[3] != 'final': elif version[3] != "final":
mapping = {'alpha': 'a', 'beta': 'b', 'rc': 'rc'} mapping = {"alpha": "a", "beta": "b", "rc": "rc"}
sub = mapping[version[3]] + str(version[4]) sub = mapping[version[3]] + str(version[4])
return str(main + sub) return str(main + sub)
@ -34,7 +34,7 @@ def get_main_version(version=None):
"Returns main version (X.Y[.Z]) from VERSION." "Returns main version (X.Y[.Z]) from VERSION."
version = get_complete_version(version) version = get_complete_version(version)
parts = 2 if version[2] == 0 else 3 parts = 2 if version[2] == 0 else 3
return '.'.join(str(x) for x in version[:parts]) return ".".join(str(x) for x in version[:parts])
def get_complete_version(version=None): def get_complete_version(version=None):
@ -45,17 +45,17 @@ def get_complete_version(version=None):
from graphene import VERSION as version from graphene import VERSION as version
else: else:
assert len(version) == 5 assert len(version) == 5
assert version[3] in ('alpha', 'beta', 'rc', 'final') assert version[3] in ("alpha", "beta", "rc", "final")
return version return version
def get_docs_version(version=None): def get_docs_version(version=None):
version = get_complete_version(version) version = get_complete_version(version)
if version[3] != 'final': if version[3] != "final":
return 'dev' return "dev"
else: else:
return '%d.%d' % version[:2] return "%d.%d" % version[:2]
def get_git_changeset(): def get_git_changeset():
@ -67,12 +67,15 @@ def get_git_changeset():
repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
try: try:
git_log = subprocess.Popen( git_log = subprocess.Popen(
'git log --pretty=format:%ct --quiet -1 HEAD', "git log --pretty=format:%ct --quiet -1 HEAD",
stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE,
shell=True, cwd=repo_dir, universal_newlines=True, stderr=subprocess.PIPE,
shell=True,
cwd=repo_dir,
universal_newlines=True,
) )
timestamp = git_log.communicate()[0] timestamp = git_log.communicate()[0]
timestamp = datetime.datetime.utcfromtimestamp(int(timestamp)) timestamp = datetime.datetime.utcfromtimestamp(int(timestamp))
except: except:
return None return None
return timestamp.strftime('%Y%m%d%H%M%S') return timestamp.strftime("%Y%m%d%H%M%S")

View File

@ -3,11 +3,11 @@ from .mutation import ClientIDMutation
from .connection import Connection, ConnectionField, PageInfo from .connection import Connection, ConnectionField, PageInfo
__all__ = [ __all__ = [
'Node', "Node",
'is_node', "is_node",
'GlobalID', "GlobalID",
'ClientIDMutation', "ClientIDMutation",
'Connection', "Connection",
'ConnectionField', "ConnectionField",
'PageInfo', "PageInfo",
] ]

View File

@ -5,8 +5,7 @@ from functools import partial
from graphql_relay import connection_from_list from graphql_relay import connection_from_list
from promise import Promise, is_thenable from promise import Promise, is_thenable
from ..types import (Boolean, Enum, Int, Interface, List, NonNull, Scalar, from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union
String, Union)
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeOptions from ..types.objecttype import ObjectType, ObjectTypeOptions
from .node import is_node from .node import is_node
@ -15,24 +14,24 @@ from .node import is_node
class PageInfo(ObjectType): class PageInfo(ObjectType):
has_next_page = Boolean( has_next_page = Boolean(
required=True, required=True,
name='hasNextPage', name="hasNextPage",
description='When paginating forwards, are there more items?', description="When paginating forwards, are there more items?",
) )
has_previous_page = Boolean( has_previous_page = Boolean(
required=True, required=True,
name='hasPreviousPage', name="hasPreviousPage",
description='When paginating backwards, are there more items?', description="When paginating backwards, are there more items?",
) )
start_cursor = String( start_cursor = String(
name='startCursor', name="startCursor",
description='When paginating backwards, the cursor to continue.', description="When paginating backwards, the cursor to continue.",
) )
end_cursor = String( end_cursor = String(
name='endCursor', name="endCursor",
description='When paginating forwards, the cursor to continue.', description="When paginating forwards, the cursor to continue.",
) )
@ -41,59 +40,59 @@ class ConnectionOptions(ObjectTypeOptions):
class Connection(ObjectType): class Connection(ObjectType):
class Meta: class Meta:
abstract = True abstract = True
@classmethod @classmethod
def __init_subclass_with_meta__(cls, node=None, name=None, **options): def __init_subclass_with_meta__(cls, node=None, name=None, **options):
_meta = ConnectionOptions(cls) _meta = ConnectionOptions(cls)
assert node, 'You have to provide a node in {}.Meta'.format(cls.__name__) assert node, "You have to provide a node in {}.Meta".format(cls.__name__)
assert issubclass(node, (Scalar, Enum, ObjectType, Interface, Union, NonNull)), ( assert issubclass(
'Received incompatible node "{}" for Connection {}.' node, (Scalar, Enum, ObjectType, Interface, Union, NonNull)
).format(node, cls.__name__) ), ('Received incompatible node "{}" for Connection {}.').format(
node, cls.__name__
)
base_name = re.sub('Connection$', '', name or cls.__name__) or node._meta.name base_name = re.sub("Connection$", "", name or cls.__name__) or node._meta.name
if not name: if not name:
name = '{}Connection'.format(base_name) name = "{}Connection".format(base_name)
edge_class = getattr(cls, 'Edge', None) edge_class = getattr(cls, "Edge", None)
_node = node _node = node
class EdgeBase(object): class EdgeBase(object):
node = Field(_node, description='The item at the end of the edge') node = Field(_node, description="The item at the end of the edge")
cursor = String(required=True, description='A cursor for use in pagination') cursor = String(required=True, description="A cursor for use in pagination")
edge_name = '{}Edge'.format(base_name) edge_name = "{}Edge".format(base_name)
if edge_class: if edge_class:
edge_bases = (edge_class, EdgeBase, ObjectType,) edge_bases = (edge_class, EdgeBase, ObjectType)
else: else:
edge_bases = (EdgeBase, ObjectType,) edge_bases = (EdgeBase, ObjectType)
edge = type(edge_name, edge_bases, {}) edge = type(edge_name, edge_bases, {})
cls.Edge = edge cls.Edge = edge
options['name'] = name options["name"] = name
_meta.node = node _meta.node = node
_meta.fields = OrderedDict([ _meta.fields = OrderedDict(
('page_info', Field(PageInfo, name='pageInfo', required=True)), [
('edges', Field(NonNull(List(edge)))), ("page_info", Field(PageInfo, name="pageInfo", required=True)),
]) ("edges", Field(NonNull(List(edge)))),
return super(Connection, cls).__init_subclass_with_meta__(_meta=_meta, **options) ]
)
return super(Connection, cls).__init_subclass_with_meta__(
_meta=_meta, **options
)
class IterableConnectionField(Field): class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs): def __init__(self, type, *args, **kwargs):
kwargs.setdefault('before', String()) kwargs.setdefault("before", String())
kwargs.setdefault('after', String()) kwargs.setdefault("after", String())
kwargs.setdefault('first', Int()) kwargs.setdefault("first", Int())
kwargs.setdefault('last', Int()) kwargs.setdefault("last", Int())
super(IterableConnectionField, self).__init__( super(IterableConnectionField, self).__init__(type, *args, **kwargs)
type,
*args,
**kwargs
)
@property @property
def type(self): def type(self):
@ -119,7 +118,7 @@ class IterableConnectionField(Field):
return resolved return resolved
assert isinstance(resolved, Iterable), ( assert isinstance(resolved, Iterable), (
'Resolved value from the connection field have to be iterable or instance of {}. ' "Resolved value from the connection field have to be iterable or instance of {}. "
'Received "{}"' 'Received "{}"'
).format(connection_type, resolved) ).format(connection_type, resolved)
connection = connection_from_list( connection = connection_from_list(
@ -127,7 +126,7 @@ class IterableConnectionField(Field):
args, args,
connection_type=connection_type, connection_type=connection_type,
edge_type=connection_type.Edge, edge_type=connection_type.Edge,
pageinfo_type=PageInfo pageinfo_type=PageInfo,
) )
connection.iterable = resolved connection.iterable = resolved
return connection return connection

View File

@ -8,15 +8,15 @@ from ..types.mutation import Mutation
class ClientIDMutation(Mutation): class ClientIDMutation(Mutation):
class Meta: class Meta:
abstract = True abstract = True
@classmethod @classmethod
def __init_subclass_with_meta__(cls, output=None, input_fields=None, def __init_subclass_with_meta__(
arguments=None, name=None, **options): cls, output=None, input_fields=None, arguments=None, name=None, **options
input_class = getattr(cls, 'Input', None) ):
base_name = re.sub('Payload$', '', name or cls.__name__) input_class = getattr(cls, "Input", None)
base_name = re.sub("Payload$", "", name or cls.__name__)
assert not output, "Can't specify any output" assert not output, "Can't specify any output"
assert not arguments, "Can't specify any arguments" assert not arguments, "Can't specify any arguments"
@ -29,40 +29,43 @@ class ClientIDMutation(Mutation):
input_fields = {} input_fields = {}
cls.Input = type( cls.Input = type(
'{}Input'.format(base_name), "{}Input".format(base_name),
bases, bases,
OrderedDict(input_fields, client_mutation_id=String( OrderedDict(
name='clientMutationId')) input_fields, client_mutation_id=String(name="clientMutationId")
),
) )
arguments = OrderedDict( arguments = OrderedDict(
input=cls.Input(required=True) input=cls.Input(required=True)
# 'client_mutation_id': String(name='clientMutationId') # 'client_mutation_id': String(name='clientMutationId')
) )
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None) mutate_and_get_payload = getattr(cls, "mutate_and_get_payload", None)
if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__: if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
assert mutate_and_get_payload, ( assert mutate_and_get_payload, (
"{name}.mutate_and_get_payload method is required" "{name}.mutate_and_get_payload method is required"
" in a ClientIDMutation.").format(name=name or cls.__name__) " in a ClientIDMutation."
).format(name=name or cls.__name__)
if not name: if not name:
name = '{}Payload'.format(base_name) name = "{}Payload".format(base_name)
super(ClientIDMutation, cls).__init_subclass_with_meta__( super(ClientIDMutation, cls).__init_subclass_with_meta__(
output=None, arguments=arguments, name=name, **options) output=None, arguments=arguments, name=name, **options
cls._meta.fields['client_mutation_id'] = (
Field(String, name='clientMutationId')
) )
cls._meta.fields["client_mutation_id"] = Field(String, name="clientMutationId")
@classmethod @classmethod
def mutate(cls, root, info, input): def mutate(cls, root, info, input):
def on_resolve(payload): def on_resolve(payload):
try: try:
payload.client_mutation_id = input.get('client_mutation_id') payload.client_mutation_id = input.get("client_mutation_id")
except Exception: except Exception:
raise Exception( raise Exception(
('Cannot set client_mutation_id in the payload object {}' ("Cannot set client_mutation_id in the payload object {}").format(
).format(repr(payload))) repr(payload)
)
)
return payload return payload
result = cls.mutate_and_get_payload(root, info, **input) result = cls.mutate_and_get_payload(root, info, **input)

View File

@ -10,9 +10,9 @@ from ..types.utils import get_type
def is_node(objecttype): def is_node(objecttype):
''' """
Check if the given objecttype has Node as an interface Check if the given objecttype has Node as an interface
''' """
if not isclass(objecttype): if not isclass(objecttype):
return False return False
@ -27,7 +27,6 @@ def is_node(objecttype):
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs): def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs):
super(GlobalID, self).__init__(ID, required=required, *args, **kwargs) super(GlobalID, self).__init__(ID, required=required, *args, **kwargs)
self.node = node or Node self.node = node or Node
@ -41,15 +40,16 @@ class GlobalID(Field):
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return partial( return partial(
self.id_resolver, parent_resolver, self.node, parent_type_name=self.parent_type_name self.id_resolver,
parent_resolver,
self.node,
parent_type_name=self.parent_type_name,
) )
class NodeField(Field): class NodeField(Field):
def __init__(self, node, type=False, deprecation_reason=None, name=None, **kwargs):
def __init__(self, node, type=False, deprecation_reason=None, assert issubclass(node, Node), "NodeField can only operate in Nodes"
name=None, **kwargs):
assert issubclass(node, Node), 'NodeField can only operate in Nodes'
self.node_type = node self.node_type = node
self.field_type = type self.field_type = type
@ -57,8 +57,8 @@ class NodeField(Field):
# If we don's specify a type, the field type will be the node # If we don's specify a type, the field type will be the node
# interface # interface
type or node, type or node,
description='The ID of the object', description="The ID of the object",
id=ID(required=True) id=ID(required=True),
) )
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
@ -66,7 +66,6 @@ class NodeField(Field):
class AbstractNode(Interface): class AbstractNode(Interface):
class Meta: class Meta:
abstract = True abstract = True
@ -74,14 +73,13 @@ class AbstractNode(Interface):
def __init_subclass_with_meta__(cls, **options): def __init_subclass_with_meta__(cls, **options):
_meta = InterfaceOptions(cls) _meta = InterfaceOptions(cls)
_meta.fields = OrderedDict( _meta.fields = OrderedDict(
id=GlobalID(cls, description='The ID of the object.') id=GlobalID(cls, description="The ID of the object.")
) )
super(AbstractNode, cls).__init_subclass_with_meta__( super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options)
_meta=_meta, **options)
class Node(AbstractNode): class Node(AbstractNode):
'''An object with an ID''' """An object with an ID"""
@classmethod @classmethod
def Field(cls, *args, **kwargs): # noqa: N802 def Field(cls, *args, **kwargs): # noqa: N802
@ -100,15 +98,15 @@ class Node(AbstractNode):
return None return None
if only_type: if only_type:
assert graphene_type == only_type, ( assert graphene_type == only_type, ("Must receive a {} id.").format(
'Must receive a {} id.' only_type._meta.name
).format(only_type._meta.name) )
# We make sure the ObjectType implements the "Node" interface # We make sure the ObjectType implements the "Node" interface
if cls not in graphene_type._meta.interfaces: if cls not in graphene_type._meta.interfaces:
return None return None
get_node = getattr(graphene_type, 'get_node', None) get_node = getattr(graphene_type, "get_node", None)
if get_node: if get_node:
return get_node(info, _id) return get_node(info, _id)

View File

@ -1,15 +1,14 @@
import pytest import pytest
from ...types import (Argument, Field, Int, List, NonNull, ObjectType, Schema, from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String
String)
from ..connection import Connection, ConnectionField, PageInfo from ..connection import Connection, ConnectionField, PageInfo
from ..node import Node from ..node import Node
class MyObject(ObjectType): class MyObject(ObjectType):
class Meta: class Meta:
interfaces = [Node] interfaces = [Node]
field = String() field = String()
@ -23,11 +22,11 @@ def test_connection():
class Edge: class Edge:
other = String() other = String()
assert MyObjectConnection._meta.name == 'MyObjectConnection' assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ['page_info', 'edges', 'extra'] assert list(fields.keys()) == ["page_info", "edges", "extra"]
edge_field = fields['edges'] edge_field = fields["edges"]
pageinfo_field = fields['page_info'] pageinfo_field = fields["page_info"]
assert isinstance(edge_field, Field) assert isinstance(edge_field, Field)
assert isinstance(edge_field.type, NonNull) assert isinstance(edge_field.type, NonNull)
@ -44,13 +43,12 @@ def test_connection_inherit_abstracttype():
extra = String() extra = String()
class MyObjectConnection(BaseConnection, Connection): class MyObjectConnection(BaseConnection, Connection):
class Meta: class Meta:
node = MyObject node = MyObject
assert MyObjectConnection._meta.name == 'MyObjectConnection' assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ['page_info', 'edges', 'extra'] assert list(fields.keys()) == ["page_info", "edges", "extra"]
def test_connection_name(): def test_connection_name():
@ -60,7 +58,6 @@ def test_connection_name():
extra = String() extra = String()
class MyObjectConnection(BaseConnection, Connection): class MyObjectConnection(BaseConnection, Connection):
class Meta: class Meta:
node = MyObject node = MyObject
name = custom_name name = custom_name
@ -70,7 +67,6 @@ def test_connection_name():
def test_edge(): def test_edge():
class MyObjectConnection(Connection): class MyObjectConnection(Connection):
class Meta: class Meta:
node = MyObject node = MyObject
@ -78,15 +74,15 @@ def test_edge():
other = String() other = String()
Edge = MyObjectConnection.Edge Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge' assert Edge._meta.name == "MyObjectEdge"
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ['node', 'cursor', 'other'] assert list(edge_fields.keys()) == ["node", "cursor", "other"]
assert isinstance(edge_fields['node'], Field) assert isinstance(edge_fields["node"], Field)
assert edge_fields['node'].type == MyObject assert edge_fields["node"].type == MyObject
assert isinstance(edge_fields['other'], Field) assert isinstance(edge_fields["other"], Field)
assert edge_fields['other'].type == String assert edge_fields["other"].type == String
def test_edge_with_bases(): def test_edge_with_bases():
@ -94,7 +90,6 @@ def test_edge_with_bases():
extra = String() extra = String()
class MyObjectConnection(Connection): class MyObjectConnection(Connection):
class Meta: class Meta:
node = MyObject node = MyObject
@ -102,35 +97,39 @@ def test_edge_with_bases():
other = String() other = String()
Edge = MyObjectConnection.Edge Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge' assert Edge._meta.name == "MyObjectEdge"
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ['node', 'cursor', 'extra', 'other'] assert list(edge_fields.keys()) == ["node", "cursor", "extra", "other"]
assert isinstance(edge_fields['node'], Field) assert isinstance(edge_fields["node"], Field)
assert edge_fields['node'].type == MyObject assert edge_fields["node"].type == MyObject
assert isinstance(edge_fields['other'], Field) assert isinstance(edge_fields["other"], Field)
assert edge_fields['other'].type == String assert edge_fields["other"].type == String
def test_pageinfo(): def test_pageinfo():
assert PageInfo._meta.name == 'PageInfo' assert PageInfo._meta.name == "PageInfo"
fields = PageInfo._meta.fields fields = PageInfo._meta.fields
assert list(fields.keys()) == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor'] assert list(fields.keys()) == [
"has_next_page",
"has_previous_page",
"start_cursor",
"end_cursor",
]
def test_connectionfield(): def test_connectionfield():
class MyObjectConnection(Connection): class MyObjectConnection(Connection):
class Meta: class Meta:
node = MyObject node = MyObject
field = ConnectionField(MyObjectConnection) field = ConnectionField(MyObjectConnection)
assert field.args == { assert field.args == {
'before': Argument(String), "before": Argument(String),
'after': Argument(String), "after": Argument(String),
'first': Argument(Int), "first": Argument(Int),
'last': Argument(Int), "last": Argument(Int),
} }
@ -139,28 +138,30 @@ def test_connectionfield_node_deprecated():
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
field.type field.type
assert "ConnectionField's now need a explicit ConnectionType for Nodes." in str(exc_info.value) assert "ConnectionField's now need a explicit ConnectionType for Nodes." in str(
exc_info.value
)
def test_connectionfield_custom_args(): def test_connectionfield_custom_args():
class MyObjectConnection(Connection): class MyObjectConnection(Connection):
class Meta: class Meta:
node = MyObject node = MyObject
field = ConnectionField(MyObjectConnection, before=String(required=True), extra=String()) field = ConnectionField(
MyObjectConnection, before=String(required=True), extra=String()
)
assert field.args == { assert field.args == {
'before': Argument(NonNull(String)), "before": Argument(NonNull(String)),
'after': Argument(String), "after": Argument(String),
'first': Argument(Int), "first": Argument(Int),
'last': Argument(Int), "last": Argument(Int),
'extra': Argument(String), "extra": Argument(String),
} }
def test_connectionfield_required(): def test_connectionfield_required():
class MyObjectConnection(Connection): class MyObjectConnection(Connection):
class Meta: class Meta:
node = MyObject node = MyObject
@ -171,8 +172,6 @@ def test_connectionfield_required():
return [] return []
schema = Schema(query=Query) schema = Schema(query=Query)
executed = schema.execute( executed = schema.execute("{ testConnection { edges { cursor } } }")
'{ testConnection { edges { cursor } } }'
)
assert not executed.errors assert not executed.errors
assert executed.data == {'testConnection': {'edges': []}} assert executed.data == {"testConnection": {"edges": []}}

View File

@ -7,11 +7,10 @@ from ...types import ObjectType, Schema, String
from ..connection import Connection, ConnectionField, PageInfo from ..connection import Connection, ConnectionField, PageInfo
from ..node import Node from ..node import Node
letter_chars = ['A', 'B', 'C', 'D', 'E'] letter_chars = ["A", "B", "C", "D", "E"]
class Letter(ObjectType): class Letter(ObjectType):
class Meta: class Meta:
interfaces = (Node,) interfaces = (Node,)
@ -19,7 +18,6 @@ class Letter(ObjectType):
class LetterConnection(Connection): class LetterConnection(Connection):
class Meta: class Meta:
node = Letter node = Letter
@ -39,16 +37,10 @@ class Query(ObjectType):
def resolve_connection_letters(self, info, **args): def resolve_connection_letters(self, info, **args):
return LetterConnection( return LetterConnection(
page_info=PageInfo( page_info=PageInfo(has_next_page=True, has_previous_page=False),
has_next_page=True,
has_previous_page=False
),
edges=[ edges=[
LetterConnection.Edge( LetterConnection.Edge(node=Letter(id=0, letter="A"), cursor="a-cursor")
node=Letter(id=0, letter='A'), ],
cursor='a-cursor'
),
]
) )
@ -62,11 +54,8 @@ for i, letter in enumerate(letter_chars):
def edges(selected_letters): def edges(selected_letters):
return [ return [
{ {
'node': { "node": {"id": base64("Letter:%s" % l.id), "letter": l.letter},
'id': base64('Letter:%s' % l.id), "cursor": base64("arrayconnection:%s" % l.id),
'letter': l.letter
},
'cursor': base64('arrayconnection:%s' % l.id)
} }
for l in [letters[i] for i in selected_letters] for l in [letters[i] for i in selected_letters]
] ]
@ -74,14 +63,15 @@ def edges(selected_letters):
def cursor_for(ltr): def cursor_for(ltr):
letter = letters[ltr] letter = letters[ltr]
return base64('arrayconnection:%s' % letter.id) return base64("arrayconnection:%s" % letter.id)
def execute(args=''): def execute(args=""):
if args: if args:
args = '(' + args + ')' args = "(" + args + ")"
return schema.execute(''' return schema.execute(
"""
{ {
letters%s { letters%s {
edges { edges {
@ -99,112 +89,136 @@ def execute(args=''):
} }
} }
} }
''' % args) """
% args
)
def check(args, letters, has_previous_page=False, has_next_page=False): def check(args, letters, has_previous_page=False, has_next_page=False):
result = execute(args) result = execute(args)
expected_edges = edges(letters) expected_edges = edges(letters)
expected_page_info = { expected_page_info = {
'hasPreviousPage': has_previous_page, "hasPreviousPage": has_previous_page,
'hasNextPage': has_next_page, "hasNextPage": has_next_page,
'endCursor': expected_edges[-1]['cursor'] if expected_edges else None, "endCursor": expected_edges[-1]["cursor"] if expected_edges else None,
'startCursor': expected_edges[0]['cursor'] if expected_edges else None "startCursor": expected_edges[0]["cursor"] if expected_edges else None,
} }
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
'letters': { "letters": {"edges": expected_edges, "pageInfo": expected_page_info}
'edges': expected_edges,
'pageInfo': expected_page_info
}
} }
def test_returns_all_elements_without_filters(): def test_returns_all_elements_without_filters():
check('', 'ABCDE') check("", "ABCDE")
def test_respects_a_smaller_first(): def test_respects_a_smaller_first():
check('first: 2', 'AB', has_next_page=True) check("first: 2", "AB", has_next_page=True)
def test_respects_an_overly_large_first(): def test_respects_an_overly_large_first():
check('first: 10', 'ABCDE') check("first: 10", "ABCDE")
def test_respects_a_smaller_last(): def test_respects_a_smaller_last():
check('last: 2', 'DE', has_previous_page=True) check("last: 2", "DE", has_previous_page=True)
def test_respects_an_overly_large_last(): def test_respects_an_overly_large_last():
check('last: 10', 'ABCDE') check("last: 10", "ABCDE")
def test_respects_first_and_after(): def test_respects_first_and_after():
check('first: 2, after: "{}"'.format(cursor_for('B')), 'CD', has_next_page=True) check('first: 2, after: "{}"'.format(cursor_for("B")), "CD", has_next_page=True)
def test_respects_first_and_after_with_long_first(): def test_respects_first_and_after_with_long_first():
check('first: 10, after: "{}"'.format(cursor_for('B')), 'CDE') check('first: 10, after: "{}"'.format(cursor_for("B")), "CDE")
def test_respects_last_and_before(): def test_respects_last_and_before():
check('last: 2, before: "{}"'.format(cursor_for('D')), 'BC', has_previous_page=True) check('last: 2, before: "{}"'.format(cursor_for("D")), "BC", has_previous_page=True)
def test_respects_last_and_before_with_long_last(): def test_respects_last_and_before_with_long_last():
check('last: 10, before: "{}"'.format(cursor_for('D')), 'ABC') check('last: 10, before: "{}"'.format(cursor_for("D")), "ABC")
def test_respects_first_and_after_and_before_too_few(): def test_respects_first_and_after_and_before_too_few():
check('first: 2, after: "{}", before: "{}"'.format(cursor_for('A'), cursor_for('E')), 'BC', has_next_page=True) check(
'first: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BC",
has_next_page=True,
)
def test_respects_first_and_after_and_before_too_many(): def test_respects_first_and_after_and_before_too_many():
check('first: 4, after: "{}", before: "{}"'.format(cursor_for('A'), cursor_for('E')), 'BCD') check(
'first: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_respects_first_and_after_and_before_exactly_right(): def test_respects_first_and_after_and_before_exactly_right():
check('first: 3, after: "{}", before: "{}"'.format(cursor_for('A'), cursor_for('E')), "BCD") check(
'first: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_respects_last_and_after_and_before_too_few(): def test_respects_last_and_after_and_before_too_few():
check('last: 2, after: "{}", before: "{}"'.format(cursor_for('A'), cursor_for('E')), 'CD', has_previous_page=True) check(
'last: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"CD",
has_previous_page=True,
)
def test_respects_last_and_after_and_before_too_many(): def test_respects_last_and_after_and_before_too_many():
check('last: 4, after: "{}", before: "{}"'.format(cursor_for('A'), cursor_for('E')), 'BCD') check(
'last: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_respects_last_and_after_and_before_exactly_right(): def test_respects_last_and_after_and_before_exactly_right():
check('last: 3, after: "{}", before: "{}"'.format(cursor_for('A'), cursor_for('E')), 'BCD') check(
'last: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_returns_no_elements_if_first_is_0(): def test_returns_no_elements_if_first_is_0():
check('first: 0', '', has_next_page=True) check("first: 0", "", has_next_page=True)
def test_returns_all_elements_if_cursors_are_invalid(): def test_returns_all_elements_if_cursors_are_invalid():
check('before: "invalid" after: "invalid"', 'ABCDE') check('before: "invalid" after: "invalid"', "ABCDE")
def test_returns_all_elements_if_cursors_are_on_the_outside(): def test_returns_all_elements_if_cursors_are_on_the_outside():
check( check(
'before: "{}" after: "{}"'.format( 'before: "{}" after: "{}"'.format(
base64( base64("arrayconnection:%s" % 6), base64("arrayconnection:%s" % -1)
'arrayconnection:%s' % 6), ),
base64( "ABCDE",
'arrayconnection:%s' % -1)), )
'ABCDE')
def test_returns_no_elements_if_cursors_cross(): def test_returns_no_elements_if_cursors_cross():
check('before: "{}" after: "{}"'.format(base64('arrayconnection:%s' % 2), base64('arrayconnection:%s' % 4)), '') check(
'before: "{}" after: "{}"'.format(
base64("arrayconnection:%s" % 2), base64("arrayconnection:%s" % 4)
),
"",
)
def test_connection_type_nodes(): def test_connection_type_nodes():
result = schema.execute(''' result = schema.execute(
"""
{ {
connectionLetters { connectionLetters {
edges { edges {
@ -220,28 +234,23 @@ def test_connection_type_nodes():
} }
} }
} }
''') """
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
'connectionLetters': { "connectionLetters": {
'edges': [{ "edges": [
'node': { {"node": {"id": "TGV0dGVyOjA=", "letter": "A"}, "cursor": "a-cursor"}
'id': 'TGV0dGVyOjA=', ],
'letter': 'A', "pageInfo": {"hasPreviousPage": False, "hasNextPage": True},
},
'cursor': 'a-cursor',
}],
'pageInfo': {
'hasPreviousPage': False,
'hasNextPage': True,
}
} }
} }
def test_connection_promise(): def test_connection_promise():
result = schema.execute(''' result = schema.execute(
"""
{ {
promiseLetters(first:1) { promiseLetters(first:1) {
edges { edges {
@ -256,20 +265,13 @@ def test_connection_promise():
} }
} }
} }
''') """
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
'promiseLetters': { "promiseLetters": {
'edges': [{ "edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}],
'node': { "pageInfo": {"hasPreviousPage": False, "hasNextPage": True},
'id': 'TGV0dGVyOjA=',
'letter': 'A',
},
}],
'pageInfo': {
'hasPreviousPage': False,
'hasNextPage': True,
}
} }
} }

View File

@ -6,20 +6,18 @@ from ..node import GlobalID, Node
class CustomNode(Node): class CustomNode(Node):
class Meta: class Meta:
name = 'Node' name = "Node"
class User(ObjectType): class User(ObjectType):
class Meta: class Meta:
interfaces = [CustomNode] interfaces = [CustomNode]
name = String() name = String()
class Info(object): class Info(object):
def __init__(self, parent_type): def __init__(self, parent_type):
self.parent_type = GrapheneObjectType( self.parent_type = GrapheneObjectType(
graphene_type=parent_type, graphene_type=parent_type,
@ -27,7 +25,7 @@ class Info(object):
description=parent_type._meta.description, description=parent_type._meta.description,
fields=None, fields=None,
is_type_of=parent_type.is_type_of, is_type_of=parent_type.is_type_of,
interfaces=None interfaces=None,
) )
@ -45,7 +43,7 @@ def test_global_id_allows_overriding_of_node_and_required():
def test_global_id_defaults_to_info_parent_type(): def test_global_id_defaults_to_info_parent_type():
my_id = '1' my_id = "1"
gid = GlobalID() gid = GlobalID()
id_resolver = gid.get_resolver(lambda *_: my_id) id_resolver = gid.get_resolver(lambda *_: my_id)
my_global_id = id_resolver(None, Info(User)) my_global_id = id_resolver(None, Info(User))
@ -53,7 +51,7 @@ def test_global_id_defaults_to_info_parent_type():
def test_global_id_allows_setting_customer_parent_type(): def test_global_id_allows_setting_customer_parent_type():
my_id = '1' my_id = "1"
gid = GlobalID(parent_type=User) gid = GlobalID(parent_type=User)
id_resolver = gid.get_resolver(lambda *_: my_id) id_resolver = gid.get_resolver(lambda *_: my_id)
my_global_id = id_resolver(None, None) my_global_id = id_resolver(None, None)

View File

@ -1,8 +1,16 @@
import pytest import pytest
from promise import Promise from promise import Promise
from ...types import (ID, Argument, Field, InputField, InputObjectType, from ...types import (
NonNull, ObjectType, Schema) ID,
Argument,
Field,
InputField,
InputObjectType,
NonNull,
ObjectType,
Schema,
)
from ...types.scalars import String from ...types.scalars import String
from ..mutation import ClientIDMutation from ..mutation import ClientIDMutation
@ -19,7 +27,6 @@ class MyNode(ObjectType):
class SaySomething(ClientIDMutation): class SaySomething(ClientIDMutation):
class Input: class Input:
what = String() what = String()
@ -31,14 +38,13 @@ class SaySomething(ClientIDMutation):
class FixedSaySomething(object): class FixedSaySomething(object):
__slots__ = 'phrase', __slots__ = ("phrase",)
def __init__(self, phrase): def __init__(self, phrase):
self.phrase = phrase self.phrase = phrase
class SaySomethingFixed(ClientIDMutation): class SaySomethingFixed(ClientIDMutation):
class Input: class Input:
what = String() what = String()
@ -50,7 +56,6 @@ class SaySomethingFixed(ClientIDMutation):
class SaySomethingPromise(ClientIDMutation): class SaySomethingPromise(ClientIDMutation):
class Input: class Input:
what = String() what = String()
@ -68,7 +73,6 @@ class MyEdge(ObjectType):
class OtherMutation(ClientIDMutation): class OtherMutation(ClientIDMutation):
class Input(SharedFields): class Input(SharedFields):
additional_field = String() additional_field = String()
@ -76,11 +80,14 @@ class OtherMutation(ClientIDMutation):
my_node_edge = Field(MyEdge) my_node_edge = Field(MyEdge)
@staticmethod @staticmethod
def mutate_and_get_payload(self, info, shared='', additional_field='', client_mutation_id=None): def mutate_and_get_payload(
self, info, shared="", additional_field="", client_mutation_id=None
):
edge_type = MyEdge edge_type = MyEdge
return OtherMutation( return OtherMutation(
name=shared + additional_field, name=shared + additional_field,
my_node_edge=edge_type(cursor='1', node=MyNode(name='name'))) my_node_edge=edge_type(cursor="1", node=MyNode(name="name")),
)
class RootQuery(ObjectType): class RootQuery(ObjectType):
@ -103,64 +110,62 @@ def test_no_mutate_and_get_payload():
class MyMutation(ClientIDMutation): class MyMutation(ClientIDMutation):
pass pass
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation." == str( assert (
excinfo.value) "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation."
== str(excinfo.value)
)
def test_mutation(): def test_mutation():
fields = SaySomething._meta.fields fields = SaySomething._meta.fields
assert list(fields.keys()) == ['phrase', 'client_mutation_id'] assert list(fields.keys()) == ["phrase", "client_mutation_id"]
assert SaySomething._meta.name == "SaySomethingPayload" assert SaySomething._meta.name == "SaySomethingPayload"
assert isinstance(fields['phrase'], Field) assert isinstance(fields["phrase"], Field)
field = SaySomething.Field() field = SaySomething.Field()
assert field.type == SaySomething assert field.type == SaySomething
assert list(field.args.keys()) == ['input'] assert list(field.args.keys()) == ["input"]
assert isinstance(field.args['input'], Argument) assert isinstance(field.args["input"], Argument)
assert isinstance(field.args['input'].type, NonNull) assert isinstance(field.args["input"].type, NonNull)
assert field.args['input'].type.of_type == SaySomething.Input assert field.args["input"].type.of_type == SaySomething.Input
assert isinstance(fields['client_mutation_id'], Field) assert isinstance(fields["client_mutation_id"], Field)
assert fields['client_mutation_id'].name == 'clientMutationId' assert fields["client_mutation_id"].name == "clientMutationId"
assert fields['client_mutation_id'].type == String assert fields["client_mutation_id"].type == String
def test_mutation_input(): def test_mutation_input():
Input = SaySomething.Input Input = SaySomething.Input
assert issubclass(Input, InputObjectType) assert issubclass(Input, InputObjectType)
fields = Input._meta.fields fields = Input._meta.fields
assert list(fields.keys()) == ['what', 'client_mutation_id'] assert list(fields.keys()) == ["what", "client_mutation_id"]
assert isinstance(fields['what'], InputField) assert isinstance(fields["what"], InputField)
assert fields['what'].type == String assert fields["what"].type == String
assert isinstance(fields['client_mutation_id'], InputField) assert isinstance(fields["client_mutation_id"], InputField)
assert fields['client_mutation_id'].type == String assert fields["client_mutation_id"].type == String
def test_subclassed_mutation(): def test_subclassed_mutation():
fields = OtherMutation._meta.fields fields = OtherMutation._meta.fields
assert list(fields.keys()) == [ assert list(fields.keys()) == ["name", "my_node_edge", "client_mutation_id"]
'name', 'my_node_edge', 'client_mutation_id' assert isinstance(fields["name"], Field)
]
assert isinstance(fields['name'], Field)
field = OtherMutation.Field() field = OtherMutation.Field()
assert field.type == OtherMutation assert field.type == OtherMutation
assert list(field.args.keys()) == ['input'] assert list(field.args.keys()) == ["input"]
assert isinstance(field.args['input'], Argument) assert isinstance(field.args["input"], Argument)
assert isinstance(field.args['input'].type, NonNull) assert isinstance(field.args["input"].type, NonNull)
assert field.args['input'].type.of_type == OtherMutation.Input assert field.args["input"].type.of_type == OtherMutation.Input
def test_subclassed_mutation_input(): def test_subclassed_mutation_input():
Input = OtherMutation.Input Input = OtherMutation.Input
assert issubclass(Input, InputObjectType) assert issubclass(Input, InputObjectType)
fields = Input._meta.fields fields = Input._meta.fields
assert list(fields.keys()) == [ assert list(fields.keys()) == ["shared", "additional_field", "client_mutation_id"]
'shared', 'additional_field', 'client_mutation_id' assert isinstance(fields["shared"], InputField)
] assert fields["shared"].type == String
assert isinstance(fields['shared'], InputField) assert isinstance(fields["additional_field"], InputField)
assert fields['shared'].type == String assert fields["additional_field"].type == String
assert isinstance(fields['additional_field'], InputField) assert isinstance(fields["client_mutation_id"], InputField)
assert fields['additional_field'].type == String assert fields["client_mutation_id"].type == String
assert isinstance(fields['client_mutation_id'], InputField)
assert fields['client_mutation_id'].type == String
def test_node_query(): def test_node_query():
@ -168,14 +173,16 @@ def test_node_query():
'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }' 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
) )
assert not executed.errors assert not executed.errors
assert executed.data == {'say': {'phrase': 'hello'}} assert executed.data == {"say": {"phrase": "hello"}}
def test_node_query_fixed(): def test_node_query_fixed():
executed = schema.execute( executed = schema.execute(
'mutation a { sayFixed(input: {what:"hello", clientMutationId:"1"}) { phrase } }' 'mutation a { sayFixed(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
) )
assert "Cannot set client_mutation_id in the payload object" in str(executed.errors[0]) assert "Cannot set client_mutation_id in the payload object" in str(
executed.errors[0]
)
def test_node_query_promise(): def test_node_query_promise():
@ -183,7 +190,7 @@ def test_node_query_promise():
'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }' 'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
) )
assert not executed.errors assert not executed.errors
assert executed.data == {'sayPromise': {'phrase': 'hello'}} assert executed.data == {"sayPromise": {"phrase": "hello"}}
def test_edge_query(): def test_edge_query():
@ -192,13 +199,8 @@ def test_edge_query():
) )
assert not executed.errors assert not executed.errors
assert dict(executed.data) == { assert dict(executed.data) == {
'other': { "other": {
'clientMutationId': '1', "clientMutationId": "1",
'myNodeEdge': { "myNodeEdge": {"cursor": "1", "node": {"name": "name"}},
'cursor': '1',
'node': {
'name': 'name'
}
}
} }
} }

View File

@ -12,13 +12,13 @@ class SharedNodeFields(object):
something_else = String() something_else = String()
def resolve_something_else(*_): def resolve_something_else(*_):
return '----' return "----"
class MyNode(ObjectType): class MyNode(ObjectType):
class Meta: class Meta:
interfaces = (Node,) interfaces = (Node,)
name = String() name = String()
@staticmethod @staticmethod
@ -33,7 +33,7 @@ class MyOtherNode(SharedNodeFields, ObjectType):
interfaces = (Node,) interfaces = (Node,)
def resolve_extra_field(self, *_): def resolve_extra_field(self, *_):
return 'extra field info.' return "extra field info."
@staticmethod @staticmethod
def get_node(info, id): def get_node(info, id):
@ -51,7 +51,7 @@ schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode])
def test_node_good(): def test_node_good():
assert 'id' in MyNode._meta.fields assert "id" in MyNode._meta.fields
assert is_node(MyNode) assert is_node(MyNode)
assert not is_node(object) assert not is_node(object)
@ -61,25 +61,33 @@ def test_node_query():
'{ node(id:"%s") { ... on MyNode { name } } }' % Node.to_global_id("MyNode", 1) '{ node(id:"%s") { ... on MyNode { name } } }' % Node.to_global_id("MyNode", 1)
) )
assert not executed.errors assert not executed.errors
assert executed.data == {'node': {'name': '1'}} assert executed.data == {"node": {"name": "1"}}
def test_subclassed_node_query(): def test_subclassed_node_query():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyOtherNode { shared, extraField, somethingElse } } }' % '{ node(id:"%s") { ... on MyOtherNode { shared, extraField, somethingElse } } }'
to_global_id("MyOtherNode", 1)) % to_global_id("MyOtherNode", 1)
)
assert not executed.errors assert not executed.errors
assert executed.data == OrderedDict({'node': OrderedDict( assert executed.data == OrderedDict(
[('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])}) {
"node": OrderedDict(
[
("shared", "1"),
("extraField", "extra field info."),
("somethingElse", "----"),
]
)
}
)
def test_node_requesting_non_node(): def test_node_requesting_non_node():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { __typename } } ' % Node.to_global_id("RootQuery", 1) '{ node(id:"%s") { __typename } } ' % Node.to_global_id("RootQuery", 1)
) )
assert executed.data == { assert executed.data == {"node": None}
'node': None
}
def test_node_query_incorrect_id(): def test_node_query_incorrect_id():
@ -87,7 +95,7 @@ def test_node_query_incorrect_id():
'{ node(id:"%s") { ... on MyNode { name } } }' % "something:2" '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2"
) )
assert not executed.errors assert not executed.errors
assert executed.data == {'node': None} assert executed.data == {"node": None}
def test_node_field(): def test_node_field():
@ -107,37 +115,42 @@ def test_node_field_only_type():
'{ onlyNode(id:"%s") { __typename, name } } ' % Node.to_global_id("MyNode", 1) '{ onlyNode(id:"%s") { __typename, name } } ' % Node.to_global_id("MyNode", 1)
) )
assert not executed.errors assert not executed.errors
assert executed.data == {'onlyNode': {'__typename': 'MyNode', 'name': '1'}} assert executed.data == {"onlyNode": {"__typename": "MyNode", "name": "1"}}
def test_node_field_only_type_wrong(): def test_node_field_only_type_wrong():
executed = schema.execute( executed = schema.execute(
'{ onlyNode(id:"%s") { __typename, name } } ' % Node.to_global_id("MyOtherNode", 1) '{ onlyNode(id:"%s") { __typename, name } } '
% Node.to_global_id("MyOtherNode", 1)
) )
assert len(executed.errors) == 1 assert len(executed.errors) == 1
assert str(executed.errors[0]) == 'Must receive a MyNode id.' assert str(executed.errors[0]) == "Must receive a MyNode id."
assert executed.data == {'onlyNode': None} assert executed.data == {"onlyNode": None}
def test_node_field_only_lazy_type(): def test_node_field_only_lazy_type():
executed = schema.execute( executed = schema.execute(
'{ onlyNodeLazy(id:"%s") { __typename, name } } ' % Node.to_global_id("MyNode", 1) '{ onlyNodeLazy(id:"%s") { __typename, name } } '
% Node.to_global_id("MyNode", 1)
) )
assert not executed.errors assert not executed.errors
assert executed.data == {'onlyNodeLazy': {'__typename': 'MyNode', 'name': '1'}} assert executed.data == {"onlyNodeLazy": {"__typename": "MyNode", "name": "1"}}
def test_node_field_only_lazy_type_wrong(): def test_node_field_only_lazy_type_wrong():
executed = schema.execute( executed = schema.execute(
'{ onlyNodeLazy(id:"%s") { __typename, name } } ' % Node.to_global_id("MyOtherNode", 1) '{ onlyNodeLazy(id:"%s") { __typename, name } } '
% Node.to_global_id("MyOtherNode", 1)
) )
assert len(executed.errors) == 1 assert len(executed.errors) == 1
assert str(executed.errors[0]) == 'Must receive a MyNode id.' assert str(executed.errors[0]) == "Must receive a MyNode id."
assert executed.data == {'onlyNodeLazy': None} assert executed.data == {"onlyNodeLazy": None}
def test_str_schema(): def test_str_schema():
assert str(schema) == """ assert (
str(schema)
== """
schema { schema {
query: RootQuery query: RootQuery
} }
@ -165,3 +178,4 @@ type RootQuery {
onlyNodeLazy(id: ID!): MyNode onlyNodeLazy(id: ID!): MyNode
} }
""".lstrip() """.lstrip()
)

View File

@ -6,9 +6,8 @@ from ..node import Node
class CustomNode(Node): class CustomNode(Node):
class Meta: class Meta:
name = 'Node' name = "Node"
@staticmethod @staticmethod
def to_global_id(type, id): def to_global_id(type, id):
@ -28,27 +27,20 @@ class BasePhoto(Interface):
class User(ObjectType): class User(ObjectType):
class Meta: class Meta:
interfaces = [CustomNode] interfaces = [CustomNode]
name = String() name = String()
class Photo(ObjectType): class Photo(ObjectType):
class Meta: class Meta:
interfaces = [CustomNode, BasePhoto] interfaces = [CustomNode, BasePhoto]
user_data = { user_data = {"1": User(id="1", name="John Doe"), "2": User(id="2", name="Jane Smith")}
'1': User(id='1', name='John Doe'),
'2': User(id='2', name='Jane Smith'),
}
photo_data = { photo_data = {"3": Photo(id="3", width=300), "4": Photo(id="4", width=400)}
'3': Photo(id='3', width=300),
'4': Photo(id='4', width=400),
}
class RootQuery(ObjectType): class RootQuery(ObjectType):
@ -59,7 +51,9 @@ schema = Schema(query=RootQuery, types=[User, Photo])
def test_str_schema_correct(): def test_str_schema_correct():
assert str(schema) == '''schema { assert (
str(schema)
== """schema {
query: RootQuery query: RootQuery
} }
@ -84,47 +78,40 @@ type User implements Node {
id: ID! id: ID!
name: String name: String
} }
''' """
)
def test_gets_the_correct_id_for_users(): def test_gets_the_correct_id_for_users():
query = ''' query = """
{ {
node(id: "1") { node(id: "1") {
id id
} }
} }
''' """
expected = { expected = {"node": {"id": "1"}}
'node': {
'id': '1',
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_gets_the_correct_id_for_photos(): def test_gets_the_correct_id_for_photos():
query = ''' query = """
{ {
node(id: "4") { node(id: "4") {
id id
} }
} }
''' """
expected = { expected = {"node": {"id": "4"}}
'node': {
'id': '4',
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_gets_the_correct_name_for_users(): def test_gets_the_correct_name_for_users():
query = ''' query = """
{ {
node(id: "1") { node(id: "1") {
id id
@ -133,20 +120,15 @@ def test_gets_the_correct_name_for_users():
} }
} }
} }
''' """
expected = { expected = {"node": {"id": "1", "name": "John Doe"}}
'node': {
'id': '1',
'name': 'John Doe'
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_gets_the_correct_width_for_photos(): def test_gets_the_correct_width_for_photos():
query = ''' query = """
{ {
node(id: "4") { node(id: "4") {
id id
@ -155,60 +137,45 @@ def test_gets_the_correct_width_for_photos():
} }
} }
} }
''' """
expected = { expected = {"node": {"id": "4", "width": 400}}
'node': {
'id': '4',
'width': 400
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_gets_the_correct_typename_for_users(): def test_gets_the_correct_typename_for_users():
query = ''' query = """
{ {
node(id: "1") { node(id: "1") {
id id
__typename __typename
} }
} }
''' """
expected = { expected = {"node": {"id": "1", "__typename": "User"}}
'node': {
'id': '1',
'__typename': 'User'
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_gets_the_correct_typename_for_photos(): def test_gets_the_correct_typename_for_photos():
query = ''' query = """
{ {
node(id: "4") { node(id: "4") {
id id
__typename __typename
} }
} }
''' """
expected = { expected = {"node": {"id": "4", "__typename": "Photo"}}
'node': {
'id': '4',
'__typename': 'Photo'
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_ignores_photo_fragments_on_user(): def test_ignores_photo_fragments_on_user():
query = ''' query = """
{ {
node(id: "1") { node(id: "1") {
id id
@ -217,35 +184,29 @@ def test_ignores_photo_fragments_on_user():
} }
} }
} }
''' """
expected = { expected = {"node": {"id": "1"}}
'node': {
'id': '1',
}
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_returns_null_for_bad_ids(): def test_returns_null_for_bad_ids():
query = ''' query = """
{ {
node(id: "5") { node(id: "5") {
id id
} }
} }
''' """
expected = { expected = {"node": None}
'node': None
}
result = graphql(schema, query) result = graphql(schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
def test_have_correct_node_interface(): def test_have_correct_node_interface():
query = ''' query = """
{ {
__type(name: "Node") { __type(name: "Node") {
name name
@ -262,23 +223,20 @@ def test_have_correct_node_interface():
} }
} }
} }
''' """
expected = { expected = {
'__type': { "__type": {
'name': 'Node', "name": "Node",
'kind': 'INTERFACE', "kind": "INTERFACE",
'fields': [ "fields": [
{ {
'name': 'id', "name": "id",
'type': { "type": {
'kind': 'NON_NULL', "kind": "NON_NULL",
'ofType': { "ofType": {"name": "ID", "kind": "SCALAR"},
'name': 'ID', },
'kind': 'SCALAR'
} }
} ],
}
]
} }
} }
result = graphql(schema, query) result = graphql(schema, query)
@ -287,7 +245,7 @@ def test_have_correct_node_interface():
def test_has_correct_node_root_field(): def test_has_correct_node_root_field():
query = ''' query = """
{ {
__schema { __schema {
queryType { queryType {
@ -311,29 +269,23 @@ def test_has_correct_node_root_field():
} }
} }
} }
''' """
expected = { expected = {
'__schema': { "__schema": {
'queryType': { "queryType": {
'fields': [ "fields": [
{ {
'name': 'node', "name": "node",
'type': { "type": {"name": "Node", "kind": "INTERFACE"},
'name': 'Node', "args": [
'kind': 'INTERFACE' {
"name": "id",
"type": {
"kind": "NON_NULL",
"ofType": {"name": "ID", "kind": "SCALAR"},
}, },
'args': [
{
'name': 'id',
'type': {
'kind': 'NON_NULL',
'ofType': {
'name': 'ID',
'kind': 'SCALAR'
} }
} ],
}
]
} }
] ]
} }

View File

@ -10,7 +10,7 @@ def default_format_error(error):
if isinstance(error, GraphQLError): if isinstance(error, GraphQLError):
return format_graphql_error(error) return format_graphql_error(error)
return {'message': six.text_type(error)} return {"message": six.text_type(error)}
def format_execution_result(execution_result, format_error): def format_execution_result(execution_result, format_error):
@ -18,18 +18,15 @@ def format_execution_result(execution_result, format_error):
response = {} response = {}
if execution_result.errors: if execution_result.errors:
response['errors'] = [ response["errors"] = [format_error(e) for e in execution_result.errors]
format_error(e) for e in execution_result.errors
]
if not execution_result.invalid: if not execution_result.invalid:
response['data'] = execution_result.data response["data"] = execution_result.data
return response return response
class Client(object): class Client(object):
def __init__(self, schema, format_error=None, **execute_options): def __init__(self, schema, format_error=None, **execute_options):
assert isinstance(schema, Schema) assert isinstance(schema, Schema)
self.schema = schema self.schema = schema
@ -40,8 +37,7 @@ class Client(object):
return format_execution_result(result, self.format_error) return format_execution_result(result, self.format_error)
def execute(self, *args, **kwargs): def execute(self, *args, **kwargs):
executed = self.schema.execute(*args, executed = self.schema.execute(*args, **dict(self.execute_options, **kwargs))
**dict(self.execute_options, **kwargs))
if is_thenable(executed): if is_thenable(executed):
return Promise.resolve(executed).then(self.format_result) return Promise.resolve(executed).then(self.format_result)

View File

@ -16,20 +16,18 @@ class Error(graphene.ObjectType):
class CreatePostResult(graphene.Union): class CreatePostResult(graphene.Union):
class Meta: class Meta:
types = [Success, Error] types = [Success, Error]
class CreatePost(graphene.Mutation): class CreatePost(graphene.Mutation):
class Input: class Input:
text = graphene.String(required=True) text = graphene.String(required=True)
result = graphene.Field(CreatePostResult) result = graphene.Field(CreatePostResult)
def mutate(self, info, text): def mutate(self, info, text):
result = Success(yeah='yeah') result = Success(yeah="yeah")
return CreatePost(result=result) return CreatePost(result=result)
@ -37,11 +35,12 @@ class CreatePost(graphene.Mutation):
class Mutations(graphene.ObjectType): class Mutations(graphene.ObjectType):
create_post = CreatePost.Field() create_post = CreatePost.Field()
# tests.py # tests.py
def test_create_post(): def test_create_post():
query_string = ''' query_string = """
mutation { mutation {
createPost(text: "Try this out") { createPost(text: "Try this out") {
result { result {
@ -49,10 +48,10 @@ def test_create_post():
} }
} }
} }
''' """
schema = graphene.Schema(query=Query, mutation=Mutations) schema = graphene.Schema(query=Query, mutation=Mutations)
result = schema.execute(query_string) result = schema.execute(query_string)
assert not result.errors assert not result.errors
assert result.data['createPost']['result']['__typename'] == 'Success' assert result.data["createPost"]["result"]["__typename"] == "Success"

View File

@ -15,7 +15,6 @@ class SomeTypeTwo(graphene.ObjectType):
class MyUnion(graphene.Union): class MyUnion(graphene.Union):
class Meta: class Meta:
types = (SomeTypeOne, SomeTypeTwo) types = (SomeTypeOne, SomeTypeTwo)
@ -28,6 +27,6 @@ def test_issue():
graphene.Schema(query=Query) graphene.Schema(query=Query)
assert str(exc_info.value) == ( assert str(exc_info.value) == (
'IterableConnectionField type have to be a subclass of Connection. ' "IterableConnectionField type have to be a subclass of Connection. "
'Received "MyUnion".' 'Received "MyUnion".'
) )

View File

@ -12,35 +12,35 @@ class SpecialOptions(ObjectTypeOptions):
class SpecialObjectType(ObjectType): class SpecialObjectType(ObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, other_attr='default', **options): def __init_subclass_with_meta__(cls, other_attr="default", **options):
_meta = SpecialOptions(cls) _meta = SpecialOptions(cls)
_meta.other_attr = other_attr _meta.other_attr = other_attr
super(SpecialObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(SpecialObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, **options
)
def test_special_objecttype_could_be_subclassed(): def test_special_objecttype_could_be_subclassed():
class MyType(SpecialObjectType): class MyType(SpecialObjectType):
class Meta: class Meta:
other_attr = 'yeah!' other_attr = "yeah!"
assert MyType._meta.other_attr == 'yeah!' assert MyType._meta.other_attr == "yeah!"
def test_special_objecttype_could_be_subclassed_default(): def test_special_objecttype_could_be_subclassed_default():
class MyType(SpecialObjectType): class MyType(SpecialObjectType):
pass pass
assert MyType._meta.other_attr == 'default' assert MyType._meta.other_attr == "default"
def test_special_objecttype_inherit_meta_options(): def test_special_objecttype_inherit_meta_options():
class MyType(SpecialObjectType): class MyType(SpecialObjectType):
pass pass
assert MyType._meta.name == 'MyType' assert MyType._meta.name == "MyType"
assert MyType._meta.default_resolver is None assert MyType._meta.default_resolver is None
assert MyType._meta.interfaces == () assert MyType._meta.interfaces == ()
@ -51,35 +51,35 @@ class SpecialInputObjectTypeOptions(ObjectTypeOptions):
class SpecialInputObjectType(InputObjectType): class SpecialInputObjectType(InputObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, other_attr='default', **options): def __init_subclass_with_meta__(cls, other_attr="default", **options):
_meta = SpecialInputObjectTypeOptions(cls) _meta = SpecialInputObjectTypeOptions(cls)
_meta.other_attr = other_attr _meta.other_attr = other_attr
super(SpecialInputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(SpecialInputObjectType, cls).__init_subclass_with_meta__(
_meta=_meta, **options
)
def test_special_inputobjecttype_could_be_subclassed(): def test_special_inputobjecttype_could_be_subclassed():
class MyInputObjectType(SpecialInputObjectType): class MyInputObjectType(SpecialInputObjectType):
class Meta: class Meta:
other_attr = 'yeah!' other_attr = "yeah!"
assert MyInputObjectType._meta.other_attr == 'yeah!' assert MyInputObjectType._meta.other_attr == "yeah!"
def test_special_inputobjecttype_could_be_subclassed_default(): def test_special_inputobjecttype_could_be_subclassed_default():
class MyInputObjectType(SpecialInputObjectType): class MyInputObjectType(SpecialInputObjectType):
pass pass
assert MyInputObjectType._meta.other_attr == 'default' assert MyInputObjectType._meta.other_attr == "default"
def test_special_inputobjecttype_inherit_meta_options(): def test_special_inputobjecttype_inherit_meta_options():
class MyInputObjectType(SpecialInputObjectType): class MyInputObjectType(SpecialInputObjectType):
pass pass
assert MyInputObjectType._meta.name == 'MyInputObjectType' assert MyInputObjectType._meta.name == "MyInputObjectType"
# Enum # Enum
@ -88,9 +88,8 @@ class SpecialEnumOptions(EnumOptions):
class SpecialEnum(Enum): class SpecialEnum(Enum):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, other_attr='default', **options): def __init_subclass_with_meta__(cls, other_attr="default", **options):
_meta = SpecialEnumOptions(cls) _meta = SpecialEnumOptions(cls)
_meta.other_attr = other_attr _meta.other_attr = other_attr
super(SpecialEnum, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(SpecialEnum, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@ -98,22 +97,21 @@ class SpecialEnum(Enum):
def test_special_enum_could_be_subclassed(): def test_special_enum_could_be_subclassed():
class MyEnum(SpecialEnum): class MyEnum(SpecialEnum):
class Meta: class Meta:
other_attr = 'yeah!' other_attr = "yeah!"
assert MyEnum._meta.other_attr == 'yeah!' assert MyEnum._meta.other_attr == "yeah!"
def test_special_enum_could_be_subclassed_default(): def test_special_enum_could_be_subclassed_default():
class MyEnum(SpecialEnum): class MyEnum(SpecialEnum):
pass pass
assert MyEnum._meta.other_attr == 'default' assert MyEnum._meta.other_attr == "default"
def test_special_enum_inherit_meta_options(): def test_special_enum_inherit_meta_options():
class MyEnum(SpecialEnum): class MyEnum(SpecialEnum):
pass pass
assert MyEnum._meta.name == 'MyEnum' assert MyEnum._meta.name == "MyEnum"

View File

@ -11,14 +11,14 @@ class Query(graphene.ObjectType):
def test_issue(): def test_issue():
query_string = ''' query_string = """
query myQuery { query myQuery {
someField(from: "Oh") someField(from: "Oh")
} }
''' """
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query_string) result = schema.execute(query_string)
assert not result.errors assert not result.errors
assert result.data['someField'] == 'Oh' assert result.data["someField"] == "Oh"

View File

@ -7,19 +7,19 @@ import graphene
class MyInputClass(graphene.InputObjectType): class MyInputClass(graphene.InputObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__( def __init_subclass_with_meta__(
cls, container=None, _meta=None, fields=None, **options): cls, container=None, _meta=None, fields=None, **options
):
if _meta is None: if _meta is None:
_meta = graphene.types.inputobjecttype.InputObjectTypeOptions(cls) _meta = graphene.types.inputobjecttype.InputObjectTypeOptions(cls)
_meta.fields = fields _meta.fields = fields
super(MyInputClass, cls).__init_subclass_with_meta__( super(MyInputClass, cls).__init_subclass_with_meta__(
container=container, _meta=_meta, **options) container=container, _meta=_meta, **options
)
class MyInput(MyInputClass): class MyInput(MyInputClass):
class Meta: class Meta:
fields = dict(x=graphene.Field(graphene.Int)) fields = dict(x=graphene.Field(graphene.Int))
@ -28,15 +28,15 @@ class Query(graphene.ObjectType):
myField = graphene.Field(graphene.String, input=graphene.Argument(MyInput)) myField = graphene.Field(graphene.String, input=graphene.Argument(MyInput))
def resolve_myField(parent, info, input): def resolve_myField(parent, info, input):
return 'ok' return "ok"
def test_issue(): def test_issue():
query_string = ''' query_string = """
query myQuery { query myQuery {
myField(input: {x: 1}) myField(input: {x: 1})
} }
''' """
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
result = schema.execute(query_string) result = schema.execute(query_string)

View File

@ -24,33 +24,32 @@ from .abstracttype import AbstractType
__all__ = [ __all__ = [
'ObjectType', "ObjectType",
'InputObjectType', "InputObjectType",
'Interface', "Interface",
'Mutation', "Mutation",
'Enum', "Enum",
'Field', "Field",
'InputField', "InputField",
'Schema', "Schema",
'Scalar', "Scalar",
'String', "String",
'ID', "ID",
'Int', "Int",
'Float', "Float",
'Date', "Date",
'DateTime', "DateTime",
'Time', "Time",
'JSONString', "JSONString",
'UUID', "UUID",
'Boolean', "Boolean",
'List', "List",
'NonNull', "NonNull",
'Argument', "Argument",
'Dynamic', "Dynamic",
'Union', "Union",
'Context', "Context",
'ResolveInfo', "ResolveInfo",
# Deprecated # Deprecated
'AbstractType', "AbstractType",
] ]

View File

@ -3,7 +3,6 @@ from ..utils.subclass_with_meta import SubclassWithMeta
class AbstractType(SubclassWithMeta): class AbstractType(SubclassWithMeta):
def __init_subclass__(cls, *args, **kwargs): def __init_subclass__(cls, *args, **kwargs):
warn_deprecation( warn_deprecation(
"Abstract type is deprecated, please use normal object inheritance instead.\n" "Abstract type is deprecated, please use normal object inheritance instead.\n"

View File

@ -8,8 +8,15 @@ from .utils import get_type
class Argument(MountedType): class Argument(MountedType):
def __init__(
def __init__(self, type, default_value=None, description=None, name=None, required=False, _creation_counter=None): self,
type,
default_value=None,
description=None,
name=None,
required=False,
_creation_counter=None,
):
super(Argument, self).__init__(_creation_counter=_creation_counter) super(Argument, self).__init__(_creation_counter=_creation_counter)
if required: if required:
@ -26,10 +33,10 @@ class Argument(MountedType):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, Argument) and ( return isinstance(other, Argument) and (
self.name == other.name and self.name == other.name
self.type == other.type and and self.type == other.type
self.default_value == other.default_value and and self.default_value == other.default_value
self.description == other.description and self.description == other.description
) )
@ -37,6 +44,7 @@ def to_arguments(args, extra_args=None):
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
from .field import Field from .field import Field
from .inputfield import InputField from .inputfield import InputField
if extra_args: if extra_args:
extra_args = sorted(extra_args.items(), key=lambda f: f[1]) extra_args = sorted(extra_args.items(), key=lambda f: f[1])
else: else:
@ -55,17 +63,21 @@ def to_arguments(args, extra_args=None):
arg = Argument.mounted(arg) arg = Argument.mounted(arg)
if isinstance(arg, (InputField, Field)): if isinstance(arg, (InputField, Field)):
raise ValueError('Expected {} to be Argument, but received {}. Try using Argument({}).'.format( raise ValueError(
default_name, "Expected {} to be Argument, but received {}. Try using Argument({}).".format(
type(arg).__name__, default_name, type(arg).__name__, arg.type
arg.type )
)) )
if not isinstance(arg, Argument): if not isinstance(arg, Argument):
raise ValueError('Unknown argument "{}".'.format(default_name)) raise ValueError('Unknown argument "{}".'.format(default_name))
arg_name = default_name or arg.name arg_name = default_name or arg.name
assert arg_name not in arguments, 'More than one Argument have same name "{}".'.format(arg_name) assert (
arg_name not in arguments
), 'More than one Argument have same name "{}".'.format(
arg_name
)
arguments[arg_name] = arg arguments[arg_name] = arg
return arguments return arguments

View File

@ -25,10 +25,9 @@ class BaseOptions(object):
class BaseType(SubclassWithMeta): class BaseType(SubclassWithMeta):
@classmethod @classmethod
def create_type(cls, class_name, **options): def create_type(cls, class_name, **options):
return type(class_name, (cls, ), {'Meta': options}) return type(class_name, (cls,), {"Meta": options})
@classmethod @classmethod
def __init_subclass_with_meta__(cls, name=None, description=None, _meta=None): def __init_subclass_with_meta__(cls, name=None, description=None, _meta=None):

View File

@ -9,19 +9,19 @@ from .scalars import Scalar
class Date(Scalar): class Date(Scalar):
''' """
The `Date` scalar type represents a Date The `Date` scalar type represents a Date
value as specified by value as specified by
[iso8601](https://en.wikipedia.org/wiki/ISO_8601). [iso8601](https://en.wikipedia.org/wiki/ISO_8601).
''' """
@staticmethod @staticmethod
def serialize(date): def serialize(date):
if isinstance(date, datetime.datetime): if isinstance(date, datetime.datetime):
date = date.date() date = date.date()
assert isinstance(date, datetime.date), ( assert isinstance(
'Received not compatible date "{}"'.format(repr(date)) date, datetime.date
) ), 'Received not compatible date "{}"'.format(repr(date))
return date.isoformat() return date.isoformat()
@classmethod @classmethod
@ -38,17 +38,17 @@ class Date(Scalar):
class DateTime(Scalar): class DateTime(Scalar):
''' """
The `DateTime` scalar type represents a DateTime The `DateTime` scalar type represents a DateTime
value as specified by value as specified by
[iso8601](https://en.wikipedia.org/wiki/ISO_8601). [iso8601](https://en.wikipedia.org/wiki/ISO_8601).
''' """
@staticmethod @staticmethod
def serialize(dt): def serialize(dt):
assert isinstance(dt, (datetime.datetime, datetime.date)), ( assert isinstance(
'Received not compatible datetime "{}"'.format(repr(dt)) dt, (datetime.datetime, datetime.date)
) ), 'Received not compatible datetime "{}"'.format(repr(dt))
return dt.isoformat() return dt.isoformat()
@classmethod @classmethod
@ -65,17 +65,17 @@ class DateTime(Scalar):
class Time(Scalar): class Time(Scalar):
''' """
The `Time` scalar type represents a Time value as The `Time` scalar type represents a Time value as
specified by specified by
[iso8601](https://en.wikipedia.org/wiki/ISO_8601). [iso8601](https://en.wikipedia.org/wiki/ISO_8601).
''' """
@staticmethod @staticmethod
def serialize(time): def serialize(time):
assert isinstance(time, datetime.time), ( assert isinstance(
'Received not compatible time "{}"'.format(repr(time)) time, datetime.time
) ), 'Received not compatible time "{}"'.format(repr(time))
return time.isoformat() return time.isoformat()
@classmethod @classmethod

View File

@ -1,16 +1,21 @@
from graphql import (GraphQLEnumType, GraphQLInputObjectType, from graphql import (
GraphQLInterfaceType, GraphQLObjectType, GraphQLEnumType,
GraphQLScalarType, GraphQLUnionType) GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLScalarType,
GraphQLUnionType,
)
class GrapheneGraphQLType(object): class GrapheneGraphQLType(object):
''' """
A class for extending the base GraphQLType with the related A class for extending the base GraphQLType with the related
graphene_type graphene_type
''' """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.graphene_type = kwargs.pop('graphene_type') self.graphene_type = kwargs.pop("graphene_type")
super(GrapheneGraphQLType, self).__init__(*args, **kwargs) super(GrapheneGraphQLType, self).__init__(*args, **kwargs)

View File

@ -5,10 +5,10 @@ from .mountedtype import MountedType
class Dynamic(MountedType): class Dynamic(MountedType):
''' """
A Dynamic Type let us get the type in runtime when we generate A Dynamic Type let us get the type in runtime when we generate
the schema. So we can have lazy fields. the schema. So we can have lazy fields.
''' """
def __init__(self, type, with_schema=False, _creation_counter=None): def __init__(self, type, with_schema=False, _creation_counter=None):
super(Dynamic, self).__init__(_creation_counter=_creation_counter) super(Dynamic, self).__init__(_creation_counter=_creation_counter)

View File

@ -24,14 +24,15 @@ class EnumOptions(BaseOptions):
class EnumMeta(SubclassWithMeta_Meta): class EnumMeta(SubclassWithMeta_Meta):
def __new__(cls, name, bases, classdict, **options): def __new__(cls, name, bases, classdict, **options):
enum_members = OrderedDict(classdict, __eq__=eq_enum) enum_members = OrderedDict(classdict, __eq__=eq_enum)
# We remove the Meta attribute from the class to not collide # We remove the Meta attribute from the class to not collide
# with the enum values. # with the enum values.
enum_members.pop('Meta', None) enum_members.pop("Meta", None)
enum = PyEnum(cls.__name__, enum_members) enum = PyEnum(cls.__name__, enum_members)
return SubclassWithMeta_Meta.__new__(cls, name, bases, OrderedDict(classdict, __enum__=enum), **options) return SubclassWithMeta_Meta.__new__(
cls, name, bases, OrderedDict(classdict, __enum__=enum), **options
)
def get(cls, value): def get(cls, value):
return cls._meta.enum(value) return cls._meta.enum(value)
@ -44,7 +45,7 @@ class EnumMeta(SubclassWithMeta_Meta):
def __call__(cls, *args, **kwargs): # noqa: N805 def __call__(cls, *args, **kwargs): # noqa: N805
if cls is Enum: if cls is Enum:
description = kwargs.pop('description', None) description = kwargs.pop("description", None)
return cls.from_enum(PyEnum(*args, **kwargs), description=description) return cls.from_enum(PyEnum(*args, **kwargs), description=description)
return super(EnumMeta, cls).__call__(*args, **kwargs) return super(EnumMeta, cls).__call__(*args, **kwargs)
# return cls._meta.enum(*args, **kwargs) # return cls._meta.enum(*args, **kwargs)
@ -52,22 +53,21 @@ class EnumMeta(SubclassWithMeta_Meta):
def from_enum(cls, enum, description=None, deprecation_reason=None): # noqa: N805 def from_enum(cls, enum, description=None, deprecation_reason=None): # noqa: N805
description = description or enum.__doc__ description = description or enum.__doc__
meta_dict = { meta_dict = {
'enum': enum, "enum": enum,
'description': description, "description": description,
'deprecation_reason': deprecation_reason "deprecation_reason": deprecation_reason,
} }
meta_class = type('Meta', (object,), meta_dict) meta_class = type("Meta", (object,), meta_dict)
return type(meta_class.enum.__name__, (Enum,), {'Meta': meta_class}) return type(meta_class.enum.__name__, (Enum,), {"Meta": meta_class})
class Enum(six.with_metaclass(EnumMeta, UnmountedType, BaseType)): class Enum(six.with_metaclass(EnumMeta, UnmountedType, BaseType)):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, enum=None, _meta=None, **options): def __init_subclass_with_meta__(cls, enum=None, _meta=None, **options):
if not _meta: if not _meta:
_meta = EnumOptions(cls) _meta = EnumOptions(cls)
_meta.enum = enum or cls.__enum__ _meta.enum = enum or cls.__enum__
_meta.deprecation_reason = options.pop('deprecation_reason', None) _meta.deprecation_reason = options.pop("deprecation_reason", None)
for key, value in _meta.enum.__members__.items(): for key, value in _meta.enum.__members__.items():
setattr(cls, key, value) setattr(cls, key, value)
@ -75,8 +75,8 @@ class Enum(six.with_metaclass(EnumMeta, UnmountedType, BaseType)):
@classmethod @classmethod
def get_type(cls): def get_type(cls):
''' """
This function is called when the unmounted type (Enum instance) This function is called when the unmounted type (Enum instance)
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
''' """
return cls return cls

View File

@ -19,18 +19,27 @@ def source_resolver(source, root, info, **args):
class Field(MountedType): class Field(MountedType):
def __init__(
def __init__(self, type, args=None, resolver=None, source=None, self,
deprecation_reason=None, name=None, description=None, type,
required=False, _creation_counter=None, default_value=None, args=None,
**extra_args): resolver=None,
source=None,
deprecation_reason=None,
name=None,
description=None,
required=False,
_creation_counter=None,
default_value=None,
**extra_args
):
super(Field, self).__init__(_creation_counter=_creation_counter) super(Field, self).__init__(_creation_counter=_creation_counter)
assert not args or isinstance(args, Mapping), ( assert not args or isinstance(args, Mapping), (
'Arguments in a field have to be a mapping, received "{}".' 'Arguments in a field have to be a mapping, received "{}".'
).format(args) ).format(args)
assert not (source and resolver), ( assert not (
'A Field cannot have a source and a resolver in at the same time.' source and resolver
) ), "A Field cannot have a source and a resolver in at the same time."
assert not callable(default_value), ( assert not callable(default_value), (
'The default value can not be a function but received "{}".' 'The default value can not be a function but received "{}".'
).format(base_type(default_value)) ).format(base_type(default_value))
@ -40,12 +49,12 @@ class Field(MountedType):
# Check if name is actually an argument of the field # Check if name is actually an argument of the field
if isinstance(name, (Argument, UnmountedType)): if isinstance(name, (Argument, UnmountedType)):
extra_args['name'] = name extra_args["name"] = name
name = None name = None
# Check if source is actually an argument of the field # Check if source is actually an argument of the field
if isinstance(source, (Argument, UnmountedType)): if isinstance(source, (Argument, UnmountedType)):
extra_args['source'] = source extra_args["source"] = source
source = None source = None
self.name = name self.name = name

View File

@ -1,7 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from graphql.language.ast import (BooleanValue, FloatValue, IntValue, from graphql.language.ast import (
ListValue, ObjectValue, StringValue) BooleanValue,
FloatValue,
IntValue,
ListValue,
ObjectValue,
StringValue,
)
from graphene.types.scalars import MAX_INT, MIN_INT from graphene.types.scalars import MAX_INT, MIN_INT
@ -35,6 +41,9 @@ class GenericScalar(Scalar):
elif isinstance(ast, ListValue): elif isinstance(ast, ListValue):
return [GenericScalar.parse_literal(value) for value in ast.values] return [GenericScalar.parse_literal(value) for value in ast.values]
elif isinstance(ast, ObjectValue): elif isinstance(ast, ObjectValue):
return {field.name.value: GenericScalar.parse_literal(field.value) for field in ast.fields} return {
field.name.value: GenericScalar.parse_literal(field.value)
for field in ast.fields
}
else: else:
return None return None

View File

@ -4,10 +4,17 @@ from .utils import get_type
class InputField(MountedType): class InputField(MountedType):
def __init__(
def __init__(self, type, name=None, default_value=None, self,
deprecation_reason=None, description=None, type,
required=False, _creation_counter=None, **extra_args): name=None,
default_value=None,
deprecation_reason=None,
description=None,
required=False,
_creation_counter=None,
**extra_args
):
super(InputField, self).__init__(_creation_counter=_creation_counter) super(InputField, self).__init__(_creation_counter=_creation_counter)
self.name = name self.name = name
if required: if required:

View File

@ -30,14 +30,14 @@ class InputObjectTypeContainer(dict, BaseType):
class InputObjectType(UnmountedType, BaseType): class InputObjectType(UnmountedType, BaseType):
''' """
Input Object Type Definition Input Object Type Definition
An input object defines a structured collection of fields which may be An input object defines a structured collection of fields which may be
supplied to a field argument. supplied to a field argument.
Using `NonNull` will ensure that a value must be provided by the query Using `NonNull` will ensure that a value must be provided by the query
''' """
@classmethod @classmethod
def __init_subclass_with_meta__(cls, container=None, _meta=None, **options): def __init_subclass_with_meta__(cls, container=None, _meta=None, **options):
@ -46,9 +46,7 @@ class InputObjectType(UnmountedType, BaseType):
fields = OrderedDict() fields = OrderedDict()
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update( fields.update(yank_fields_from_attrs(base.__dict__, _as=InputField))
yank_fields_from_attrs(base.__dict__, _as=InputField)
)
if _meta.fields: if _meta.fields:
_meta.fields.update(fields) _meta.fields.update(fields)
@ -57,13 +55,12 @@ class InputObjectType(UnmountedType, BaseType):
if container is None: if container is None:
container = type(cls.__name__, (InputObjectTypeContainer, cls), {}) container = type(cls.__name__, (InputObjectTypeContainer, cls), {})
_meta.container = container _meta.container = container
super(InputObjectType, cls).__init_subclass_with_meta__( super(InputObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options)
_meta=_meta, **options)
@classmethod @classmethod
def get_type(cls): def get_type(cls):
''' """
This function is called when the unmounted type (InputObjectType instance) This function is called when the unmounted type (InputObjectType instance)
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
''' """
return cls return cls

View File

@ -15,14 +15,15 @@ class InterfaceOptions(BaseOptions):
class Interface(BaseType): class Interface(BaseType):
''' """
Interface Type Definition Interface Type Definition
When a field can return one of a heterogeneous set of types, a Interface type When a field can return one of a heterogeneous set of types, a Interface type
is used to describe what types are possible, what fields are in common across is used to describe what types are possible, what fields are in common across
all types, as well as a function to determine which type is actually used all types, as well as a function to determine which type is actually used
when the field is resolved. when the field is resolved.
''' """
@classmethod @classmethod
def __init_subclass_with_meta__(cls, _meta=None, **options): def __init_subclass_with_meta__(cls, _meta=None, **options):
if not _meta: if not _meta:
@ -30,9 +31,7 @@ class Interface(BaseType):
fields = OrderedDict() fields = OrderedDict()
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update( fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
yank_fields_from_attrs(base.__dict__, _as=Field)
)
if _meta.fields: if _meta.fields:
_meta.fields.update(fields) _meta.fields.update(fields)
@ -44,6 +43,7 @@ class Interface(BaseType):
@classmethod @classmethod
def resolve_type(cls, instance, info): def resolve_type(cls, instance, info):
from .objecttype import ObjectType from .objecttype import ObjectType
if isinstance(instance, ObjectType): if isinstance(instance, ObjectType):
return type(instance) return type(instance)

View File

@ -8,7 +8,7 @@ from .scalars import Scalar
class JSONString(Scalar): class JSONString(Scalar):
'''JSON String''' """JSON String"""
@staticmethod @staticmethod
def serialize(dt): def serialize(dt):

View File

@ -3,15 +3,14 @@ from .unmountedtype import UnmountedType
class MountedType(OrderedType): class MountedType(OrderedType):
@classmethod @classmethod
def mounted(cls, unmounted): # noqa: N802 def mounted(cls, unmounted): # noqa: N802
''' """
Mount the UnmountedType instance Mount the UnmountedType instance
''' """
assert isinstance(unmounted, UnmountedType), ( assert isinstance(unmounted, UnmountedType), ("{} can't mount {}").format(
'{} can\'t mount {}' cls.__name__, repr(unmounted)
).format(cls.__name__, repr(unmounted)) )
return cls( return cls(
unmounted.get_type(), unmounted.get_type(),

View File

@ -21,37 +21,39 @@ class MutationOptions(ObjectTypeOptions):
class Mutation(ObjectType): class Mutation(ObjectType):
''' """
Mutation Type Definition Mutation Type Definition
''' """
@classmethod @classmethod
def __init_subclass_with_meta__(cls, resolver=None, output=None, arguments=None, def __init_subclass_with_meta__(
_meta=None, **options): cls, resolver=None, output=None, arguments=None, _meta=None, **options
):
if not _meta: if not _meta:
_meta = MutationOptions(cls) _meta = MutationOptions(cls)
output = output or getattr(cls, 'Output', None) output = output or getattr(cls, "Output", None)
fields = {} fields = {}
if not output: if not output:
# If output is defined, we don't need to get the fields # If output is defined, we don't need to get the fields
fields = OrderedDict() fields = OrderedDict()
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update( fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
yank_fields_from_attrs(base.__dict__, _as=Field)
)
output = cls output = cls
if not arguments: if not arguments:
input_class = getattr(cls, 'Arguments', None) input_class = getattr(cls, "Arguments", None)
if not input_class: if not input_class:
input_class = getattr(cls, 'Input', None) input_class = getattr(cls, "Input", None)
if input_class: if input_class:
warn_deprecation(( warn_deprecation(
(
"Please use {name}.Arguments instead of {name}.Input." "Please use {name}.Arguments instead of {name}.Input."
"Input is now only used in ClientMutationID.\n" "Input is now only used in ClientMutationID.\n"
"Read more:" "Read more:"
" https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input" " https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input"
).format(name=cls.__name__)) ).format(name=cls.__name__)
)
if input_class: if input_class:
arguments = props(input_class) arguments = props(input_class)
@ -59,8 +61,8 @@ class Mutation(ObjectType):
arguments = {} arguments = {}
if not resolver: if not resolver:
mutate = getattr(cls, 'mutate', None) mutate = getattr(cls, "mutate", None)
assert mutate, 'All mutations must define a mutate method in it' assert mutate, "All mutations must define a mutate method in it"
resolver = get_unbound_function(mutate) resolver = get_unbound_function(mutate)
if _meta.fields: if _meta.fields:
@ -72,11 +74,12 @@ class Mutation(ObjectType):
_meta.resolver = resolver _meta.resolver = resolver
_meta.arguments = arguments _meta.arguments = arguments
super(Mutation, cls).__init_subclass_with_meta__( super(Mutation, cls).__init_subclass_with_meta__(_meta=_meta, **options)
_meta=_meta, **options)
@classmethod @classmethod
def Field(cls, name=None, description=None, deprecation_reason=None, required=False): def Field(
cls, name=None, description=None, deprecation_reason=None, required=False
):
return Field( return Field(
cls._meta.output, cls._meta.output,
args=cls._meta.arguments, args=cls._meta.arguments,

View File

@ -17,17 +17,22 @@ class ObjectTypeOptions(BaseOptions):
class ObjectType(BaseType): class ObjectType(BaseType):
''' """
Object Type Definition Object Type Definition
Almost all of the GraphQL types you define will be object types. Object types Almost all of the GraphQL types you define will be object types. Object types
have a name, but most importantly describe their fields. have a name, but most importantly describe their fields.
''' """
@classmethod @classmethod
def __init_subclass_with_meta__( def __init_subclass_with_meta__(
cls, interfaces=(), cls,
interfaces=(),
possible_types=(), possible_types=(),
default_resolver=None, _meta=None, **options): default_resolver=None,
_meta=None,
**options
):
if not _meta: if not _meta:
_meta = ObjectTypeOptions(cls) _meta = ObjectTypeOptions(cls)
@ -40,13 +45,11 @@ class ObjectType(BaseType):
fields.update(interface._meta.fields) fields.update(interface._meta.fields)
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update( fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
yank_fields_from_attrs(base.__dict__, _as=Field)
)
assert not (possible_types and cls.is_type_of), ( assert not (possible_types and cls.is_type_of), (
'{name}.Meta.possible_types will cause type collision with {name}.is_type_of. ' "{name}.Meta.possible_types will cause type collision with {name}.is_type_of. "
'Please use one or other.' "Please use one or other."
).format(name=cls.__name__) ).format(name=cls.__name__)
if _meta.fields: if _meta.fields:
@ -82,8 +85,7 @@ class ObjectType(BaseType):
for name, field in fields_iter: for name, field in fields_iter:
try: try:
val = kwargs.pop( val = kwargs.pop(
name, name, field.default_value if isinstance(field, Field) else None
field.default_value if isinstance(field, Field) else None
) )
setattr(self, name, val) setattr(self, name, val)
except KeyError: except KeyError:
@ -92,14 +94,15 @@ class ObjectType(BaseType):
if kwargs: if kwargs:
for prop in list(kwargs): for prop in list(kwargs):
try: try:
if isinstance(getattr(self.__class__, prop), property) or prop.startswith('_'): if isinstance(
getattr(self.__class__, prop), property
) or prop.startswith("_"):
setattr(self, prop, kwargs.pop(prop)) setattr(self, prop, kwargs.pop(prop))
except AttributeError: except AttributeError:
pass pass
if kwargs: if kwargs:
raise TypeError( raise TypeError(
"'{}' is an invalid keyword argument for {}".format( "'{}' is an invalid keyword argument for {}".format(
list(kwargs)[0], list(kwargs)[0], self.__class__.__name__
self.__class__.__name__
) )
) )

View File

@ -11,7 +11,7 @@ default_resolver = attr_resolver
def set_default_resolver(resolver): def set_default_resolver(resolver):
global default_resolver global default_resolver
assert callable(resolver), 'Received non-callable resolver.' assert callable(resolver), "Received non-callable resolver."
default_resolver = resolver default_resolver = resolver

View File

@ -1,6 +1,5 @@
import six import six
from graphql.language.ast import (BooleanValue, FloatValue, IntValue, from graphql.language.ast import BooleanValue, FloatValue, IntValue, StringValue
StringValue)
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
@ -11,13 +10,14 @@ class ScalarOptions(BaseOptions):
class Scalar(UnmountedType, BaseType): class Scalar(UnmountedType, BaseType):
''' """
Scalar Type Definition Scalar Type Definition
The leaf values of any request and input values to arguments are The leaf values of any request and input values to arguments are
Scalars (or Enums) and are defined with a name and a series of functions Scalars (or Enums) and are defined with a name and a series of functions
used to parse input from ast or variables and to ensure validity. used to parse input from ast or variables and to ensure validity.
''' """
@classmethod @classmethod
def __init_subclass_with_meta__(cls, **options): def __init_subclass_with_meta__(cls, **options):
_meta = ScalarOptions(cls) _meta = ScalarOptions(cls)
@ -29,10 +29,10 @@ class Scalar(UnmountedType, BaseType):
@classmethod @classmethod
def get_type(cls): def get_type(cls):
''' """
This function is called when the unmounted type (Scalar instance) This function is called when the unmounted type (Scalar instance)
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
''' """
return cls return cls
@ -46,12 +46,12 @@ MIN_INT = -2147483648
class Int(Scalar): class Int(Scalar):
''' """
The `Int` scalar type represents non-fractional signed whole numeric The `Int` scalar type represents non-fractional signed whole numeric
values. Int can represent values between -(2^53 - 1) and 2^53 - 1 since values. Int can represent values between -(2^53 - 1) and 2^53 - 1 since
represented in JSON as double-precision floating point numbers specified represented in JSON as double-precision floating point numbers specified
by [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point). by [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point).
''' """
@staticmethod @staticmethod
def coerce_int(value): def coerce_int(value):
@ -77,11 +77,11 @@ class Int(Scalar):
class Float(Scalar): class Float(Scalar):
''' """
The `Float` scalar type represents signed double-precision fractional The `Float` scalar type represents signed double-precision fractional
values as specified by values as specified by
[IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point). [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point).
''' """
@staticmethod @staticmethod
def coerce_float(value): def coerce_float(value):
@ -101,16 +101,16 @@ class Float(Scalar):
class String(Scalar): class String(Scalar):
''' """
The `String` scalar type represents textual data, represented as UTF-8 The `String` scalar type represents textual data, represented as UTF-8
character sequences. The String type is most often used by GraphQL to character sequences. The String type is most often used by GraphQL to
represent free-form human-readable text. represent free-form human-readable text.
''' """
@staticmethod @staticmethod
def coerce_string(value): def coerce_string(value):
if isinstance(value, bool): if isinstance(value, bool):
return u'true' if value else u'false' return u"true" if value else u"false"
return six.text_type(value) return six.text_type(value)
serialize = coerce_string serialize = coerce_string
@ -123,9 +123,9 @@ class String(Scalar):
class Boolean(Scalar): class Boolean(Scalar):
''' """
The `Boolean` scalar type represents `true` or `false`. The `Boolean` scalar type represents `true` or `false`.
''' """
serialize = bool serialize = bool
parse_value = bool parse_value = bool
@ -137,13 +137,13 @@ class Boolean(Scalar):
class ID(Scalar): class ID(Scalar):
''' """
The `ID` scalar type represents a unique identifier, often used to The `ID` scalar type represents a unique identifier, often used to
refetch an object or as key for a cache. The ID type appears in a JSON refetch an object or as key for a cache. The ID type appears in a JSON
response as a String; however, it is not intended to be human-readable. response as a String; however, it is not intended to be human-readable.
When expected as an input type, any string (such as `"4"`) or integer When expected as an input type, any string (such as `"4"`) or integer
(such as `4`) input value will be accepted as an ID. (such as `4`) input value will be accepted as an ID.
''' """
serialize = str serialize = str
parse_value = str parse_value = str

View File

@ -1,8 +1,11 @@
import inspect import inspect
from graphql import GraphQLObjectType, GraphQLSchema, graphql, is_type from graphql import GraphQLObjectType, GraphQLSchema, graphql, is_type
from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective, from graphql.type.directives import (
GraphQLSkipDirective) GraphQLDirective,
GraphQLIncludeDirective,
GraphQLSkipDirective,
)
from graphql.type.introspection import IntrospectionSchema from graphql.type.introspection import IntrospectionSchema
from graphql.utils.introspection_query import introspection_query from graphql.utils.introspection_query import introspection_query
from graphql.utils.schema_printer import print_schema from graphql.utils.schema_printer import print_schema
@ -15,8 +18,7 @@ from .typemap import TypeMap, is_graphene_type
def assert_valid_root_type(_type): def assert_valid_root_type(_type):
if _type is None: if _type is None:
return return
is_graphene_objecttype = inspect.isclass( is_graphene_objecttype = inspect.isclass(_type) and issubclass(_type, ObjectType)
_type) and issubclass(_type, ObjectType)
is_graphql_objecttype = isinstance(_type, GraphQLObjectType) is_graphql_objecttype = isinstance(_type, GraphQLObjectType)
assert is_graphene_objecttype or is_graphql_objecttype, ( assert is_graphene_objecttype or is_graphql_objecttype, (
"Type {} is not a valid ObjectType." "Type {} is not a valid ObjectType."
@ -24,20 +26,22 @@ def assert_valid_root_type(_type):
class Schema(GraphQLSchema): class Schema(GraphQLSchema):
''' """
Schema Definition Schema Definition
A Schema is created by supplying the root types of each type of operation, A Schema is created by supplying the root types of each type of operation,
query and mutation (optional). query and mutation (optional).
''' """
def __init__(self, def __init__(
self,
query=None, query=None,
mutation=None, mutation=None,
subscription=None, subscription=None,
directives=None, directives=None,
types=None, types=None,
auto_camelcase=True): auto_camelcase=True,
):
assert_valid_root_type(query) assert_valid_root_type(query)
assert_valid_root_type(mutation) assert_valid_root_type(mutation)
assert_valid_root_type(subscription) assert_valid_root_type(subscription)
@ -49,8 +53,9 @@ class Schema(GraphQLSchema):
if directives is None: if directives is None:
directives = [GraphQLIncludeDirective, GraphQLSkipDirective] directives = [GraphQLIncludeDirective, GraphQLSkipDirective]
assert all(isinstance(d, GraphQLDirective) for d in directives), \ assert all(
'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( isinstance(d, GraphQLDirective) for d in directives
), "Schema directives must be List[GraphQLDirective] if provided but got: {}.".format(
directives directives
) )
self._directives = directives self._directives = directives
@ -66,16 +71,15 @@ class Schema(GraphQLSchema):
return self.get_graphql_type(self._subscription) return self.get_graphql_type(self._subscription)
def __getattr__(self, type_name): def __getattr__(self, type_name):
''' """
This function let the developer select a type in a given schema This function let the developer select a type in a given schema
by accessing its attrs. by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema Example: using schema.Query for accessing the "Query" type in the Schema
''' """
_type = super(Schema, self).get_type(type_name) _type = super(Schema, self).get_type(type_name)
if _type is None: if _type is None:
raise AttributeError( raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
'Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType): if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type return _type.graphene_type
return _type return _type
@ -88,7 +92,8 @@ class Schema(GraphQLSchema):
if is_graphene_type(_type): if is_graphene_type(_type):
graphql_type = self.get_type(_type._meta.name) graphql_type = self.get_type(_type._meta.name)
assert graphql_type, "Type {} not found in this schema.".format( assert graphql_type, "Type {} not found in this schema.".format(
_type._meta.name) _type._meta.name
)
assert graphql_type.graphene_type == _type assert graphql_type.graphene_type == _type
return graphql_type return graphql_type
raise Exception("{} is not a valid GraphQL type.".format(_type)) raise Exception("{} is not a valid GraphQL type.".format(_type))
@ -113,12 +118,10 @@ class Schema(GraphQLSchema):
self._query, self._query,
self._mutation, self._mutation,
self._subscription, self._subscription,
IntrospectionSchema IntrospectionSchema,
] ]
if self.types: if self.types:
initial_types += self.types initial_types += self.types
self._type_map = TypeMap( self._type_map = TypeMap(
initial_types, initial_types, auto_camelcase=self.auto_camelcase, schema=self
auto_camelcase=self.auto_camelcase,
schema=self
) )

View File

@ -3,22 +3,21 @@ from .utils import get_type
class Structure(UnmountedType): class Structure(UnmountedType):
''' """
A structure is a GraphQL type instance that A structure is a GraphQL type instance that
wraps a main type with certain structure. wraps a main type with certain structure.
''' """
def __init__(self, of_type, *args, **kwargs): def __init__(self, of_type, *args, **kwargs):
super(Structure, self).__init__(*args, **kwargs) super(Structure, self).__init__(*args, **kwargs)
if not isinstance(of_type, Structure) and isinstance(of_type, UnmountedType): if not isinstance(of_type, Structure) and isinstance(of_type, UnmountedType):
cls_name = type(self).__name__ cls_name = type(self).__name__
of_type_name = type(of_type).__name__ of_type_name = type(of_type).__name__
raise Exception("{} could not have a mounted {}() as inner type. Try with {}({}).".format( raise Exception(
cls_name, "{} could not have a mounted {}() as inner type. Try with {}({}).".format(
of_type_name, cls_name, of_type_name, cls_name, of_type_name
cls_name, )
of_type_name, )
))
self._of_type = of_type self._of_type = of_type
@property @property
@ -26,35 +25,35 @@ class Structure(UnmountedType):
return get_type(self._of_type) return get_type(self._of_type)
def get_type(self): def get_type(self):
''' """
This function is called when the unmounted type (List or NonNull instance) This function is called when the unmounted type (List or NonNull instance)
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
''' """
return self return self
class List(Structure): class List(Structure):
''' """
List Modifier List Modifier
A list is a kind of type marker, a wrapping type which points to another A list is a kind of type marker, a wrapping type which points to another
type. Lists are often created within the context of defining the fields of type. Lists are often created within the context of defining the fields of
an object type. an object type.
''' """
def __str__(self): def __str__(self):
return '[{}]'.format(self.of_type) return "[{}]".format(self.of_type)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, List) and ( return isinstance(other, List) and (
self.of_type == other.of_type and self.of_type == other.of_type
self.args == other.args and and self.args == other.args
self.kwargs == other.kwargs and self.kwargs == other.kwargs
) )
class NonNull(Structure): class NonNull(Structure):
''' """
Non-Null Modifier Non-Null Modifier
A non-null is a kind of type marker, a wrapping type which points to another A non-null is a kind of type marker, a wrapping type which points to another
@ -64,20 +63,20 @@ class NonNull(Structure):
usually the id field of a database row will never be null. usually the id field of a database row will never be null.
Note: the enforcement of non-nullability occurs within the executor. Note: the enforcement of non-nullability occurs within the executor.
''' """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(NonNull, self).__init__(*args, **kwargs) super(NonNull, self).__init__(*args, **kwargs)
assert not isinstance(self._of_type, NonNull), ( assert not isinstance(self._of_type, NonNull), (
'Can only create NonNull of a Nullable GraphQLType but got: {}.' "Can only create NonNull of a Nullable GraphQLType but got: {}."
).format(self._of_type) ).format(self._of_type)
def __str__(self): def __str__(self):
return '{}!'.format(self.of_type) return "{}!".format(self.of_type)
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, NonNull) and ( return isinstance(other, NonNull) and (
self.of_type == other.of_type and self.of_type == other.of_type
self.args == other.args and and self.args == other.args
self.kwargs == other.kwargs and self.kwargs == other.kwargs
) )

View File

@ -10,13 +10,12 @@ class MyType(ObjectType):
class MyScalar(UnmountedType): class MyScalar(UnmountedType):
def get_type(self): def get_type(self):
return MyType return MyType
def test_abstract_objecttype_warn_deprecation(mocker): def test_abstract_objecttype_warn_deprecation(mocker):
mocker.patch.object(abstracttype, 'warn_deprecation') mocker.patch.object(abstracttype, "warn_deprecation")
class MyAbstractType(AbstractType): class MyAbstractType(AbstractType):
field1 = MyScalar() field1 = MyScalar()
@ -34,5 +33,5 @@ def test_generate_objecttype_inherit_abstracttype():
assert MyObjectType._meta.description is None assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == () assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ['field1', 'field2'] assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field] assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]

View File

@ -10,16 +10,16 @@ from ..structures import NonNull
def test_argument(): def test_argument():
arg = Argument(String, default_value='a', description='desc', name='b') arg = Argument(String, default_value="a", description="desc", name="b")
assert arg.type == String assert arg.type == String
assert arg.default_value == 'a' assert arg.default_value == "a"
assert arg.description == 'desc' assert arg.description == "desc"
assert arg.name == 'b' assert arg.name == "b"
def test_argument_comparasion(): def test_argument_comparasion():
arg1 = Argument(String, name='Hey', description='Desc', default_value='default') arg1 = Argument(String, name="Hey", description="Desc", default_value="default")
arg2 = Argument(String, name='Hey', description='Desc', default_value='default') arg2 = Argument(String, name="Hey", description="Desc", default_value="default")
assert arg1 == arg2 assert arg1 == arg2
assert arg1 != String() assert arg1 != String()
@ -31,43 +31,36 @@ def test_argument_required():
def test_to_arguments(): def test_to_arguments():
args = { args = {"arg_string": Argument(String), "unmounted_arg": String(required=True)}
'arg_string': Argument(String),
'unmounted_arg': String(required=True)
}
my_args = to_arguments(args) my_args = to_arguments(args)
assert my_args == { assert my_args == {
'arg_string': Argument(String), "arg_string": Argument(String),
'unmounted_arg': Argument(String, required=True) "unmounted_arg": Argument(String, required=True),
} }
def test_to_arguments_raises_if_field(): def test_to_arguments_raises_if_field():
args = { args = {"arg_string": Field(String)}
'arg_string': Field(String),
}
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
to_arguments(args) to_arguments(args)
assert str(exc_info.value) == ( assert str(exc_info.value) == (
'Expected arg_string to be Argument, but received Field. Try using ' "Expected arg_string to be Argument, but received Field. Try using "
'Argument(String).' "Argument(String)."
) )
def test_to_arguments_raises_if_inputfield(): def test_to_arguments_raises_if_inputfield():
args = { args = {"arg_string": InputField(String)}
'arg_string': InputField(String),
}
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
to_arguments(args) to_arguments(args)
assert str(exc_info.value) == ( assert str(exc_info.value) == (
'Expected arg_string to be Argument, but received InputField. Try ' "Expected arg_string to be Argument, but received InputField. Try "
'using Argument(String).' "using Argument(String)."
) )

View File

@ -23,7 +23,8 @@ def test_basetype():
def test_basetype_nones(): def test_basetype_nones():
class MyBaseType(CustomType): class MyBaseType(CustomType):
'''Documentation''' """Documentation"""
class Meta: class Meta:
name = None name = None
description = None description = None
@ -35,10 +36,11 @@ def test_basetype_nones():
def test_basetype_custom(): def test_basetype_custom():
class MyBaseType(CustomType): class MyBaseType(CustomType):
'''Documentation''' """Documentation"""
class Meta: class Meta:
name = 'Base' name = "Base"
description = 'Desc' description = "Desc"
assert isinstance(MyBaseType._meta, CustomOptions) assert isinstance(MyBaseType._meta, CustomOptions)
assert MyBaseType._meta.name == "Base" assert MyBaseType._meta.name == "Base"
@ -46,7 +48,7 @@ def test_basetype_custom():
def test_basetype_create(): def test_basetype_create():
MyBaseType = CustomType.create_type('MyBaseType') MyBaseType = CustomType.create_type("MyBaseType")
assert isinstance(MyBaseType._meta, CustomOptions) assert isinstance(MyBaseType._meta, CustomOptions)
assert MyBaseType._meta.name == "MyBaseType" assert MyBaseType._meta.name == "MyBaseType"
@ -54,7 +56,7 @@ def test_basetype_create():
def test_basetype_create_extra(): def test_basetype_create_extra():
MyBaseType = CustomType.create_type('MyBaseType', name='Base', description='Desc') MyBaseType = CustomType.create_type("MyBaseType", name="Base", description="Desc")
assert isinstance(MyBaseType._meta, CustomOptions) assert isinstance(MyBaseType._meta, CustomOptions)
assert MyBaseType._meta.name == "Base" assert MyBaseType._meta.name == "Base"

View File

@ -9,9 +9,9 @@ from ..schema import Schema
class Query(ObjectType): class Query(ObjectType):
datetime = DateTime(_in=DateTime(name='in')) datetime = DateTime(_in=DateTime(name="in"))
date = Date(_in=Date(name='in')) date = Date(_in=Date(name="in"))
time = Time(_at=Time(name='at')) time = Time(_at=Time(name="at"))
def resolve_datetime(self, info, _in=None): def resolve_datetime(self, info, _in=None):
return _in return _in
@ -30,35 +30,34 @@ def test_datetime_query():
now = datetime.datetime.now().replace(tzinfo=pytz.utc) now = datetime.datetime.now().replace(tzinfo=pytz.utc)
isoformat = now.isoformat() isoformat = now.isoformat()
result = schema.execute('''{ datetime(in: "%s") }''' % isoformat) result = schema.execute("""{ datetime(in: "%s") }""" % isoformat)
assert not result.errors assert not result.errors
assert result.data == {'datetime': isoformat} assert result.data == {"datetime": isoformat}
def test_date_query(): def test_date_query():
now = datetime.datetime.now().replace(tzinfo=pytz.utc).date() now = datetime.datetime.now().replace(tzinfo=pytz.utc).date()
isoformat = now.isoformat() isoformat = now.isoformat()
result = schema.execute('''{ date(in: "%s") }''' % isoformat) result = schema.execute("""{ date(in: "%s") }""" % isoformat)
assert not result.errors assert not result.errors
assert result.data == {'date': isoformat} assert result.data == {"date": isoformat}
def test_time_query(): def test_time_query():
now = datetime.datetime.now().replace(tzinfo=pytz.utc) now = datetime.datetime.now().replace(tzinfo=pytz.utc)
time = datetime.time(now.hour, now.minute, now.second, now.microsecond, time = datetime.time(now.hour, now.minute, now.second, now.microsecond, now.tzinfo)
now.tzinfo)
isoformat = time.isoformat() isoformat = time.isoformat()
result = schema.execute('''{ time(at: "%s") }''' % isoformat) result = schema.execute("""{ time(at: "%s") }""" % isoformat)
assert not result.errors assert not result.errors
assert result.data == {'time': isoformat} assert result.data == {"time": isoformat}
def test_bad_datetime_query(): def test_bad_datetime_query():
not_a_date = "Some string that's not a date" not_a_date = "Some string that's not a date"
result = schema.execute('''{ datetime(in: "%s") }''' % not_a_date) result = schema.execute("""{ datetime(in: "%s") }""" % not_a_date)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError) assert isinstance(result.errors[0], GraphQLError)
@ -68,7 +67,7 @@ def test_bad_datetime_query():
def test_bad_date_query(): def test_bad_date_query():
not_a_date = "Some string that's not a date" not_a_date = "Some string that's not a date"
result = schema.execute('''{ date(in: "%s") }''' % not_a_date) result = schema.execute("""{ date(in: "%s") }""" % not_a_date)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError) assert isinstance(result.errors[0], GraphQLError)
@ -78,7 +77,7 @@ def test_bad_date_query():
def test_bad_time_query(): def test_bad_time_query():
not_a_date = "Some string that's not a date" not_a_date = "Some string that's not a date"
result = schema.execute('''{ time(at: "%s") }''' % not_a_date) result = schema.execute("""{ time(at: "%s") }""" % not_a_date)
assert len(result.errors) == 1 assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError) assert isinstance(result.errors[0], GraphQLError)
@ -90,10 +89,11 @@ def test_datetime_query_variable():
isoformat = now.isoformat() isoformat = now.isoformat()
result = schema.execute( result = schema.execute(
'''query Test($date: DateTime){ datetime(in: $date) }''', """query Test($date: DateTime){ datetime(in: $date) }""",
variable_values={'date': isoformat}) variable_values={"date": isoformat},
)
assert not result.errors assert not result.errors
assert result.data == {'datetime': isoformat} assert result.data == {"datetime": isoformat}
def test_date_query_variable(): def test_date_query_variable():
@ -101,20 +101,21 @@ def test_date_query_variable():
isoformat = now.isoformat() isoformat = now.isoformat()
result = schema.execute( result = schema.execute(
'''query Test($date: Date){ date(in: $date) }''', """query Test($date: Date){ date(in: $date) }""",
variable_values={'date': isoformat}) variable_values={"date": isoformat},
)
assert not result.errors assert not result.errors
assert result.data == {'date': isoformat} assert result.data == {"date": isoformat}
def test_time_query_variable(): def test_time_query_variable():
now = datetime.datetime.now().replace(tzinfo=pytz.utc) now = datetime.datetime.now().replace(tzinfo=pytz.utc)
time = datetime.time(now.hour, now.minute, now.second, now.microsecond, time = datetime.time(now.hour, now.minute, now.second, now.microsecond, now.tzinfo)
now.tzinfo)
isoformat = time.isoformat() isoformat = time.isoformat()
result = schema.execute( result = schema.execute(
'''query Test($time: Time){ time(at: $time) }''', """query Test($time: Time){ time(at: $time) }""",
variable_values={'time': isoformat}) variable_values={"time": isoformat},
)
assert not result.errors assert not result.errors
assert result.data == {'time': isoformat} assert result.data == {"time": isoformat}

View File

@ -56,13 +56,12 @@ class MyInterface(Interface):
class MyUnion(Union): class MyUnion(Union):
class Meta: class Meta:
types = (Article,) types = (Article,)
class MyEnum(Enum): class MyEnum(Enum):
foo = 'foo' foo = "foo"
class MyInputObjectType(InputObjectType): class MyInputObjectType(InputObjectType):
@ -74,24 +73,24 @@ def test_defines_a_query_only_schema():
assert blog_schema.get_query_type().graphene_type == Query assert blog_schema.get_query_type().graphene_type == Query
article_field = Query._meta.fields['article'] article_field = Query._meta.fields["article"]
assert article_field.type == Article assert article_field.type == Article
assert article_field.type._meta.name == 'Article' assert article_field.type._meta.name == "Article"
article_field_type = article_field.type article_field_type = article_field.type
assert issubclass(article_field_type, ObjectType) assert issubclass(article_field_type, ObjectType)
title_field = article_field_type._meta.fields['title'] title_field = article_field_type._meta.fields["title"]
assert title_field.type == String assert title_field.type == String
author_field = article_field_type._meta.fields['author'] author_field = article_field_type._meta.fields["author"]
author_field_type = author_field.type author_field_type = author_field.type
assert issubclass(author_field_type, ObjectType) assert issubclass(author_field_type, ObjectType)
recent_article_field = author_field_type._meta.fields['recent_article'] recent_article_field = author_field_type._meta.fields["recent_article"]
assert recent_article_field.type == Article assert recent_article_field.type == Article
feed_field = Query._meta.fields['feed'] feed_field = Query._meta.fields["feed"]
assert feed_field.type.of_type == Article assert feed_field.type.of_type == Article
@ -100,9 +99,9 @@ def test_defines_a_mutation_schema():
assert blog_schema.get_mutation_type().graphene_type == Mutation assert blog_schema.get_mutation_type().graphene_type == Mutation
write_mutation = Mutation._meta.fields['write_article'] write_mutation = Mutation._meta.fields["write_article"]
assert write_mutation.type == Article assert write_mutation.type == Article
assert write_mutation.type._meta.name == 'Article' assert write_mutation.type._meta.name == "Article"
def test_defines_a_subscription_schema(): def test_defines_a_subscription_schema():
@ -110,9 +109,9 @@ def test_defines_a_subscription_schema():
assert blog_schema.get_subscription_type().graphene_type == Subscription assert blog_schema.get_subscription_type().graphene_type == Subscription
subscription = Subscription._meta.fields['article_subscribe'] subscription = Subscription._meta.fields["article_subscribe"]
assert subscription.type == Article assert subscription.type == Article
assert subscription.type._meta.name == 'Article' assert subscription.type._meta.name == "Article"
def test_includes_nested_input_objects_in_the_map(): def test_includes_nested_input_objects_in_the_map():
@ -128,13 +127,9 @@ def test_includes_nested_input_objects_in_the_map():
class SomeSubscription(Mutation): class SomeSubscription(Mutation):
subscribe_to_something = Field(Article, input=Argument(SomeInputObject)) subscribe_to_something = Field(Article, input=Argument(SomeInputObject))
schema = Schema( schema = Schema(query=Query, mutation=SomeMutation, subscription=SomeSubscription)
query=Query,
mutation=SomeMutation,
subscription=SomeSubscription
)
assert schema.get_type_map()['NestedInputObject'].graphene_type is NestedInputObject assert schema.get_type_map()["NestedInputObject"].graphene_type is NestedInputObject
def test_includes_interfaces_thunk_subtypes_in_the_type_map(): def test_includes_interfaces_thunk_subtypes_in_the_type_map():
@ -142,19 +137,15 @@ def test_includes_interfaces_thunk_subtypes_in_the_type_map():
f = Int() f = Int()
class SomeSubtype(ObjectType): class SomeSubtype(ObjectType):
class Meta: class Meta:
interfaces = (SomeInterface,) interfaces = (SomeInterface,)
class Query(ObjectType): class Query(ObjectType):
iface = Field(lambda: SomeInterface) iface = Field(lambda: SomeInterface)
schema = Schema( schema = Schema(query=Query, types=[SomeSubtype])
query=Query,
types=[SomeSubtype]
)
assert schema.get_type_map()['SomeSubtype'].graphene_type is SomeSubtype assert schema.get_type_map()["SomeSubtype"].graphene_type is SomeSubtype
def test_includes_types_in_union(): def test_includes_types_in_union():
@ -165,19 +156,16 @@ def test_includes_types_in_union():
b = String() b = String()
class MyUnion(Union): class MyUnion(Union):
class Meta: class Meta:
types = (SomeType, OtherType) types = (SomeType, OtherType)
class Query(ObjectType): class Query(ObjectType):
union = Field(MyUnion) union = Field(MyUnion)
schema = Schema( schema = Schema(query=Query)
query=Query,
)
assert schema.get_type_map()['OtherType'].graphene_type is OtherType assert schema.get_type_map()["OtherType"].graphene_type is OtherType
assert schema.get_type_map()['SomeType'].graphene_type is SomeType assert schema.get_type_map()["SomeType"].graphene_type is SomeType
def test_maps_enum(): def test_maps_enum():
@ -188,19 +176,16 @@ def test_maps_enum():
b = String() b = String()
class MyUnion(Union): class MyUnion(Union):
class Meta: class Meta:
types = (SomeType, OtherType) types = (SomeType, OtherType)
class Query(ObjectType): class Query(ObjectType):
union = Field(MyUnion) union = Field(MyUnion)
schema = Schema( schema = Schema(query=Query)
query=Query,
)
assert schema.get_type_map()['OtherType'].graphene_type is OtherType assert schema.get_type_map()["OtherType"].graphene_type is OtherType
assert schema.get_type_map()['SomeType'].graphene_type is SomeType assert schema.get_type_map()["SomeType"].graphene_type is SomeType
def test_includes_interfaces_subtypes_in_the_type_map(): def test_includes_interfaces_subtypes_in_the_type_map():
@ -208,33 +193,29 @@ def test_includes_interfaces_subtypes_in_the_type_map():
f = Int() f = Int()
class SomeSubtype(ObjectType): class SomeSubtype(ObjectType):
class Meta: class Meta:
interfaces = (SomeInterface,) interfaces = (SomeInterface,)
class Query(ObjectType): class Query(ObjectType):
iface = Field(SomeInterface) iface = Field(SomeInterface)
schema = Schema( schema = Schema(query=Query, types=[SomeSubtype])
query=Query,
types=[SomeSubtype]
)
assert schema.get_type_map()['SomeSubtype'].graphene_type is SomeSubtype assert schema.get_type_map()["SomeSubtype"].graphene_type is SomeSubtype
def test_stringifies_simple_types(): def test_stringifies_simple_types():
assert str(Int) == 'Int' assert str(Int) == "Int"
assert str(Article) == 'Article' assert str(Article) == "Article"
assert str(MyInterface) == 'MyInterface' assert str(MyInterface) == "MyInterface"
assert str(MyUnion) == 'MyUnion' assert str(MyUnion) == "MyUnion"
assert str(MyEnum) == 'MyEnum' assert str(MyEnum) == "MyEnum"
assert str(MyInputObjectType) == 'MyInputObjectType' assert str(MyInputObjectType) == "MyInputObjectType"
assert str(NonNull(Int)) == 'Int!' assert str(NonNull(Int)) == "Int!"
assert str(List(Int)) == '[Int]' assert str(List(Int)) == "[Int]"
assert str(NonNull(List(Int))) == '[Int]!' assert str(NonNull(List(Int))) == "[Int]!"
assert str(List(NonNull(Int))) == '[Int!]' assert str(List(NonNull(Int))) == "[Int!]"
assert str(List(List(Int))) == '[[Int]]' assert str(List(List(Int))) == "[[Int]]"
# def test_identifies_input_types(): # def test_identifies_input_types():

View File

@ -8,30 +8,31 @@ from ..structures import List, NonNull
def test_dynamic(): def test_dynamic():
dynamic = Dynamic(lambda: String) dynamic = Dynamic(lambda: String)
assert dynamic.get_type() == String assert dynamic.get_type() == String
assert str(dynamic.get_type()) == 'String' assert str(dynamic.get_type()) == "String"
def test_nonnull(): def test_nonnull():
dynamic = Dynamic(lambda: NonNull(String)) dynamic = Dynamic(lambda: NonNull(String))
assert dynamic.get_type().of_type == String assert dynamic.get_type().of_type == String
assert str(dynamic.get_type()) == 'String!' assert str(dynamic.get_type()) == "String!"
def test_list(): def test_list():
dynamic = Dynamic(lambda: List(String)) dynamic = Dynamic(lambda: List(String))
assert dynamic.get_type().of_type == String assert dynamic.get_type().of_type == String
assert str(dynamic.get_type()) == '[String]' assert str(dynamic.get_type()) == "[String]"
def test_list_non_null(): def test_list_non_null():
dynamic = Dynamic(lambda: List(NonNull(String))) dynamic = Dynamic(lambda: List(NonNull(String)))
assert dynamic.get_type().of_type.of_type == String assert dynamic.get_type().of_type.of_type == String
assert str(dynamic.get_type()) == '[String!]' assert str(dynamic.get_type()) == "[String!]"
def test_partial(): def test_partial():
def __type(_type): def __type(_type):
return _type return _type
dynamic = Dynamic(partial(__type, String)) dynamic = Dynamic(partial(__type, String))
assert dynamic.get_type() == String assert dynamic.get_type() == String
assert str(dynamic.get_type()) == 'String' assert str(dynamic.get_type()) == "String"

View File

@ -9,7 +9,8 @@ from ..schema import ObjectType, Schema
def test_enum_construction(): def test_enum_construction():
class RGB(Enum): class RGB(Enum):
'''Description''' """Description"""
RED = 1 RED = 1
GREEN = 2 GREEN = 2
BLUE = 3 BLUE = 3
@ -18,49 +19,41 @@ def test_enum_construction():
def description(self): def description(self):
return "Description {}".format(self.name) return "Description {}".format(self.name)
assert RGB._meta.name == 'RGB' assert RGB._meta.name == "RGB"
assert RGB._meta.description == 'Description' assert RGB._meta.description == "Description"
values = RGB._meta.enum.__members__.values() values = RGB._meta.enum.__members__.values()
assert sorted([v.name for v in values]) == [ assert sorted([v.name for v in values]) == ["BLUE", "GREEN", "RED"]
'BLUE',
'GREEN',
'RED'
]
assert sorted([v.description for v in values]) == [ assert sorted([v.description for v in values]) == [
'Description BLUE', "Description BLUE",
'Description GREEN', "Description GREEN",
'Description RED' "Description RED",
] ]
def test_enum_construction_meta(): def test_enum_construction_meta():
class RGB(Enum): class RGB(Enum):
class Meta: class Meta:
name = 'RGBEnum' name = "RGBEnum"
description = 'Description' description = "Description"
RED = 1 RED = 1
GREEN = 2 GREEN = 2
BLUE = 3 BLUE = 3
assert RGB._meta.name == 'RGBEnum' assert RGB._meta.name == "RGBEnum"
assert RGB._meta.description == 'Description' assert RGB._meta.description == "Description"
def test_enum_instance_construction(): def test_enum_instance_construction():
RGB = Enum('RGB', 'RED,GREEN,BLUE') RGB = Enum("RGB", "RED,GREEN,BLUE")
values = RGB._meta.enum.__members__.values() values = RGB._meta.enum.__members__.values()
assert sorted([v.name for v in values]) == [ assert sorted([v.name for v in values]) == ["BLUE", "GREEN", "RED"]
'BLUE',
'GREEN',
'RED'
]
def test_enum_from_builtin_enum(): def test_enum_from_builtin_enum():
PyRGB = PyEnum('RGB', 'RED,GREEN,BLUE') PyRGB = PyEnum("RGB", "RED,GREEN,BLUE")
RGB = Enum.from_enum(PyRGB) RGB = Enum.from_enum(PyRGB)
assert RGB._meta.enum == PyRGB assert RGB._meta.enum == PyRGB
@ -74,30 +67,51 @@ def test_enum_from_builtin_enum_accepts_lambda_description():
if not value: if not value:
return "StarWars Episodes" return "StarWars Episodes"
return 'New Hope Episode' if value == Episode.NEWHOPE else 'Other' return "New Hope Episode" if value == Episode.NEWHOPE else "Other"
def custom_deprecation_reason(value): def custom_deprecation_reason(value):
return 'meh' if value == Episode.NEWHOPE else None return "meh" if value == Episode.NEWHOPE else None
PyEpisode = PyEnum('PyEpisode', 'NEWHOPE,EMPIRE,JEDI') PyEpisode = PyEnum("PyEpisode", "NEWHOPE,EMPIRE,JEDI")
Episode = Enum.from_enum(PyEpisode, description=custom_description, Episode = Enum.from_enum(
deprecation_reason=custom_deprecation_reason) PyEpisode,
description=custom_description,
deprecation_reason=custom_deprecation_reason,
)
class Query(ObjectType): class Query(ObjectType):
foo = Episode() foo = Episode()
schema = Schema(query=Query) schema = Schema(query=Query)
GraphQLPyEpisode = schema._type_map['PyEpisode'].values GraphQLPyEpisode = schema._type_map["PyEpisode"].values
assert schema._type_map['PyEpisode'].description == "StarWars Episodes" assert schema._type_map["PyEpisode"].description == "StarWars Episodes"
assert GraphQLPyEpisode[0].name == 'NEWHOPE' and GraphQLPyEpisode[0].description == 'New Hope Episode' assert (
assert GraphQLPyEpisode[1].name == 'EMPIRE' and GraphQLPyEpisode[1].description == 'Other' GraphQLPyEpisode[0].name == "NEWHOPE"
assert GraphQLPyEpisode[2].name == 'JEDI' and GraphQLPyEpisode[2].description == 'Other' and GraphQLPyEpisode[0].description == "New Hope Episode"
)
assert (
GraphQLPyEpisode[1].name == "EMPIRE"
and GraphQLPyEpisode[1].description == "Other"
)
assert (
GraphQLPyEpisode[2].name == "JEDI"
and GraphQLPyEpisode[2].description == "Other"
)
assert GraphQLPyEpisode[0].name == 'NEWHOPE' and GraphQLPyEpisode[0].deprecation_reason == 'meh' assert (
assert GraphQLPyEpisode[1].name == 'EMPIRE' and GraphQLPyEpisode[1].deprecation_reason is None GraphQLPyEpisode[0].name == "NEWHOPE"
assert GraphQLPyEpisode[2].name == 'JEDI' and GraphQLPyEpisode[2].deprecation_reason is None and GraphQLPyEpisode[0].deprecation_reason == "meh"
)
assert (
GraphQLPyEpisode[1].name == "EMPIRE"
and GraphQLPyEpisode[1].deprecation_reason is None
)
assert (
GraphQLPyEpisode[2].name == "JEDI"
and GraphQLPyEpisode[2].deprecation_reason is None
)
def test_enum_from_python3_enum_uses_enum_doc(): def test_enum_from_python3_enum_uses_enum_doc():
@ -108,6 +122,7 @@ def test_enum_from_python3_enum_uses_enum_doc():
class Color(PyEnum): class Color(PyEnum):
"""This is the description""" """This is the description"""
RED = 1 RED = 1
GREEN = 2 GREEN = 2
BLUE = 3 BLUE = 3
@ -196,9 +211,9 @@ def test_enum_can_retrieve_members():
GREEN = 2 GREEN = 2
BLUE = 3 BLUE = 3
assert RGB['RED'] == RGB.RED assert RGB["RED"] == RGB.RED
assert RGB['GREEN'] == RGB.GREEN assert RGB["GREEN"] == RGB.GREEN
assert RGB['BLUE'] == RGB.BLUE assert RGB["BLUE"] == RGB.BLUE
def test_enum_to_enum_comparison_should_differ(): def test_enum_to_enum_comparison_should_differ():
@ -220,14 +235,14 @@ def test_enum_to_enum_comparison_should_differ():
def test_enum_skip_meta_from_members(): def test_enum_skip_meta_from_members():
class RGB1(Enum): class RGB1(Enum):
class Meta: class Meta:
name = 'RGB' name = "RGB"
RED = 1 RED = 1
GREEN = 2 GREEN = 2
BLUE = 3 BLUE = 3
assert dict(RGB1._meta.enum.__members__) == { assert dict(RGB1._meta.enum.__members__) == {
'RED': RGB1.RED, "RED": RGB1.RED,
'GREEN': RGB1.GREEN, "GREEN": RGB1.GREEN,
'BLUE': RGB1.BLUE, "BLUE": RGB1.BLUE,
} }

View File

@ -10,31 +10,33 @@ from .utils import MyLazyType
class MyInstance(object): class MyInstance(object):
value = 'value' value = "value"
value_func = staticmethod(lambda: 'value_func') value_func = staticmethod(lambda: "value_func")
def value_method(self): def value_method(self):
return 'value_method' return "value_method"
def test_field_basic(): def test_field_basic():
MyType = object() MyType = object()
args = {'my arg': Argument(True)} args = {"my arg": Argument(True)}
def resolver(): return None def resolver():
deprecation_reason = 'Deprecated now' return None
description = 'My Field'
my_default = 'something' deprecation_reason = "Deprecated now"
description = "My Field"
my_default = "something"
field = Field( field = Field(
MyType, MyType,
name='name', name="name",
args=args, args=args,
resolver=resolver, resolver=resolver,
description=description, description=description,
deprecation_reason=deprecation_reason, deprecation_reason=deprecation_reason,
default_value=my_default, default_value=my_default,
) )
assert field.name == 'name' assert field.name == "name"
assert field.args == args assert field.args == args
assert field.resolver == resolver assert field.resolver == resolver
assert field.deprecation_reason == deprecation_reason assert field.deprecation_reason == deprecation_reason
@ -55,12 +57,12 @@ def test_field_default_value_not_callable():
Field(MyType, default_value=lambda: True) Field(MyType, default_value=lambda: True)
except AssertionError as e: except AssertionError as e:
# substring comparison for py 2/3 compatibility # substring comparison for py 2/3 compatibility
assert 'The default value can not be a function but received' in str(e) assert "The default value can not be a function but received" in str(e)
def test_field_source(): def test_field_source():
MyType = object() MyType = object()
field = Field(MyType, source='value') field = Field(MyType, source="value")
assert field.resolver(MyInstance(), None) == MyInstance.value assert field.resolver(MyInstance(), None) == MyInstance.value
@ -84,46 +86,48 @@ def test_field_with_string_type():
def test_field_not_source_and_resolver(): def test_field_not_source_and_resolver():
MyType = object() MyType = object()
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
Field(MyType, source='value', resolver=lambda: None) Field(MyType, source="value", resolver=lambda: None)
assert str( assert (
exc_info.value) == 'A Field cannot have a source and a resolver in at the same time.' str(exc_info.value)
== "A Field cannot have a source and a resolver in at the same time."
)
def test_field_source_func(): def test_field_source_func():
MyType = object() MyType = object()
field = Field(MyType, source='value_func') field = Field(MyType, source="value_func")
assert field.resolver(MyInstance(), None) == MyInstance.value_func() assert field.resolver(MyInstance(), None) == MyInstance.value_func()
def test_field_source_method(): def test_field_source_method():
MyType = object() MyType = object()
field = Field(MyType, source='value_method') field = Field(MyType, source="value_method")
assert field.resolver(MyInstance(), None) == MyInstance().value_method() assert field.resolver(MyInstance(), None) == MyInstance().value_method()
def test_field_source_as_argument(): def test_field_source_as_argument():
MyType = object() MyType = object()
field = Field(MyType, source=String()) field = Field(MyType, source=String())
assert 'source' in field.args assert "source" in field.args
assert field.args['source'].type == String assert field.args["source"].type == String
def test_field_name_as_argument(): def test_field_name_as_argument():
MyType = object() MyType = object()
field = Field(MyType, name=String()) field = Field(MyType, name=String())
assert 'name' in field.args assert "name" in field.args
assert field.args['name'].type == String assert field.args["name"].type == String
def test_field_source_argument_as_kw(): def test_field_source_argument_as_kw():
MyType = object() MyType = object()
field = Field(MyType, b=NonNull(True), c=Argument(None), a=NonNull(False)) field = Field(MyType, b=NonNull(True), c=Argument(None), a=NonNull(False))
assert list(field.args.keys()) == ['b', 'c', 'a'] assert list(field.args.keys()) == ["b", "c", "a"]
assert isinstance(field.args['b'], Argument) assert isinstance(field.args["b"], Argument)
assert isinstance(field.args['b'].type, NonNull) assert isinstance(field.args["b"].type, NonNull)
assert field.args['b'].type.of_type is True assert field.args["b"].type.of_type is True
assert isinstance(field.args['c'], Argument) assert isinstance(field.args["c"], Argument)
assert field.args['c'].type is None assert field.args["c"].type is None
assert isinstance(field.args['a'], Argument) assert isinstance(field.args["a"], Argument)
assert isinstance(field.args['a'].type, NonNull) assert isinstance(field.args["a"].type, NonNull)
assert field.args['a'].type.of_type is False assert field.args["a"].type.of_type is False

View File

@ -18,44 +18,36 @@ def test_generic_query_variable():
1, 1,
1.1, 1.1,
True, True,
'str', "str",
[1, 2, 3], [1, 2, 3],
[1.1, 2.2, 3.3], [1.1, 2.2, 3.3],
[True, False], [True, False],
['str1', 'str2'], ["str1", "str2"],
{"key_a": "a", "key_b": "b"},
{ {
'key_a': 'a', "int": 1,
'key_b': 'b' "float": 1.1,
"boolean": True,
"string": "str",
"int_list": [1, 2, 3],
"float_list": [1.1, 2.2, 3.3],
"boolean_list": [True, False],
"string_list": ["str1", "str2"],
"nested_dict": {"key_a": "a", "key_b": "b"},
}, },
{ None,
'int': 1,
'float': 1.1,
'boolean': True,
'string': 'str',
'int_list': [1, 2, 3],
'float_list': [1.1, 2.2, 3.3],
'boolean_list': [True, False],
'string_list': ['str1', 'str2'],
'nested_dict': {
'key_a': 'a',
'key_b': 'b'
}
},
None
]: ]:
result = schema.execute( result = schema.execute(
'''query Test($generic: GenericScalar){ generic(input: $generic) }''', """query Test($generic: GenericScalar){ generic(input: $generic) }""",
variable_values={'generic': generic_value} variable_values={"generic": generic_value},
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"generic": generic_value}
'generic': generic_value
}
def test_generic_parse_literal_query(): def test_generic_parse_literal_query():
result = schema.execute( result = schema.execute(
''' """
query { query {
generic(input: { generic(input: {
int: 1, int: 1,
@ -73,23 +65,20 @@ def test_generic_parse_literal_query():
empty_key: undefined empty_key: undefined
}) })
} }
''' """
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
'generic': { "generic": {
'int': 1, "int": 1,
'float': 1.1, "float": 1.1,
'boolean': True, "boolean": True,
'string': 'str', "string": "str",
'int_list': [1, 2, 3], "int_list": [1, 2, 3],
'float_list': [1.1, 2.2, 3.3], "float_list": [1.1, 2.2, 3.3],
'boolean_list': [True, False], "boolean_list": [True, False],
'string_list': ['str1', 'str2'], "string_list": ["str1", "str2"],
'nested_dict': { "nested_dict": {"key_a": "a", "key_b": "b"},
'key_a': 'a', "empty_key": None,
'key_b': 'b'
},
'empty_key': None
} }
} }

View File

@ -14,14 +14,13 @@ class MyType(object):
class MyScalar(UnmountedType): class MyScalar(UnmountedType):
def get_type(self): def get_type(self):
return MyType return MyType
def test_generate_inputobjecttype(): def test_generate_inputobjecttype():
class MyInputObjectType(InputObjectType): class MyInputObjectType(InputObjectType):
'''Documentation''' """Documentation"""
assert MyInputObjectType._meta.name == "MyInputObjectType" assert MyInputObjectType._meta.name == "MyInputObjectType"
assert MyInputObjectType._meta.description == "Documentation" assert MyInputObjectType._meta.description == "Documentation"
@ -30,10 +29,9 @@ def test_generate_inputobjecttype():
def test_generate_inputobjecttype_with_meta(): def test_generate_inputobjecttype_with_meta():
class MyInputObjectType(InputObjectType): class MyInputObjectType(InputObjectType):
class Meta: class Meta:
name = 'MyOtherInputObjectType' name = "MyOtherInputObjectType"
description = 'Documentation' description = "Documentation"
assert MyInputObjectType._meta.name == "MyOtherInputObjectType" assert MyInputObjectType._meta.name == "MyOtherInputObjectType"
assert MyInputObjectType._meta.description == "Documentation" assert MyInputObjectType._meta.description == "Documentation"
@ -43,7 +41,7 @@ def test_generate_inputobjecttype_with_fields():
class MyInputObjectType(InputObjectType): class MyInputObjectType(InputObjectType):
field = Field(MyType) field = Field(MyType)
assert 'field' in MyInputObjectType._meta.fields assert "field" in MyInputObjectType._meta.fields
def test_ordered_fields_in_inputobjecttype(): def test_ordered_fields_in_inputobjecttype():
@ -53,16 +51,15 @@ def test_ordered_fields_in_inputobjecttype():
field = MyScalar() field = MyScalar()
asa = InputField(MyType) asa = InputField(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == [ assert list(MyInputObjectType._meta.fields.keys()) == ["b", "a", "field", "asa"]
'b', 'a', 'field', 'asa']
def test_generate_inputobjecttype_unmountedtype(): def test_generate_inputobjecttype_unmountedtype():
class MyInputObjectType(InputObjectType): class MyInputObjectType(InputObjectType):
field = MyScalar(MyType) field = MyScalar(MyType)
assert 'field' in MyInputObjectType._meta.fields assert "field" in MyInputObjectType._meta.fields
assert isinstance(MyInputObjectType._meta.fields['field'], InputField) assert isinstance(MyInputObjectType._meta.fields["field"], InputField)
def test_generate_inputobjecttype_as_argument(): def test_generate_inputobjecttype_as_argument():
@ -72,13 +69,13 @@ def test_generate_inputobjecttype_as_argument():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
field = Field(MyType, input=MyInputObjectType()) field = Field(MyType, input=MyInputObjectType())
assert 'field' in MyObjectType._meta.fields assert "field" in MyObjectType._meta.fields
field = MyObjectType._meta.fields['field'] field = MyObjectType._meta.fields["field"]
assert isinstance(field, Field) assert isinstance(field, Field)
assert field.type == MyType assert field.type == MyType
assert 'input' in field.args assert "input" in field.args
assert isinstance(field.args['input'], Argument) assert isinstance(field.args["input"], Argument)
assert field.args['input'].type == MyInputObjectType assert field.args["input"].type == MyInputObjectType
def test_generate_inputobjecttype_inherit_abstracttype(): def test_generate_inputobjecttype_inherit_abstracttype():
@ -88,9 +85,11 @@ def test_generate_inputobjecttype_inherit_abstracttype():
class MyInputObjectType(InputObjectType, MyAbstractType): class MyInputObjectType(InputObjectType, MyAbstractType):
field2 = MyScalar(MyType) field2 = MyScalar(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ['field1', 'field2'] assert list(MyInputObjectType._meta.fields.keys()) == ["field1", "field2"]
assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [ assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [
InputField, InputField] InputField,
InputField,
]
def test_generate_inputobjecttype_inherit_abstracttype_reversed(): def test_generate_inputobjecttype_inherit_abstracttype_reversed():
@ -100,9 +99,11 @@ def test_generate_inputobjecttype_inherit_abstracttype_reversed():
class MyInputObjectType(MyAbstractType, InputObjectType): class MyInputObjectType(MyAbstractType, InputObjectType):
field2 = MyScalar(MyType) field2 = MyScalar(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ['field1', 'field2'] assert list(MyInputObjectType._meta.fields.keys()) == ["field1", "field2"]
assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [ assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [
InputField, InputField] InputField,
InputField,
]
def test_inputobjecttype_of_input(): def test_inputobjecttype_of_input():
@ -121,14 +122,17 @@ def test_inputobjecttype_of_input():
is_child = Boolean(parent=Parent()) is_child = Boolean(parent=Parent())
def resolve_is_child(self, info, parent): def resolve_is_child(self, info, parent):
return isinstance(parent.child, Child) and parent.child.full_name == "Peter Griffin" return (
isinstance(parent.child, Child)
and parent.child.full_name == "Peter Griffin"
)
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute('''query basequery { result = schema.execute(
"""query basequery {
isChild(parent: {child: {firstName: "Peter", lastName: "Griffin"}}) isChild(parent: {child: {firstName: "Peter", lastName: "Griffin"}})
} }
''') """
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"isChild": True}
'isChild': True
}

View File

@ -8,14 +8,13 @@ class MyType(object):
class MyScalar(UnmountedType): class MyScalar(UnmountedType):
def get_type(self): def get_type(self):
return MyType return MyType
def test_generate_interface(): def test_generate_interface():
class MyInterface(Interface): class MyInterface(Interface):
'''Documentation''' """Documentation"""
assert MyInterface._meta.name == "MyInterface" assert MyInterface._meta.name == "MyInterface"
assert MyInterface._meta.description == "Documentation" assert MyInterface._meta.description == "Documentation"
@ -24,10 +23,9 @@ def test_generate_interface():
def test_generate_interface_with_meta(): def test_generate_interface_with_meta():
class MyInterface(Interface): class MyInterface(Interface):
class Meta: class Meta:
name = 'MyOtherInterface' name = "MyOtherInterface"
description = 'Documentation' description = "Documentation"
assert MyInterface._meta.name == "MyOtherInterface" assert MyInterface._meta.name == "MyOtherInterface"
assert MyInterface._meta.description == "Documentation" assert MyInterface._meta.description == "Documentation"
@ -37,7 +35,7 @@ def test_generate_interface_with_fields():
class MyInterface(Interface): class MyInterface(Interface):
field = Field(MyType) field = Field(MyType)
assert 'field' in MyInterface._meta.fields assert "field" in MyInterface._meta.fields
def test_ordered_fields_in_interface(): def test_ordered_fields_in_interface():
@ -47,15 +45,15 @@ def test_ordered_fields_in_interface():
field = MyScalar() field = MyScalar()
asa = Field(MyType) asa = Field(MyType)
assert list(MyInterface._meta.fields.keys()) == ['b', 'a', 'field', 'asa'] assert list(MyInterface._meta.fields.keys()) == ["b", "a", "field", "asa"]
def test_generate_interface_unmountedtype(): def test_generate_interface_unmountedtype():
class MyInterface(Interface): class MyInterface(Interface):
field = MyScalar() field = MyScalar()
assert 'field' in MyInterface._meta.fields assert "field" in MyInterface._meta.fields
assert isinstance(MyInterface._meta.fields['field'], Field) assert isinstance(MyInterface._meta.fields["field"], Field)
def test_generate_interface_inherit_abstracttype(): def test_generate_interface_inherit_abstracttype():
@ -65,7 +63,7 @@ def test_generate_interface_inherit_abstracttype():
class MyInterface(Interface, MyAbstractType): class MyInterface(Interface, MyAbstractType):
field2 = MyScalar() field2 = MyScalar()
assert list(MyInterface._meta.fields.keys()) == ['field1', 'field2'] assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
@ -76,8 +74,8 @@ def test_generate_interface_inherit_interface():
class MyInterface(MyBaseInterface): class MyInterface(MyBaseInterface):
field2 = MyScalar() field2 = MyScalar()
assert MyInterface._meta.name == 'MyInterface' assert MyInterface._meta.name == "MyInterface"
assert list(MyInterface._meta.fields.keys()) == ['field1', 'field2'] assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
@ -88,5 +86,5 @@ def test_generate_interface_inherit_abstracttype_reversed():
class MyInterface(MyAbstractType, Interface): class MyInterface(MyAbstractType, Interface):
field2 = MyScalar() field2 = MyScalar()
assert list(MyInterface._meta.fields.keys()) == ['field1', 'field2'] assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]

View File

@ -18,21 +18,17 @@ def test_jsonstring_query():
json_value = '{"key": "value"}' json_value = '{"key": "value"}'
json_value_quoted = json_value.replace('"', '\\"') json_value_quoted = json_value.replace('"', '\\"')
result = schema.execute('''{ json(input: "%s") }''' % json_value_quoted) result = schema.execute("""{ json(input: "%s") }""" % json_value_quoted)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"json": json_value}
'json': json_value
}
def test_jsonstring_query_variable(): def test_jsonstring_query_variable():
json_value = '{"key": "value"}' json_value = '{"key": "value"}'
result = schema.execute( result = schema.execute(
'''query Test($json: JSONString){ json(input: $json) }''', """query Test($json: JSONString){ json(input: $json) }""",
variable_values={'json': json_value} variable_values={"json": json_value},
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"json": json_value}
'json': json_value
}

View File

@ -4,9 +4,8 @@ from ..scalars import String
class CustomField(Field): class CustomField(Field):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.metadata = kwargs.pop('metadata', None) self.metadata = kwargs.pop("metadata", None)
super(CustomField, self).__init__(*args, **kwargs) super(CustomField, self).__init__(*args, **kwargs)
@ -18,8 +17,8 @@ def test_mounted_type():
def test_mounted_type_custom(): def test_mounted_type_custom():
unmounted = String(metadata={'hey': 'yo!'}) unmounted = String(metadata={"hey": "yo!"})
mounted = CustomField.mounted(unmounted) mounted = CustomField.mounted(unmounted)
assert isinstance(mounted, CustomField) assert isinstance(mounted, CustomField)
assert mounted.type == String assert mounted.type == String
assert mounted.metadata == {'hey': 'yo!'} assert mounted.metadata == {"hey": "yo!"}

View File

@ -11,7 +11,7 @@ from ..structures import NonNull
def test_generate_mutation_no_args(): def test_generate_mutation_no_args():
class MyMutation(Mutation): class MyMutation(Mutation):
'''Documentation''' """Documentation"""
def mutate(self, info, **args): def mutate(self, info, **args):
return args return args
@ -19,24 +19,23 @@ def test_generate_mutation_no_args():
assert issubclass(MyMutation, ObjectType) assert issubclass(MyMutation, ObjectType)
assert MyMutation._meta.name == "MyMutation" assert MyMutation._meta.name == "MyMutation"
assert MyMutation._meta.description == "Documentation" assert MyMutation._meta.description == "Documentation"
resolved = MyMutation.Field().resolver(None, None, name='Peter') resolved = MyMutation.Field().resolver(None, None, name="Peter")
assert resolved == {'name': 'Peter'} assert resolved == {"name": "Peter"}
def test_generate_mutation_with_meta(): def test_generate_mutation_with_meta():
class MyMutation(Mutation): class MyMutation(Mutation):
class Meta: class Meta:
name = 'MyOtherMutation' name = "MyOtherMutation"
description = 'Documentation' description = "Documentation"
def mutate(self, info, **args): def mutate(self, info, **args):
return args return args
assert MyMutation._meta.name == "MyOtherMutation" assert MyMutation._meta.name == "MyOtherMutation"
assert MyMutation._meta.description == "Documentation" assert MyMutation._meta.description == "Documentation"
resolved = MyMutation.Field().resolver(None, None, name='Peter') resolved = MyMutation.Field().resolver(None, None, name="Peter")
assert resolved == {'name': 'Peter'} assert resolved == {"name": "Peter"}
def test_mutation_raises_exception_if_no_mutate(): def test_mutation_raises_exception_if_no_mutate():
@ -45,8 +44,7 @@ def test_mutation_raises_exception_if_no_mutate():
class MyMutation(Mutation): class MyMutation(Mutation):
pass pass
assert "All mutations must define a mutate method in it" == str( assert "All mutations must define a mutate method in it" == str(excinfo.value)
excinfo.value)
def test_mutation_custom_output_type(): def test_mutation_custom_output_type():
@ -54,7 +52,6 @@ def test_mutation_custom_output_type():
name = String() name = String()
class CreateUser(Mutation): class CreateUser(Mutation):
class Arguments: class Arguments:
name = String() name = String()
@ -65,15 +62,14 @@ def test_mutation_custom_output_type():
field = CreateUser.Field() field = CreateUser.Field()
assert field.type == User assert field.type == User
assert field.args == {'name': Argument(String)} assert field.args == {"name": Argument(String)}
resolved = field.resolver(None, None, name='Peter') resolved = field.resolver(None, None, name="Peter")
assert isinstance(resolved, User) assert isinstance(resolved, User)
assert resolved.name == 'Peter' assert resolved.name == "Peter"
def test_mutation_execution(): def test_mutation_execution():
class CreateUser(Mutation): class CreateUser(Mutation):
class Arguments: class Arguments:
name = String() name = String()
dynamic = Dynamic(lambda: String()) dynamic = Dynamic(lambda: String())
@ -92,20 +88,17 @@ def test_mutation_execution():
create_user = CreateUser.Field() create_user = CreateUser.Field()
schema = Schema(query=Query, mutation=MyMutation) schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(''' mutation mymutation { result = schema.execute(
""" mutation mymutation {
createUser(name:"Peter", dynamic: "dynamic") { createUser(name:"Peter", dynamic: "dynamic") {
name name
dynamic dynamic
} }
} }
''') """
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"createUser": {"name": "Peter", "dynamic": "dynamic"}}
'createUser': {
'name': 'Peter',
'dynamic': 'dynamic',
}
}
def test_mutation_no_fields_output(): def test_mutation_no_fields_output():
@ -122,23 +115,20 @@ def test_mutation_no_fields_output():
create_user = CreateUser.Field() create_user = CreateUser.Field()
schema = Schema(query=Query, mutation=MyMutation) schema = Schema(query=Query, mutation=MyMutation)
result = schema.execute(''' mutation mymutation { result = schema.execute(
""" mutation mymutation {
createUser { createUser {
name name
} }
} }
''') """
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"createUser": {"name": None}}
'createUser': {
'name': None,
}
}
def test_mutation_allow_to_have_custom_args(): def test_mutation_allow_to_have_custom_args():
class CreateUser(Mutation): class CreateUser(Mutation):
class Arguments: class Arguments:
name = String() name = String()
@ -149,14 +139,14 @@ def test_mutation_allow_to_have_custom_args():
class MyMutation(ObjectType): class MyMutation(ObjectType):
create_user = CreateUser.Field( create_user = CreateUser.Field(
description='Create a user', description="Create a user",
deprecation_reason='Is deprecated', deprecation_reason="Is deprecated",
required=True required=True,
) )
field = MyMutation._meta.fields['create_user'] field = MyMutation._meta.fields["create_user"]
assert field.description == 'Create a user' assert field.description == "Create a user"
assert field.deprecation_reason == 'Is deprecated' assert field.deprecation_reason == "Is deprecated"
assert field.type == NonNull(CreateUser) assert field.type == NonNull(CreateUser)

View File

@ -23,37 +23,37 @@ class MyInterface(Interface):
class ContainerWithInterface(ObjectType): class ContainerWithInterface(ObjectType):
class Meta: class Meta:
interfaces = (MyInterface,) interfaces = (MyInterface,)
field1 = Field(MyType) field1 = Field(MyType)
field2 = Field(MyType) field2 = Field(MyType)
class MyScalar(UnmountedType): class MyScalar(UnmountedType):
def get_type(self): def get_type(self):
return MyType return MyType
def test_generate_objecttype(): def test_generate_objecttype():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
'''Documentation''' """Documentation"""
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert MyObjectType._meta.description == "Documentation" assert MyObjectType._meta.description == "Documentation"
assert MyObjectType._meta.interfaces == tuple() assert MyObjectType._meta.interfaces == tuple()
assert MyObjectType._meta.fields == {} assert MyObjectType._meta.fields == {}
assert repr( assert (
MyObjectType) == "<MyObjectType meta=<ObjectTypeOptions name='MyObjectType'>>" repr(MyObjectType)
== "<MyObjectType meta=<ObjectTypeOptions name='MyObjectType'>>"
)
def test_generate_objecttype_with_meta(): def test_generate_objecttype_with_meta():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
class Meta: class Meta:
name = 'MyOtherObjectType' name = "MyOtherObjectType"
description = 'Documentation' description = "Documentation"
interfaces = (MyType,) interfaces = (MyType,)
assert MyObjectType._meta.name == "MyOtherObjectType" assert MyObjectType._meta.name == "MyOtherObjectType"
@ -69,7 +69,7 @@ def test_generate_lazy_objecttype():
field = Field(MyType) field = Field(MyType)
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
example_field = MyObjectType._meta.fields['example'] example_field = MyObjectType._meta.fields["example"]
assert isinstance(example_field.type, NonNull) assert isinstance(example_field.type, NonNull)
assert example_field.type.of_type == InnerObjectType assert example_field.type.of_type == InnerObjectType
@ -78,21 +78,21 @@ def test_generate_objecttype_with_fields():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
field = Field(MyType) field = Field(MyType)
assert 'field' in MyObjectType._meta.fields assert "field" in MyObjectType._meta.fields
def test_generate_objecttype_with_private_attributes(): def test_generate_objecttype_with_private_attributes():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
_private_state = None _private_state = None
assert '_private_state' not in MyObjectType._meta.fields assert "_private_state" not in MyObjectType._meta.fields
assert hasattr(MyObjectType, '_private_state') assert hasattr(MyObjectType, "_private_state")
m = MyObjectType(_private_state='custom') m = MyObjectType(_private_state="custom")
assert m._private_state == 'custom' assert m._private_state == "custom"
with pytest.raises(TypeError): with pytest.raises(TypeError):
MyObjectType(_other_private_state='Wrong') MyObjectType(_other_private_state="Wrong")
def test_ordered_fields_in_objecttype(): def test_ordered_fields_in_objecttype():
@ -102,7 +102,7 @@ def test_ordered_fields_in_objecttype():
field = MyScalar() field = MyScalar()
asa = Field(MyType) asa = Field(MyType)
assert list(MyObjectType._meta.fields.keys()) == ['b', 'a', 'field', 'asa'] assert list(MyObjectType._meta.fields.keys()) == ["b", "a", "field", "asa"]
def test_generate_objecttype_inherit_abstracttype(): def test_generate_objecttype_inherit_abstracttype():
@ -115,9 +115,8 @@ def test_generate_objecttype_inherit_abstracttype():
assert MyObjectType._meta.description is None assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == () assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ['field1', 'field2'] assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [ assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]
Field, Field]
def test_generate_objecttype_inherit_abstracttype_reversed(): def test_generate_objecttype_inherit_abstracttype_reversed():
@ -130,26 +129,28 @@ def test_generate_objecttype_inherit_abstracttype_reversed():
assert MyObjectType._meta.description is None assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == () assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ['field1', 'field2'] assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [ assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]
Field, Field]
def test_generate_objecttype_unmountedtype(): def test_generate_objecttype_unmountedtype():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
field = MyScalar() field = MyScalar()
assert 'field' in MyObjectType._meta.fields assert "field" in MyObjectType._meta.fields
assert isinstance(MyObjectType._meta.fields['field'], Field) assert isinstance(MyObjectType._meta.fields["field"], Field)
def test_parent_container_get_fields(): def test_parent_container_get_fields():
assert list(Container._meta.fields.keys()) == ['field1', 'field2'] assert list(Container._meta.fields.keys()) == ["field1", "field2"]
def test_parent_container_interface_get_fields(): def test_parent_container_interface_get_fields():
assert list(ContainerWithInterface._meta.fields.keys()) == [ assert list(ContainerWithInterface._meta.fields.keys()) == [
'ifield', 'field1', 'field2'] "ifield",
"field1",
"field2",
]
def test_objecttype_as_container_only_args(): def test_objecttype_as_container_only_args():
@ -187,29 +188,29 @@ def test_objecttype_as_container_invalid_kwargs():
Container(unexisting_field="3") Container(unexisting_field="3")
assert "'unexisting_field' is an invalid keyword argument for Container" == str( assert "'unexisting_field' is an invalid keyword argument for Container" == str(
excinfo.value) excinfo.value
)
def test_objecttype_container_benchmark(benchmark): def test_objecttype_container_benchmark(benchmark):
@benchmark @benchmark
def create_objecttype(): def create_objecttype():
Container(field1='field1', field2='field2') Container(field1="field1", field2="field2")
def test_generate_objecttype_description(): def test_generate_objecttype_description():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
''' """
Documentation Documentation
Documentation line 2 Documentation line 2
''' """
assert MyObjectType._meta.description == "Documentation\n\nDocumentation line 2" assert MyObjectType._meta.description == "Documentation\n\nDocumentation line 2"
def test_objecttype_with_possible_types(): def test_objecttype_with_possible_types():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
class Meta: class Meta:
possible_types = (dict,) possible_types = (dict,)
@ -218,8 +219,8 @@ def test_objecttype_with_possible_types():
def test_objecttype_with_possible_types_and_is_type_of_should_raise(): def test_objecttype_with_possible_types_and_is_type_of_should_raise():
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
class MyObjectType(ObjectType):
class MyObjectType(ObjectType):
class Meta: class Meta:
possible_types = (dict,) possible_types = (dict,)
@ -228,8 +229,8 @@ def test_objecttype_with_possible_types_and_is_type_of_should_raise():
return False return False
assert str(excinfo.value) == ( assert str(excinfo.value) == (
'MyObjectType.Meta.possible_types will cause type collision with ' "MyObjectType.Meta.possible_types will cause type collision with "
'MyObjectType.is_type_of. Please use one or other.' "MyObjectType.is_type_of. Please use one or other."
) )
@ -244,24 +245,23 @@ def test_objecttype_no_fields_output():
return User() return User()
schema = Schema(query=Query) schema = Schema(query=Query)
result = schema.execute(''' query basequery { result = schema.execute(
""" query basequery {
user { user {
name name
} }
} }
''') """
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"user": {"name": None}}
'user': {
'name': None,
}
}
def test_abstract_objecttype_can_str(): def test_abstract_objecttype_can_str():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
class Meta: class Meta:
abstract = True abstract = True
field = MyScalar() field = MyScalar()
assert str(MyObjectType) == "MyObjectType" assert str(MyObjectType) == "MyObjectType"

View File

@ -18,13 +18,13 @@ from ..union import Union
def test_query(): def test_query():
class Query(ObjectType): class Query(ObjectType):
hello = String(resolver=lambda *_: 'World') hello = String(resolver=lambda *_: "World")
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello }') executed = hello_schema.execute("{ hello }")
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': 'World'} assert executed.data == {"hello": "World"}
def test_query_source(): def test_query_source():
@ -39,9 +39,9 @@ def test_query_source():
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello }', Root()) executed = hello_schema.execute("{ hello }", Root())
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': 'World'} assert executed.data == {"hello": "World"}
def test_query_union(): def test_query_union():
@ -66,7 +66,6 @@ def test_query_union():
return isinstance(root, two_object) return isinstance(root, two_object)
class MyUnion(Union): class MyUnion(Union):
class Meta: class Meta:
types = (One, Two) types = (One, Two)
@ -78,15 +77,9 @@ def test_query_union():
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ unions { __typename } }') executed = hello_schema.execute("{ unions { __typename } }")
assert not executed.errors assert not executed.errors
assert executed.data == { assert executed.data == {"unions": [{"__typename": "One"}, {"__typename": "Two"}]}
'unions': [{
'__typename': 'One'
}, {
'__typename': 'Two'
}]
}
def test_query_interface(): def test_query_interface():
@ -100,7 +93,6 @@ def test_query_interface():
base = String() base = String()
class One(ObjectType): class One(ObjectType):
class Meta: class Meta:
interfaces = (MyInterface,) interfaces = (MyInterface,)
@ -111,7 +103,6 @@ def test_query_interface():
return isinstance(root, one_object) return isinstance(root, one_object)
class Two(ObjectType): class Two(ObjectType):
class Meta: class Meta:
interfaces = (MyInterface,) interfaces = (MyInterface,)
@ -129,30 +120,28 @@ def test_query_interface():
hello_schema = Schema(Query, types=[One, Two]) hello_schema = Schema(Query, types=[One, Two])
executed = hello_schema.execute('{ interfaces { __typename } }') executed = hello_schema.execute("{ interfaces { __typename } }")
assert not executed.errors assert not executed.errors
assert executed.data == { assert executed.data == {
'interfaces': [{ "interfaces": [{"__typename": "One"}, {"__typename": "Two"}]
'__typename': 'One'
}, {
'__typename': 'Two'
}]
} }
def test_query_dynamic(): def test_query_dynamic():
class Query(ObjectType): class Query(ObjectType):
hello = Dynamic(lambda: String(resolver=lambda *_: 'World')) hello = Dynamic(lambda: String(resolver=lambda *_: "World"))
hellos = Dynamic(lambda: List(String, resolver=lambda *_: ['Worlds'])) hellos = Dynamic(lambda: List(String, resolver=lambda *_: ["Worlds"]))
hello_field = Dynamic(lambda: Field( hello_field = Dynamic(lambda: Field(String, resolver=lambda *_: "Field World"))
String, resolver=lambda *_: 'Field World'))
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello hellos helloField }') executed = hello_schema.execute("{ hello hellos helloField }")
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': 'World', 'hellos': [ assert executed.data == {
'Worlds'], 'helloField': 'Field World'} "hello": "World",
"hellos": ["Worlds"],
"helloField": "Field World",
}
def test_query_default_value(): def test_query_default_value():
@ -160,13 +149,13 @@ def test_query_default_value():
field = String() field = String()
class Query(ObjectType): class Query(ObjectType):
hello = Field(MyType, default_value=MyType(field='something else!')) hello = Field(MyType, default_value=MyType(field="something else!"))
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello { field } }') executed = hello_schema.execute("{ hello { field } }")
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': {'field': 'something else!'}} assert executed.data == {"hello": {"field": "something else!"}}
def test_query_wrong_default_value(): def test_query_wrong_default_value():
@ -178,15 +167,17 @@ def test_query_wrong_default_value():
return isinstance(root, MyType) return isinstance(root, MyType)
class Query(ObjectType): class Query(ObjectType):
hello = Field(MyType, default_value='hello') hello = Field(MyType, default_value="hello")
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello { field } }') executed = hello_schema.execute("{ hello { field } }")
assert len(executed.errors) == 1 assert len(executed.errors) == 1
assert executed.errors[0].message == GraphQLError( assert (
'Expected value of type "MyType" but got: str.').message executed.errors[0].message
assert executed.data == {'hello': None} == GraphQLError('Expected value of type "MyType" but got: str.').message
)
assert executed.data == {"hello": None}
def test_query_default_value_ignored_by_resolver(): def test_query_default_value_ignored_by_resolver():
@ -194,14 +185,17 @@ def test_query_default_value_ignored_by_resolver():
field = String() field = String()
class Query(ObjectType): class Query(ObjectType):
hello = Field(MyType, default_value='hello', hello = Field(
resolver=lambda *_: MyType(field='no default.')) MyType,
default_value="hello",
resolver=lambda *_: MyType(field="no default."),
)
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello { field } }') executed = hello_schema.execute("{ hello { field } }")
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': {'field': 'no default.'}} assert executed.data == {"hello": {"field": "no default."}}
def test_query_resolve_function(): def test_query_resolve_function():
@ -209,13 +203,13 @@ def test_query_resolve_function():
hello = String() hello = String()
def resolve_hello(self, info): def resolve_hello(self, info):
return 'World' return "World"
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute('{ hello }') executed = hello_schema.execute("{ hello }")
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': 'World'} assert executed.data == {"hello": "World"}
def test_query_arguments(): def test_query_arguments():
@ -223,24 +217,23 @@ def test_query_arguments():
test = String(a_str=String(), a_int=Int()) test = String(a_str=String(), a_int=Int())
def resolve_test(self, info, **args): def resolve_test(self, info, **args):
return json.dumps([self, args], separators=(',', ':')) return json.dumps([self, args], separators=(",", ":"))
test_schema = Schema(Query) test_schema = Schema(Query)
result = test_schema.execute('{ test }', None) result = test_schema.execute("{ test }", None)
assert not result.errors assert not result.errors
assert result.data == {'test': '[null,{}]'} assert result.data == {"test": "[null,{}]"}
result = test_schema.execute('{ test(aStr: "String!") }', 'Source!') result = test_schema.execute('{ test(aStr: "String!") }', "Source!")
assert not result.errors assert not result.errors
assert result.data == {'test': '["Source!",{"a_str":"String!"}]'} assert result.data == {"test": '["Source!",{"a_str":"String!"}]'}
result = test_schema.execute( result = test_schema.execute('{ test(aInt: -123, aStr: "String!") }', "Source!")
'{ test(aInt: -123, aStr: "String!") }', 'Source!')
assert not result.errors assert not result.errors
assert result.data in [ assert result.data in [
{'test': '["Source!",{"a_str":"String!","a_int":-123}]'}, {"test": '["Source!",{"a_str":"String!","a_int":-123}]'},
{'test': '["Source!",{"a_int":-123,"a_str":"String!"}]'} {"test": '["Source!",{"a_int":-123,"a_str":"String!"}]'},
] ]
@ -253,25 +246,25 @@ def test_query_input_field():
test = String(a_input=Input()) test = String(a_input=Input())
def resolve_test(self, info, **args): def resolve_test(self, info, **args):
return json.dumps([self, args], separators=(',', ':')) return json.dumps([self, args], separators=(",", ":"))
test_schema = Schema(Query) test_schema = Schema(Query)
result = test_schema.execute('{ test }', None) result = test_schema.execute("{ test }", None)
assert not result.errors assert not result.errors
assert result.data == {'test': '[null,{}]'} assert result.data == {"test": "[null,{}]"}
result = test_schema.execute('{ test(aInput: {aField: "String!"} ) }', "Source!")
assert not result.errors
assert result.data == {"test": '["Source!",{"a_input":{"a_field":"String!"}}]'}
result = test_schema.execute( result = test_schema.execute(
'{ test(aInput: {aField: "String!"} ) }', 'Source!') '{ test(aInput: {recursiveField: {aField: "String!"}}) }', "Source!"
)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
'test': '["Source!",{"a_input":{"a_field":"String!"}}]'} "test": '["Source!",{"a_input":{"recursive_field":{"a_field":"String!"}}}]'
}
result = test_schema.execute(
'{ test(aInput: {recursiveField: {aField: "String!"}}) }', 'Source!')
assert not result.errors
assert result.data == {
'test': '["Source!",{"a_input":{"recursive_field":{"a_field":"String!"}}}]'}
def test_query_middlewares(): def test_query_middlewares():
@ -280,10 +273,10 @@ def test_query_middlewares():
other = String() other = String()
def resolve_hello(self, info): def resolve_hello(self, info):
return 'World' return "World"
def resolve_other(self, info): def resolve_other(self, info):
return 'other' return "other"
def reversed_middleware(next, *args, **kwargs): def reversed_middleware(next, *args, **kwargs):
p = next(*args, **kwargs) p = next(*args, **kwargs)
@ -292,14 +285,14 @@ def test_query_middlewares():
hello_schema = Schema(Query) hello_schema = Schema(Query)
executed = hello_schema.execute( executed = hello_schema.execute(
'{ hello, other }', middleware=[reversed_middleware]) "{ hello, other }", middleware=[reversed_middleware]
)
assert not executed.errors assert not executed.errors
assert executed.data == {'hello': 'dlroW', 'other': 'rehto'} assert executed.data == {"hello": "dlroW", "other": "rehto"}
def test_objecttype_on_instances(): def test_objecttype_on_instances():
class Ship: class Ship:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@ -314,12 +307,12 @@ def test_objecttype_on_instances():
ship = Field(ShipType) ship = Field(ShipType)
def resolve_ship(self, info): def resolve_ship(self, info):
return Ship(name='xwing') return Ship(name="xwing")
schema = Schema(query=Query) schema = Schema(query=Query)
executed = schema.execute('{ ship { name } }') executed = schema.execute("{ ship { name } }")
assert not executed.errors assert not executed.errors
assert executed.data == {'ship': {'name': 'xwing'}} assert executed.data == {"ship": {"name": "xwing"}}
def test_big_list_query_benchmark(benchmark): def test_big_list_query_benchmark(benchmark):
@ -333,10 +326,10 @@ def test_big_list_query_benchmark(benchmark):
hello_schema = Schema(Query) hello_schema = Schema(Query)
big_list_query = partial(hello_schema.execute, '{ allInts }') big_list_query = partial(hello_schema.execute, "{ allInts }")
result = benchmark(big_list_query) result = benchmark(big_list_query)
assert not result.errors assert not result.errors
assert result.data == {'allInts': list(big_list)} assert result.data == {"allInts": list(big_list)}
def test_big_list_query_compiled_query_benchmark(benchmark): def test_big_list_query_compiled_query_benchmark(benchmark):
@ -349,13 +342,13 @@ def test_big_list_query_compiled_query_benchmark(benchmark):
return big_list return big_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
source = Source('{ allInts }') source = Source("{ allInts }")
query_ast = parse(source) query_ast = parse(source)
big_list_query = partial(execute, hello_schema, query_ast) big_list_query = partial(execute, hello_schema, query_ast)
result = benchmark(big_list_query) result = benchmark(big_list_query)
assert not result.errors assert not result.errors
assert result.data == {'allInts': list(big_list)} assert result.data == {"allInts": list(big_list)}
def test_big_list_of_containers_query_benchmark(benchmark): def test_big_list_of_containers_query_benchmark(benchmark):
@ -372,11 +365,10 @@ def test_big_list_of_containers_query_benchmark(benchmark):
hello_schema = Schema(Query) hello_schema = Schema(Query)
big_list_query = partial(hello_schema.execute, '{ allContainers { x } }') big_list_query = partial(hello_schema.execute, "{ allContainers { x } }")
result = benchmark(big_list_query) result = benchmark(big_list_query)
assert not result.errors assert not result.errors
assert result.data == {'allContainers': [ assert result.data == {"allContainers": [{"x": c.x} for c in big_container_list]}
{'x': c.x} for c in big_container_list]}
def test_big_list_of_containers_multiple_fields_query_benchmark(benchmark): def test_big_list_of_containers_multiple_fields_query_benchmark(benchmark):
@ -396,15 +388,19 @@ def test_big_list_of_containers_multiple_fields_query_benchmark(benchmark):
hello_schema = Schema(Query) hello_schema = Schema(Query)
big_list_query = partial(hello_schema.execute, big_list_query = partial(hello_schema.execute, "{ allContainers { x, y, z, o } }")
'{ allContainers { x, y, z, o } }')
result = benchmark(big_list_query) result = benchmark(big_list_query)
assert not result.errors assert not result.errors
assert result.data == {'allContainers': [ assert result.data == {
{'x': c.x, 'y': c.y, 'z': c.z, 'o': c.o} for c in big_container_list]} "allContainers": [
{"x": c.x, "y": c.y, "z": c.z, "o": c.o} for c in big_container_list
]
}
def test_big_list_of_containers_multiple_fields_custom_resolvers_query_benchmark(benchmark): def test_big_list_of_containers_multiple_fields_custom_resolvers_query_benchmark(
benchmark
):
class Container(ObjectType): class Container(ObjectType):
x = Int() x = Int()
y = Int() y = Int()
@ -433,12 +429,14 @@ def test_big_list_of_containers_multiple_fields_custom_resolvers_query_benchmark
hello_schema = Schema(Query) hello_schema = Schema(Query)
big_list_query = partial(hello_schema.execute, big_list_query = partial(hello_schema.execute, "{ allContainers { x, y, z, o } }")
'{ allContainers { x, y, z, o } }')
result = benchmark(big_list_query) result = benchmark(big_list_query)
assert not result.errors assert not result.errors
assert result.data == {'allContainers': [ assert result.data == {
{'x': c.x, 'y': c.y, 'z': c.z, 'o': c.o} for c in big_container_list]} "allContainers": [
{"x": c.x, "y": c.y, "z": c.z, "o": c.o} for c in big_container_list
]
}
def test_query_annotated_resolvers(): def test_query_annotated_resolvers():
@ -464,15 +462,15 @@ def test_query_annotated_resolvers():
result = test_schema.execute('{ annotated(id:"self") }', "base") result = test_schema.execute('{ annotated(id:"self") }', "base")
assert not result.errors assert not result.errors
assert result.data == {'annotated': 'base-self'} assert result.data == {"annotated": "base-self"}
result = test_schema.execute('{ context }', "base", context_value=context) result = test_schema.execute("{ context }", "base", context_value=context)
assert not result.errors assert not result.errors
assert result.data == {'context': 'base-context'} assert result.data == {"context": "base-context"}
result = test_schema.execute('{ info }', "base") result = test_schema.execute("{ info }", "base")
assert not result.errors assert not result.errors
assert result.data == {'info': 'base-info'} assert result.data == {"info": "base-info"}
def test_default_as_kwarg_to_NonNull(): def test_default_as_kwarg_to_NonNull():
@ -488,7 +486,7 @@ def test_default_as_kwarg_to_NonNull():
return User(name="foo") return User(name="foo")
schema = Schema(query=Query) schema = Schema(query=Query)
expected = {'user': {'name': 'foo', 'isAdmin': False}} expected = {"user": {"name": "foo", "isAdmin": False}}
result = schema.execute("{ user { name isAdmin } }") result = schema.execute("{ user { name isAdmin } }")
assert not result.errors assert not result.errors

View File

@ -1,38 +1,40 @@
from ..resolver import (attr_resolver, dict_resolver, get_default_resolver, from ..resolver import (
set_default_resolver) attr_resolver,
dict_resolver,
get_default_resolver,
set_default_resolver,
)
args = {} args = {}
context = None context = None
info = None info = None
demo_dict = { demo_dict = {"attr": "value"}
'attr': 'value'
}
class demo_obj(object): class demo_obj(object):
attr = 'value' attr = "value"
def test_attr_resolver(): def test_attr_resolver():
resolved = attr_resolver('attr', None, demo_obj, info, **args) resolved = attr_resolver("attr", None, demo_obj, info, **args)
assert resolved == 'value' assert resolved == "value"
def test_attr_resolver_default_value(): def test_attr_resolver_default_value():
resolved = attr_resolver('attr2', 'default', demo_obj, info, **args) resolved = attr_resolver("attr2", "default", demo_obj, info, **args)
assert resolved == 'default' assert resolved == "default"
def test_dict_resolver(): def test_dict_resolver():
resolved = dict_resolver('attr', None, demo_dict, info, **args) resolved = dict_resolver("attr", None, demo_dict, info, **args)
assert resolved == 'value' assert resolved == "value"
def test_dict_resolver_default_value(): def test_dict_resolver_default_value():
resolved = dict_resolver('attr2', 'default', demo_dict, info, **args) resolved = dict_resolver("attr2", "default", demo_dict, info, **args)
assert resolved == 'default' assert resolved == "default"
def test_get_default_resolver_is_attr_resolver(): def test_get_default_resolver_is_attr_resolver():

View File

@ -4,7 +4,7 @@ from ..scalars import Scalar
def test_scalar(): def test_scalar():
class JSONScalar(Scalar): class JSONScalar(Scalar):
'''Documentation''' """Documentation"""
assert JSONScalar._meta.name == "JSONScalar" assert JSONScalar._meta.name == "JSONScalar"
assert JSONScalar._meta.description == "Documentation" assert JSONScalar._meta.description == "Documentation"

View File

@ -13,8 +13,8 @@ def test_serializes_output_int():
assert Int.serialize(-9876504321) is None assert Int.serialize(-9876504321) is None
assert Int.serialize(1e100) is None assert Int.serialize(1e100) is None
assert Int.serialize(-1e100) is None assert Int.serialize(-1e100) is None
assert Int.serialize('-1.1') == -1 assert Int.serialize("-1.1") == -1
assert Int.serialize('one') is None assert Int.serialize("one") is None
assert Int.serialize(False) == 0 assert Int.serialize(False) == 0
assert Int.serialize(True) == 1 assert Int.serialize(True) == 1
@ -26,24 +26,24 @@ def test_serializes_output_float():
assert Float.serialize(0.1) == 0.1 assert Float.serialize(0.1) == 0.1
assert Float.serialize(1.1) == 1.1 assert Float.serialize(1.1) == 1.1
assert Float.serialize(-1.1) == -1.1 assert Float.serialize(-1.1) == -1.1
assert Float.serialize('-1.1') == -1.1 assert Float.serialize("-1.1") == -1.1
assert Float.serialize('one') is None assert Float.serialize("one") is None
assert Float.serialize(False) == 0 assert Float.serialize(False) == 0
assert Float.serialize(True) == 1 assert Float.serialize(True) == 1
def test_serializes_output_string(): def test_serializes_output_string():
assert String.serialize('string') == 'string' assert String.serialize("string") == "string"
assert String.serialize(1) == '1' assert String.serialize(1) == "1"
assert String.serialize(-1.1) == '-1.1' assert String.serialize(-1.1) == "-1.1"
assert String.serialize(True) == 'true' assert String.serialize(True) == "true"
assert String.serialize(False) == 'false' assert String.serialize(False) == "false"
assert String.serialize(u'\U0001F601') == u'\U0001F601' assert String.serialize(u"\U0001F601") == u"\U0001F601"
def test_serializes_output_boolean(): def test_serializes_output_boolean():
assert Boolean.serialize('string') is True assert Boolean.serialize("string") is True
assert Boolean.serialize('') is False assert Boolean.serialize("") is False
assert Boolean.serialize(1) is True assert Boolean.serialize(1) is True
assert Boolean.serialize(0) is False assert Boolean.serialize(0) is False
assert Boolean.serialize(True) is True assert Boolean.serialize(True) is True

View File

@ -35,7 +35,9 @@ def test_schema_get_type_error():
def test_schema_str(): def test_schema_str():
schema = Schema(Query) schema = Schema(Query)
assert str(schema) == """schema { assert (
str(schema)
== """schema {
query: Query query: Query
} }
@ -47,8 +49,9 @@ type Query {
inner: MyOtherType inner: MyOtherType
} }
""" """
)
def test_schema_introspect(): def test_schema_introspect():
schema = Schema(Query) schema = Schema(Query)
assert '__schema' in schema.introspect() assert "__schema" in schema.introspect()

View File

@ -10,14 +10,17 @@ from .utils import MyLazyType
def test_list(): def test_list():
_list = List(String) _list = List(String)
assert _list.of_type == String assert _list.of_type == String
assert str(_list) == '[String]' assert str(_list) == "[String]"
def test_list_with_unmounted_type(): def test_list_with_unmounted_type():
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
List(String()) List(String())
assert str(exc_info.value) == 'List could not have a mounted String() as inner type. Try with List(String).' assert (
str(exc_info.value)
== "List could not have a mounted String() as inner type. Try with List(String)."
)
def test_list_with_lazy_type(): def test_list_with_lazy_type():
@ -52,7 +55,7 @@ def test_list_inherited_works_nonnull():
def test_nonnull(): def test_nonnull():
nonnull = NonNull(String) nonnull = NonNull(String)
assert nonnull.of_type == String assert nonnull.of_type == String
assert str(nonnull) == 'String!' assert str(nonnull) == "String!"
def test_nonnull_with_lazy_type(): def test_nonnull_with_lazy_type():
@ -82,14 +85,20 @@ def test_nonnull_inherited_dont_work_nonnull():
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
NonNull(NonNull(String)) NonNull(NonNull(String))
assert str(exc_info.value) == 'Can only create NonNull of a Nullable GraphQLType but got: String!.' assert (
str(exc_info.value)
== "Can only create NonNull of a Nullable GraphQLType but got: String!."
)
def test_nonnull_with_unmounted_type(): def test_nonnull_with_unmounted_type():
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
NonNull(String()) NonNull(String())
assert str(exc_info.value) == 'NonNull could not have a mounted String() as inner type. Try with NonNull(String).' assert (
str(exc_info.value)
== "NonNull could not have a mounted String() as inner type. Try with NonNull(String)."
)
def test_list_comparasion(): def test_list_comparasion():

View File

@ -1,8 +1,15 @@
import pytest import pytest
from graphql.type import (GraphQLArgument, GraphQLEnumType, GraphQLEnumValue, from graphql.type import (
GraphQLField, GraphQLInputObjectField, GraphQLArgument,
GraphQLInputObjectType, GraphQLInterfaceType, GraphQLEnumType,
GraphQLObjectType, GraphQLString) GraphQLEnumValue,
GraphQLField,
GraphQLInputObjectField,
GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLString,
)
from ..dynamic import Dynamic from ..dynamic import Dynamic
from ..enum import Enum from ..enum import Enum
@ -18,107 +25,128 @@ from ..typemap import TypeMap, resolve_type
def test_enum(): def test_enum():
class MyEnum(Enum): class MyEnum(Enum):
'''Description''' """Description"""
foo = 1 foo = 1
bar = 2 bar = 2
@property @property
def description(self): def description(self):
return 'Description {}={}'.format(self.name, self.value) return "Description {}={}".format(self.name, self.value)
@property @property
def deprecation_reason(self): def deprecation_reason(self):
if self == MyEnum.foo: if self == MyEnum.foo:
return 'Is deprecated' return "Is deprecated"
typemap = TypeMap([MyEnum]) typemap = TypeMap([MyEnum])
assert 'MyEnum' in typemap assert "MyEnum" in typemap
graphql_enum = typemap['MyEnum'] graphql_enum = typemap["MyEnum"]
assert isinstance(graphql_enum, GraphQLEnumType) assert isinstance(graphql_enum, GraphQLEnumType)
assert graphql_enum.name == 'MyEnum' assert graphql_enum.name == "MyEnum"
assert graphql_enum.description == 'Description' assert graphql_enum.description == "Description"
values = graphql_enum.values values = graphql_enum.values
assert values == [ assert values == [
GraphQLEnumValue(name='foo', value=1, description='Description foo=1', GraphQLEnumValue(
deprecation_reason='Is deprecated'), name="foo",
GraphQLEnumValue(name='bar', value=2, description='Description bar=2'), value=1,
description="Description foo=1",
deprecation_reason="Is deprecated",
),
GraphQLEnumValue(name="bar", value=2, description="Description bar=2"),
] ]
def test_objecttype(): def test_objecttype():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
'''Description''' """Description"""
foo = String(bar=String(description='Argument description',
default_value='x'), description='Field description') foo = String(
bar = String(name='gizmo') bar=String(description="Argument description", default_value="x"),
description="Field description",
)
bar = String(name="gizmo")
def resolve_foo(self, bar): def resolve_foo(self, bar):
return bar return bar
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
assert 'MyObjectType' in typemap assert "MyObjectType" in typemap
graphql_type = typemap['MyObjectType'] graphql_type = typemap["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType) assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == 'MyObjectType' assert graphql_type.name == "MyObjectType"
assert graphql_type.description == 'Description' assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ['foo', 'gizmo'] assert list(fields.keys()) == ["foo", "gizmo"]
foo_field = fields['foo'] foo_field = fields["foo"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.description == 'Field description' assert foo_field.description == "Field description"
assert foo_field.args == { assert foo_field.args == {
'bar': GraphQLArgument(GraphQLString, description='Argument description', default_value='x', out_name='bar') "bar": GraphQLArgument(
GraphQLString,
description="Argument description",
default_value="x",
out_name="bar",
)
} }
def test_dynamic_objecttype(): def test_dynamic_objecttype():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
'''Description''' """Description"""
bar = Dynamic(lambda: Field(String)) bar = Dynamic(lambda: Field(String))
own = Field(lambda: MyObjectType) own = Field(lambda: MyObjectType)
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
assert 'MyObjectType' in typemap assert "MyObjectType" in typemap
assert list(MyObjectType._meta.fields.keys()) == ['bar', 'own'] assert list(MyObjectType._meta.fields.keys()) == ["bar", "own"]
graphql_type = typemap['MyObjectType'] graphql_type = typemap["MyObjectType"]
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ['bar', 'own'] assert list(fields.keys()) == ["bar", "own"]
assert fields['bar'].type == GraphQLString assert fields["bar"].type == GraphQLString
assert fields['own'].type == graphql_type assert fields["own"].type == graphql_type
def test_interface(): def test_interface():
class MyInterface(Interface): class MyInterface(Interface):
'''Description''' """Description"""
foo = String(bar=String(description='Argument description',
default_value='x'), description='Field description') foo = String(
bar = String(name='gizmo', first_arg=String(), bar=String(description="Argument description", default_value="x"),
other_arg=String(name='oth_arg')) description="Field description",
)
bar = String(name="gizmo", first_arg=String(), other_arg=String(name="oth_arg"))
own = Field(lambda: MyInterface) own = Field(lambda: MyInterface)
def resolve_foo(self, args, info): def resolve_foo(self, args, info):
return args.get('bar') return args.get("bar")
typemap = TypeMap([MyInterface]) typemap = TypeMap([MyInterface])
assert 'MyInterface' in typemap assert "MyInterface" in typemap
graphql_type = typemap['MyInterface'] graphql_type = typemap["MyInterface"]
assert isinstance(graphql_type, GraphQLInterfaceType) assert isinstance(graphql_type, GraphQLInterfaceType)
assert graphql_type.name == 'MyInterface' assert graphql_type.name == "MyInterface"
assert graphql_type.description == 'Description' assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ['foo', 'gizmo', 'own'] assert list(fields.keys()) == ["foo", "gizmo", "own"]
assert fields['own'].type == graphql_type assert fields["own"].type == graphql_type
assert list(fields['gizmo'].args.keys()) == ['firstArg', 'oth_arg'] assert list(fields["gizmo"].args.keys()) == ["firstArg", "oth_arg"]
foo_field = fields['foo'] foo_field = fields["foo"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.description == 'Field description' assert foo_field.description == "Field description"
assert not foo_field.resolver # Resolver not attached in interfaces assert not foo_field.resolver # Resolver not attached in interfaces
assert foo_field.args == { assert foo_field.args == {
'bar': GraphQLArgument(GraphQLString, description='Argument description', default_value='x', out_name='bar') "bar": GraphQLArgument(
GraphQLString,
description="Argument description",
default_value="x",
out_name="bar",
)
} }
@ -131,103 +159,111 @@ def test_inputobject():
some_other_field = List(OtherObjectType) some_other_field = List(OtherObjectType)
class MyInputObjectType(InputObjectType): class MyInputObjectType(InputObjectType):
'''Description''' """Description"""
foo_bar = String(description='Field description')
bar = String(name='gizmo') foo_bar = String(description="Field description")
bar = String(name="gizmo")
baz = NonNull(MyInnerObjectType) baz = NonNull(MyInnerObjectType)
own = InputField(lambda: MyInputObjectType) own = InputField(lambda: MyInputObjectType)
def resolve_foo_bar(self, args, info): def resolve_foo_bar(self, args, info):
return args.get('bar') return args.get("bar")
typemap = TypeMap([MyInputObjectType]) typemap = TypeMap([MyInputObjectType])
assert 'MyInputObjectType' in typemap assert "MyInputObjectType" in typemap
graphql_type = typemap['MyInputObjectType'] graphql_type = typemap["MyInputObjectType"]
assert isinstance(graphql_type, GraphQLInputObjectType) assert isinstance(graphql_type, GraphQLInputObjectType)
assert graphql_type.name == 'MyInputObjectType' assert graphql_type.name == "MyInputObjectType"
assert graphql_type.description == 'Description' assert graphql_type.description == "Description"
other_graphql_type = typemap['OtherObjectType'] other_graphql_type = typemap["OtherObjectType"]
inner_graphql_type = typemap['MyInnerObjectType'] inner_graphql_type = typemap["MyInnerObjectType"]
container = graphql_type.create_container({ container = graphql_type.create_container(
'bar': 'oh!', {
'baz': inner_graphql_type.create_container({ "bar": "oh!",
'some_other_field': [ "baz": inner_graphql_type.create_container(
other_graphql_type.create_container({'thingy': 1}), {
other_graphql_type.create_container({'thingy': 2}) "some_other_field": [
other_graphql_type.create_container({"thingy": 1}),
other_graphql_type.create_container({"thingy": 2}),
] ]
}) }
}) ),
}
)
assert isinstance(container, MyInputObjectType) assert isinstance(container, MyInputObjectType)
assert 'bar' in container assert "bar" in container
assert container.bar == 'oh!' assert container.bar == "oh!"
assert 'foo_bar' not in container assert "foo_bar" not in container
assert container.foo_bar is None assert container.foo_bar is None
assert container.baz.some_field is None assert container.baz.some_field is None
assert container.baz.some_other_field[0].thingy == 1 assert container.baz.some_other_field[0].thingy == 1
assert container.baz.some_other_field[1].thingy == 2 assert container.baz.some_other_field[1].thingy == 2
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ['fooBar', 'gizmo', 'baz', 'own'] assert list(fields.keys()) == ["fooBar", "gizmo", "baz", "own"]
own_field = fields['own'] own_field = fields["own"]
assert own_field.type == graphql_type assert own_field.type == graphql_type
foo_field = fields['fooBar'] foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLInputObjectField) assert isinstance(foo_field, GraphQLInputObjectField)
assert foo_field.description == 'Field description' assert foo_field.description == "Field description"
def test_objecttype_camelcase(): def test_objecttype_camelcase():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
'''Description''' """Description"""
foo_bar = String(bar_foo=String()) foo_bar = String(bar_foo=String())
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
assert 'MyObjectType' in typemap assert "MyObjectType" in typemap
graphql_type = typemap['MyObjectType'] graphql_type = typemap["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType) assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == 'MyObjectType' assert graphql_type.name == "MyObjectType"
assert graphql_type.description == 'Description' assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ['fooBar'] assert list(fields.keys()) == ["fooBar"]
foo_field = fields['fooBar'] foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.args == { assert foo_field.args == {
'barFoo': GraphQLArgument(GraphQLString, out_name='bar_foo') "barFoo": GraphQLArgument(GraphQLString, out_name="bar_foo")
} }
def test_objecttype_camelcase_disabled(): def test_objecttype_camelcase_disabled():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
'''Description''' """Description"""
foo_bar = String(bar_foo=String()) foo_bar = String(bar_foo=String())
typemap = TypeMap([MyObjectType], auto_camelcase=False) typemap = TypeMap([MyObjectType], auto_camelcase=False)
assert 'MyObjectType' in typemap assert "MyObjectType" in typemap
graphql_type = typemap['MyObjectType'] graphql_type = typemap["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType) assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == 'MyObjectType' assert graphql_type.name == "MyObjectType"
assert graphql_type.description == 'Description' assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ['foo_bar'] assert list(fields.keys()) == ["foo_bar"]
foo_field = fields['foo_bar'] foo_field = fields["foo_bar"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.args == { assert foo_field.args == {
'bar_foo': GraphQLArgument(GraphQLString, out_name='bar_foo') "bar_foo": GraphQLArgument(GraphQLString, out_name="bar_foo")
} }
def test_objecttype_with_possible_types(): def test_objecttype_with_possible_types():
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
'''Description''' """Description"""
class Meta: class Meta:
possible_types = (dict,) possible_types = (dict,)
foo_bar = String() foo_bar = String()
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
graphql_type = typemap['MyObjectType'] graphql_type = typemap["MyObjectType"]
assert graphql_type.is_type_of assert graphql_type.is_type_of
assert graphql_type.is_type_of({}, None) is True assert graphql_type.is_type_of({}, None) is True
assert graphql_type.is_type_of(MyObjectType(), None) is False assert graphql_type.is_type_of(MyObjectType(), None) is False
@ -245,8 +281,6 @@ def test_resolve_type_with_missing_type():
typemap = TypeMap([MyObjectType]) typemap = TypeMap([MyObjectType])
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
resolve_type( resolve_type(resolve_type_func, typemap, "MyOtherObjectType", {}, {})
resolve_type_func, typemap, 'MyOtherObjectType', {}, {}
)
assert 'MyOtherObjectTyp' in str(excinfo.value) assert "MyOtherObjectTyp" in str(excinfo.value)

View File

@ -16,7 +16,8 @@ class MyObjectType2(ObjectType):
def test_generate_union(): def test_generate_union():
class MyUnion(Union): class MyUnion(Union):
'''Documentation''' """Documentation"""
class Meta: class Meta:
types = (MyObjectType1, MyObjectType2) types = (MyObjectType1, MyObjectType2)
@ -27,10 +28,9 @@ def test_generate_union():
def test_generate_union_with_meta(): def test_generate_union_with_meta():
class MyUnion(Union): class MyUnion(Union):
class Meta: class Meta:
name = 'MyOtherUnion' name = "MyOtherUnion"
description = 'Documentation' description = "Documentation"
types = (MyObjectType1, MyObjectType2) types = (MyObjectType1, MyObjectType2)
assert MyUnion._meta.name == "MyOtherUnion" assert MyUnion._meta.name == "MyOtherUnion"
@ -39,15 +39,15 @@ def test_generate_union_with_meta():
def test_generate_union_with_no_types(): def test_generate_union_with_no_types():
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
class MyUnion(Union): class MyUnion(Union):
pass pass
assert str(exc_info.value) == 'Must provide types for Union MyUnion.' assert str(exc_info.value) == "Must provide types for Union MyUnion."
def test_union_can_be_mounted(): def test_union_can_be_mounted():
class MyUnion(Union): class MyUnion(Union):
class Meta: class Meta:
types = (MyObjectType1, MyObjectType2) types = (MyObjectType1, MyObjectType2)

View File

@ -14,22 +14,18 @@ schema = Schema(query=Query)
def test_uuidstring_query(): def test_uuidstring_query():
uuid_value = 'dfeb3bcf-70fd-11e7-a61a-6003088f8204' uuid_value = "dfeb3bcf-70fd-11e7-a61a-6003088f8204"
result = schema.execute('''{ uuid(input: "%s") }''' % uuid_value) result = schema.execute("""{ uuid(input: "%s") }""" % uuid_value)
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"uuid": uuid_value}
'uuid': uuid_value
}
def test_uuidstring_query_variable(): def test_uuidstring_query_variable():
uuid_value = 'dfeb3bcf-70fd-11e7-a61a-6003088f8204' uuid_value = "dfeb3bcf-70fd-11e7-a61a-6003088f8204"
result = schema.execute( result = schema.execute(
'''query Test($uuid: UUID){ uuid(input: $uuid) }''', """query Test($uuid: UUID){ uuid(input: $uuid) }""",
variable_values={'uuid': uuid_value} variable_values={"uuid": uuid_value},
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {"uuid": uuid_value}
'uuid': uuid_value
}

View File

@ -2,19 +2,33 @@ import inspect
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from graphql import (GraphQLArgument, GraphQLBoolean, GraphQLField, from graphql import (
GraphQLFloat, GraphQLID, GraphQLInputObjectField, GraphQLArgument,
GraphQLInt, GraphQLList, GraphQLNonNull, GraphQLString) GraphQLBoolean,
GraphQLField,
GraphQLFloat,
GraphQLID,
GraphQLInputObjectField,
GraphQLInt,
GraphQLList,
GraphQLNonNull,
GraphQLString,
)
from graphql.execution.executor import get_default_resolve_type_fn from graphql.execution.executor import get_default_resolve_type_fn
from graphql.type import GraphQLEnumValue from graphql.type import GraphQLEnumValue
from graphql.type.typemap import GraphQLTypeMap from graphql.type.typemap import GraphQLTypeMap
from ..utils.get_unbound_function import get_unbound_function from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case from ..utils.str_converters import to_camel_case
from .definitions import (GrapheneEnumType, GrapheneGraphQLType, from .definitions import (
GrapheneInputObjectType, GrapheneInterfaceType, GrapheneEnumType,
GrapheneObjectType, GrapheneScalarType, GrapheneGraphQLType,
GrapheneUnionType) GrapheneInputObjectType,
GrapheneInterfaceType,
GrapheneObjectType,
GrapheneScalarType,
GrapheneUnionType,
)
from .dynamic import Dynamic from .dynamic import Dynamic
from .enum import Enum from .enum import Enum
from .field import Field from .field import Field
@ -31,9 +45,9 @@ from .utils import get_field_as
def is_graphene_type(_type): def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)): if isinstance(_type, (List, NonNull)):
return True return True
if inspect.isclass(_type) and issubclass(_type, if inspect.isclass(_type) and issubclass(
(ObjectType, InputObjectType, _type, (ObjectType, InputObjectType, Scalar, Interface, Union, Enum)
Scalar, Interface, Union, Enum)): ):
return True return True
@ -46,11 +60,9 @@ def resolve_type(resolve_type_func, map, type_name, root, info):
if inspect.isclass(_type) and issubclass(_type, ObjectType): if inspect.isclass(_type) and issubclass(_type, ObjectType):
graphql_type = map.get(_type._meta.name) graphql_type = map.get(_type._meta.name)
assert graphql_type, "Can't find type {} in schema".format( assert graphql_type, "Can't find type {} in schema".format(_type._meta.name)
_type._meta.name
)
assert graphql_type.graphene_type == _type, ( assert graphql_type.graphene_type == _type, (
'The type {} does not match with the associated graphene type {}.' "The type {} does not match with the associated graphene type {}."
).format(_type, graphql_type.graphene_type) ).format(_type, graphql_type.graphene_type)
return graphql_type return graphql_type
@ -62,7 +74,6 @@ def is_type_of_from_possible_types(possible_types, root, info):
class TypeMap(GraphQLTypeMap): class TypeMap(GraphQLTypeMap):
def __init__(self, types, auto_camelcase=True, schema=None): def __init__(self, types, auto_camelcase=True, schema=None):
self.auto_camelcase = auto_camelcase self.auto_camelcase = auto_camelcase
self.schema = schema self.schema = schema
@ -84,7 +95,7 @@ class TypeMap(GraphQLTypeMap):
_type = map[type._meta.name] _type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType): if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type, ( assert _type.graphene_type == type, (
'Found different types with the same name in the schema: {}, {}.' "Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type) ).format(_type.graphene_type, type)
return map return map
@ -101,8 +112,7 @@ class TypeMap(GraphQLTypeMap):
elif issubclass(type, Union): elif issubclass(type, Union):
internal_type = self.construct_union(map, type) internal_type = self.construct_union(map, type)
else: else:
raise Exception( raise Exception("Expected Graphene type, but received: {}.".format(type))
"Expected Graphene type, but received: {}.".format(type))
return GraphQLTypeMap.reducer(map, internal_type) return GraphQLTypeMap.reducer(map, internal_type)
@ -114,7 +124,7 @@ class TypeMap(GraphQLTypeMap):
Int: GraphQLInt, Int: GraphQLInt,
Float: GraphQLFloat, Float: GraphQLFloat,
Boolean: GraphQLBoolean, Boolean: GraphQLBoolean,
ID: GraphQLID ID: GraphQLID,
} }
if type in _scalars: if type in _scalars:
return _scalars[type] return _scalars[type]
@ -123,15 +133,16 @@ class TypeMap(GraphQLTypeMap):
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
serialize=getattr(type, 'serialize', None), serialize=getattr(type, "serialize", None),
parse_value=getattr(type, 'parse_value', None), parse_value=getattr(type, "parse_value", None),
parse_literal=getattr(type, 'parse_literal', None), ) parse_literal=getattr(type, "parse_literal", None),
)
def construct_enum(self, map, type): def construct_enum(self, map, type):
values = OrderedDict() values = OrderedDict()
for name, value in type._meta.enum.__members__.items(): for name, value in type._meta.enum.__members__.items():
description = getattr(value, 'description', None) description = getattr(value, "description", None)
deprecation_reason = getattr(value, 'deprecation_reason', None) deprecation_reason = getattr(value, "deprecation_reason", None)
if not description and callable(type._meta.description): if not description and callable(type._meta.description):
description = type._meta.description(value) description = type._meta.description(value)
@ -142,22 +153,28 @@ class TypeMap(GraphQLTypeMap):
name=name, name=name,
value=value.value, value=value.value,
description=description, description=description,
deprecation_reason=deprecation_reason) deprecation_reason=deprecation_reason,
)
type_description = type._meta.description(None) if callable(type._meta.description) else type._meta.description type_description = (
type._meta.description(None)
if callable(type._meta.description)
else type._meta.description
)
return GrapheneEnumType( return GrapheneEnumType(
graphene_type=type, graphene_type=type,
values=values, values=values,
name=type._meta.name, name=type._meta.name,
description=type_description, ) description=type_description,
)
def construct_objecttype(self, map, type): def construct_objecttype(self, map, type):
if type._meta.name in map: if type._meta.name in map:
_type = map[type._meta.name] _type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType): if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type, ( assert _type.graphene_type == type, (
'Found different types with the same name in the schema: {}, {}.' "Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type) ).format(_type.graphene_type, type)
return _type return _type
@ -171,8 +188,9 @@ class TypeMap(GraphQLTypeMap):
return interfaces return interfaces
if type._meta.possible_types: if type._meta.possible_types:
is_type_of = partial(is_type_of_from_possible_types, is_type_of = partial(
type._meta.possible_types) is_type_of_from_possible_types, type._meta.possible_types
)
else: else:
is_type_of = type.is_type_of is_type_of = type.is_type_of
@ -182,27 +200,30 @@ class TypeMap(GraphQLTypeMap):
description=type._meta.description, description=type._meta.description,
fields=partial(self.construct_fields_for_type, map, type), fields=partial(self.construct_fields_for_type, map, type),
is_type_of=is_type_of, is_type_of=is_type_of,
interfaces=interfaces) interfaces=interfaces,
)
def construct_interface(self, map, type): def construct_interface(self, map, type):
if type._meta.name in map: if type._meta.name in map:
_type = map[type._meta.name] _type = map[type._meta.name]
if isinstance(_type, GrapheneInterfaceType): if isinstance(_type, GrapheneInterfaceType):
assert _type.graphene_type == type, ( assert _type.graphene_type == type, (
'Found different types with the same name in the schema: {}, {}.' "Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type) ).format(_type.graphene_type, type)
return _type return _type
_resolve_type = None _resolve_type = None
if type.resolve_type: if type.resolve_type:
_resolve_type = partial(resolve_type, type.resolve_type, map, _resolve_type = partial(
type._meta.name) resolve_type, type.resolve_type, map, type._meta.name
)
return GrapheneInterfaceType( return GrapheneInterfaceType(
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
description=type._meta.description, description=type._meta.description,
fields=partial(self.construct_fields_for_type, map, type), fields=partial(self.construct_fields_for_type, map, type),
resolve_type=_resolve_type, ) resolve_type=_resolve_type,
)
def construct_inputobjecttype(self, map, type): def construct_inputobjecttype(self, map, type):
return GrapheneInputObjectType( return GrapheneInputObjectType(
@ -211,14 +232,16 @@ class TypeMap(GraphQLTypeMap):
description=type._meta.description, description=type._meta.description,
container_type=type._meta.container, container_type=type._meta.container,
fields=partial( fields=partial(
self.construct_fields_for_type, map, type, is_input_type=True), self.construct_fields_for_type, map, type, is_input_type=True
),
) )
def construct_union(self, map, type): def construct_union(self, map, type):
_resolve_type = None _resolve_type = None
if type.resolve_type: if type.resolve_type:
_resolve_type = partial(resolve_type, type.resolve_type, map, _resolve_type = partial(
type._meta.name) resolve_type, type.resolve_type, map, type._meta.name
)
def types(): def types():
union_types = [] union_types = []
@ -233,7 +256,8 @@ class TypeMap(GraphQLTypeMap):
graphene_type=type, graphene_type=type,
name=type._meta.name, name=type._meta.name,
types=types, types=types,
resolve_type=_resolve_type, ) resolve_type=_resolve_type,
)
def get_name(self, name): def get_name(self, name):
if self.auto_camelcase: if self.auto_camelcase:
@ -254,7 +278,8 @@ class TypeMap(GraphQLTypeMap):
field_type, field_type,
default_value=field.default_value, default_value=field.default_value,
out_name=name, out_name=name,
description=field.description) description=field.description,
)
else: else:
args = OrderedDict() args = OrderedDict()
for arg_name, arg in field.args.items(): for arg_name, arg in field.args.items():
@ -265,19 +290,17 @@ class TypeMap(GraphQLTypeMap):
arg_type, arg_type,
out_name=arg_name, out_name=arg_name,
description=arg.description, description=arg.description,
default_value=arg.default_value) default_value=arg.default_value,
)
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolver=field.get_resolver( resolver=field.get_resolver(
self.get_resolver_for_type( self.get_resolver_for_type(type, name, field.default_value)
type,
name,
field.default_value
)
), ),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description) description=field.description,
)
field_name = field.name or self.get_name(name) field_name = field.name or self.get_name(name)
fields[field_name] = _field fields[field_name] = _field
return fields return fields
@ -285,7 +308,7 @@ class TypeMap(GraphQLTypeMap):
def get_resolver_for_type(self, type, name, default_value): def get_resolver_for_type(self, type, name, default_value):
if not issubclass(type, ObjectType): if not issubclass(type, ObjectType):
return return
resolver = getattr(type, 'resolve_{}'.format(name), None) resolver = getattr(type, "resolve_{}".format(name), None)
if not resolver: if not resolver:
# If we don't find the resolver in the ObjectType class, then try to # If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces # find it in each of the interfaces
@ -293,8 +316,7 @@ class TypeMap(GraphQLTypeMap):
for interface in type._meta.interfaces: for interface in type._meta.interfaces:
if name not in interface._meta.fields: if name not in interface._meta.fields:
continue continue
interface_resolver = getattr(interface, interface_resolver = getattr(interface, "resolve_{}".format(name), None)
'resolve_{}'.format(name), None)
if interface_resolver: if interface_resolver:
break break
resolver = interface_resolver resolver = interface_resolver
@ -303,8 +325,7 @@ class TypeMap(GraphQLTypeMap):
if resolver: if resolver:
return get_unbound_function(resolver) return get_unbound_function(resolver)
default_resolver = type._meta.default_resolver or get_default_resolver( default_resolver = type._meta.default_resolver or get_default_resolver()
)
return partial(default_resolver, name, default_value) return partial(default_resolver, name, default_value)
def get_field_type(self, map, type): def get_field_type(self, map, type):

View File

@ -13,19 +13,19 @@ class UnionOptions(BaseOptions):
class Union(UnmountedType, BaseType): class Union(UnmountedType, BaseType):
''' """
Union Type Definition Union Type Definition
When a field can return one of a heterogeneous set of types, a Union type When a field can return one of a heterogeneous set of types, a Union type
is used to describe what types are possible as well as providing a function is used to describe what types are possible as well as providing a function
to determine which type is actually used when the field is resolved. to determine which type is actually used when the field is resolved.
''' """
@classmethod @classmethod
def __init_subclass_with_meta__(cls, types=None, **options): def __init_subclass_with_meta__(cls, types=None, **options):
assert ( assert (
isinstance(types, (list, tuple)) and isinstance(types, (list, tuple)) and len(types) > 0
len(types) > 0 ), "Must provide types for Union {name}.".format(name=cls.__name__)
), 'Must provide types for Union {name}.'.format(name=cls.__name__)
_meta = UnionOptions(cls) _meta = UnionOptions(cls)
_meta.types = types _meta.types = types
@ -33,14 +33,15 @@ class Union(UnmountedType, BaseType):
@classmethod @classmethod
def get_type(cls): def get_type(cls):
''' """
This function is called when the unmounted type (Union instance) This function is called when the unmounted type (Union instance)
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
''' """
return cls return cls
@classmethod @classmethod
def resolve_type(cls, instance, info): def resolve_type(cls, instance, info):
from .objecttype import ObjectType # NOQA from .objecttype import ObjectType # NOQA
if isinstance(instance, ObjectType): if isinstance(instance, ObjectType):
return type(instance) return type(instance)

View File

@ -2,7 +2,7 @@ from ..utils.orderedtype import OrderedType
class UnmountedType(OrderedType): class UnmountedType(OrderedType):
''' """
This class acts a proxy for a Graphene Type, so it can be mounted This class acts a proxy for a Graphene Type, so it can be mounted
dynamically as Field, InputField or Argument. dynamically as Field, InputField or Argument.
@ -13,7 +13,7 @@ class UnmountedType(OrderedType):
It let you write It let you write
>>> class MyObjectType(ObjectType): >>> class MyObjectType(ObjectType):
>>> my_field = String(description='Description here') >>> my_field = String(description='Description here')
''' """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(UnmountedType, self).__init__() super(UnmountedType, self).__init__()
@ -21,42 +21,43 @@ class UnmountedType(OrderedType):
self.kwargs = kwargs self.kwargs = kwargs
def get_type(self): def get_type(self):
''' """
This function is called when the UnmountedType instance This function is called when the UnmountedType instance
is mounted (as a Field, InputField or Argument) is mounted (as a Field, InputField or Argument)
''' """
raise NotImplementedError("get_type not implemented in {}".format(self)) raise NotImplementedError("get_type not implemented in {}".format(self))
def mount_as(self, _as): def mount_as(self, _as):
return _as.mounted(self) return _as.mounted(self)
def Field(self): # noqa: N802 def Field(self): # noqa: N802
''' """
Mount the UnmountedType as Field Mount the UnmountedType as Field
''' """
from .field import Field from .field import Field
return self.mount_as(Field) return self.mount_as(Field)
def InputField(self): # noqa: N802 def InputField(self): # noqa: N802
''' """
Mount the UnmountedType as InputField Mount the UnmountedType as InputField
''' """
from .inputfield import InputField from .inputfield import InputField
return self.mount_as(InputField) return self.mount_as(InputField)
def Argument(self): # noqa: N802 def Argument(self): # noqa: N802
''' """
Mount the UnmountedType as Argument Mount the UnmountedType as Argument
''' """
from .argument import Argument from .argument import Argument
return self.mount_as(Argument) return self.mount_as(Argument)
def __eq__(self, other): def __eq__(self, other):
return ( return self is other or (
self is other or ( isinstance(other, UnmountedType)
isinstance(other, UnmountedType) and and self.get_type() == other.get_type()
self.get_type() == other.get_type() and and self.args == other.args
self.args == other.args and and self.kwargs == other.kwargs
self.kwargs == other.kwargs
)
) )

View File

@ -10,9 +10,9 @@ from .unmountedtype import UnmountedType
def get_field_as(value, _as=None): def get_field_as(value, _as=None):
''' """
Get type mounted Get type mounted
''' """
if isinstance(value, MountedType): if isinstance(value, MountedType):
return value return value
elif isinstance(value, UnmountedType): elif isinstance(value, UnmountedType):
@ -22,10 +22,10 @@ def get_field_as(value, _as=None):
def yank_fields_from_attrs(attrs, _as=None, sort=True): def yank_fields_from_attrs(attrs, _as=None, sort=True):
''' """
Extract all the fields in given attributes (dict) Extract all the fields in given attributes (dict)
and return them ordered and return them ordered
''' """
fields_with_names = [] fields_with_names = []
for attname, value in list(attrs.items()): for attname, value in list(attrs.items()):
field = get_field_as(value, _as) field = get_field_as(value, _as)

View File

@ -8,13 +8,15 @@ from .scalars import Scalar
class UUID(Scalar): class UUID(Scalar):
'''UUID''' """UUID"""
@staticmethod @staticmethod
def serialize(uuid): def serialize(uuid):
if isinstance(uuid, str): if isinstance(uuid, str):
uuid = _UUID(uuid) uuid = _UUID(uuid)
assert isinstance(uuid, _UUID), "Expected UUID instance, received {}".format(uuid) assert isinstance(uuid, _UUID), "Expected UUID instance, received {}".format(
uuid
)
return str(uuid) return str(uuid)
@staticmethod @staticmethod

View File

@ -12,8 +12,10 @@ def annotate(_func=None, _trigger_warning=True, **annotations):
) )
if not _func: if not _func:
def _func(f): def _func(f):
return annotate(f, **annotations) return annotate(f, **annotations)
return _func return _func
func_signature = signature(_func) func_signature = signature(_func)
@ -22,12 +24,9 @@ def annotate(_func=None, _trigger_warning=True, **annotations):
for key, value in annotations.items(): for key, value in annotations.items():
assert key in func_signature.parameters, ( assert key in func_signature.parameters, (
'The key {key} is not a function parameter in the function "{func_name}".' 'The key {key} is not a function parameter in the function "{func_name}".'
).format( ).format(key=key, func_name=func_name(_func))
key=key,
func_name=func_name(_func)
)
func_annotations = getattr(_func, '__annotations__', None) func_annotations = getattr(_func, "__annotations__", None)
if func_annotations is None: if func_annotations is None:
_func.__annotations__ = annotations _func.__annotations__ = annotations
else: else:

35
graphene/utils/crunch.py Normal file
View File

@ -0,0 +1,35 @@
import json
from collections import Mapping
def to_key(value):
return json.dumps(value)
def insert(value, index, values):
key = to_key(value)
if key not in index:
index[key] = len(values)
values.append(value)
return len(values) - 1
return index.get(key)
def flatten(data, index, values):
if isinstance(data, (list, tuple)):
flattened = [flatten(child, index, values) for child in data]
elif isinstance(data, Mapping):
flattened = {key: flatten(child, index, values) for key, child in data.items()}
else:
flattened = data
return insert(flattened, index, values)
def crunch(data):
index = {}
values = []
flatten(data, index, values)
return values

View File

@ -0,0 +1,33 @@
from collections import Mapping, OrderedDict
def deflate(node, index=None, path=None):
if index is None:
index = {}
if path is None:
path = []
if node and "id" in node and "__typename" in node:
route = ",".join(path)
cache_key = ":".join([route, str(node["__typename"]), str(node["id"])])
if index.get(cache_key) is True:
return {"__typename": node["__typename"], "id": node["id"]}
else:
index[cache_key] = True
field_names = node.keys()
result = OrderedDict()
for field_name in field_names:
value = node[field_name]
new_path = path + [field_name]
if isinstance(value, (list, tuple)):
result[field_name] = [deflate(child, index, new_path) for child in value]
elif isinstance(value, Mapping):
result[field_name] = deflate(value, index, new_path)
else:
result[field_name] = value
return result

View File

@ -2,15 +2,11 @@ import functools
import inspect import inspect
import warnings import warnings
string_types = (type(b''), type(u'')) string_types = (type(b""), type(u""))
def warn_deprecation(text): def warn_deprecation(text):
warnings.warn( warnings.warn(text, category=DeprecationWarning, stacklevel=2)
text,
category=DeprecationWarning,
stacklevel=2
)
def deprecated(reason): def deprecated(reason):
@ -39,9 +35,7 @@ def deprecated(reason):
@functools.wraps(func1) @functools.wraps(func1)
def new_func1(*args, **kwargs): def new_func1(*args, **kwargs):
warn_deprecation( warn_deprecation(fmt1.format(name=func1.__name__, reason=reason))
fmt1.format(name=func1.__name__, reason=reason),
)
return func1(*args, **kwargs) return func1(*args, **kwargs)
return new_func1 return new_func1
@ -67,9 +61,7 @@ def deprecated(reason):
@functools.wraps(func2) @functools.wraps(func2)
def new_func2(*args, **kwargs): def new_func2(*args, **kwargs):
warn_deprecation( warn_deprecation(fmt2.format(name=func2.__name__))
fmt2.format(name=func2.__name__),
)
return func2(*args, **kwargs) return func2(*args, **kwargs)
return new_func2 return new_func2

View File

@ -1,4 +1,4 @@
def get_unbound_function(func): def get_unbound_function(func):
if not getattr(func, '__self__', True): if not getattr(func, "__self__", True):
return func.__func__ return func.__func__
return func return func

View File

@ -11,7 +11,7 @@ def import_string(dotted_path, dotted_attributes=None):
attribute path. Raise ImportError if the import failed. attribute path. Raise ImportError if the import failed.
""" """
try: try:
module_path, class_name = dotted_path.rsplit('.', 1) module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError: except ValueError:
raise ImportError("%s doesn't look like a module path" % dotted_path) raise ImportError("%s doesn't look like a module path" % dotted_path)
@ -20,14 +20,15 @@ def import_string(dotted_path, dotted_attributes=None):
try: try:
result = getattr(module, class_name) result = getattr(module, class_name)
except AttributeError: except AttributeError:
raise ImportError('Module "%s" does not define a "%s" attribute/class' % ( raise ImportError(
module_path, class_name) 'Module "%s" does not define a "%s" attribute/class'
% (module_path, class_name)
) )
if not dotted_attributes: if not dotted_attributes:
return result return result
else: else:
attributes = dotted_attributes.split('.') attributes = dotted_attributes.split(".")
traveled_attributes = [] traveled_attributes = []
try: try:
for attribute in attributes: for attribute in attributes:
@ -35,9 +36,10 @@ def import_string(dotted_path, dotted_attributes=None):
result = getattr(result, attribute) result = getattr(result, attribute)
return result return result
except AttributeError: except AttributeError:
raise ImportError('Module "%s" does not define a "%s" attribute inside attribute/class "%s"' % ( raise ImportError(
module_path, '.'.join(traveled_attributes), class_name 'Module "%s" does not define a "%s" attribute inside attribute/class "%s"'
)) % (module_path, ".".join(traveled_attributes), class_name)
)
def lazy_import(dotted_path, dotted_attributes=None): def lazy_import(dotted_path, dotted_attributes=None):

View File

@ -3,7 +3,7 @@ from functools import wraps
from .deprecated import deprecated from .deprecated import deprecated
@deprecated('This function is deprecated') @deprecated("This function is deprecated")
def resolve_only_args(func): def resolve_only_args(func):
@wraps(func) @wraps(func)
def wrapped_func(root, info, **args): def wrapped_func(root, info, **args):

View File

@ -4,18 +4,18 @@ import re
# Adapted from this response in Stackoverflow # Adapted from this response in Stackoverflow
# http://stackoverflow.com/a/19053800/1072990 # http://stackoverflow.com/a/19053800/1072990
def to_camel_case(snake_str): def to_camel_case(snake_str):
components = snake_str.split('_') components = snake_str.split("_")
# We capitalize the first letter of each component except the first one # We capitalize the first letter of each component except the first one
# with the 'capitalize' method and join them together. # with the 'capitalize' method and join them together.
return components[0] + ''.join(x.capitalize() if x else '_' for x in components[1:]) return components[0] + "".join(x.capitalize() if x else "_" for x in components[1:])
# From this response in Stackoverflow # From this response in Stackoverflow
# http://stackoverflow.com/a/1176023/1072990 # http://stackoverflow.com/a/1176023/1072990
def to_snake_case(name): def to_snake_case(name):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def to_const(string): def to_const(string):
return re.sub('[\W|^]+', '_', string).upper() return re.sub("[\W|^]+", "_", string).upper()

View File

@ -20,6 +20,7 @@ class SubclassWithMeta_Meta(InitSubclassMeta):
class SubclassWithMeta(six.with_metaclass(SubclassWithMeta_Meta)): class SubclassWithMeta(six.with_metaclass(SubclassWithMeta_Meta)):
"""This class improves __init_subclass__ to receive automatically the options from meta""" """This class improves __init_subclass__ to receive automatically the options from meta"""
# We will only have the metaclass in Python 2 # We will only have the metaclass in Python 2
def __init_subclass__(cls, **meta_options): def __init_subclass__(cls, **meta_options):
"""This method just terminates the super() chain""" """This method just terminates the super() chain"""
@ -32,19 +33,22 @@ class SubclassWithMeta(six.with_metaclass(SubclassWithMeta_Meta)):
_meta_props = props(_Meta) _meta_props = props(_Meta)
else: else:
raise Exception( raise Exception(
"Meta have to be either a class or a dict. Received {}".format(_Meta)) "Meta have to be either a class or a dict. Received {}".format(
_Meta
)
)
delattr(cls, "Meta") delattr(cls, "Meta")
options = dict(meta_options, **_meta_props) options = dict(meta_options, **_meta_props)
abstract = options.pop('abstract', False) abstract = options.pop("abstract", False)
if abstract: if abstract:
assert not options, ( assert not options, (
"Abstract types can only contain the abstract attribute. " "Abstract types can only contain the abstract attribute. "
"Received: abstract, {option_keys}" "Received: abstract, {option_keys}"
).format(option_keys=', '.join(options.keys())) ).format(option_keys=", ".join(options.keys()))
else: else:
super_class = super(cls, cls) super_class = super(cls, cls)
if hasattr(super_class, '__init_subclass_with_meta__'): if hasattr(super_class, "__init_subclass_with_meta__"):
super_class.__init_subclass_with_meta__(**options) super_class.__init_subclass_with_meta__(**options)
@classmethod @classmethod

Some files were not shown because too many files have changed in this diff Show More