Change dependency to graphql-core-next (#988)

* Changed dependencies to core-next

* Converted Scalars

* ResolveInfo name change

* Ignore .venv

* Make Schema compatible with GraphQL-core-next

* Ignore more venv names and mypy and pytest caches

* Remove print statements for debugging in schema test

* core-next now provides out_type and out_name

* Adapt date and time scalar types to core-next

* Ignore the non-standard result.invalid flag

* Results are named tuples in core-next (immutable)

* Enum values are returned as dict in core-next

* Fix mutation tests with promises

* Make all 345 tests pass with graphql-core-next

* Remove the compat module which was only needed for older Py version

* Remove object as base class (not needed in Py 3)

* We can assume that dicts are ordered in Py 3.6+

* Make use of the fact that dicts are iterable

* Use consistent style of importing from pytest

* Restore compatibility with graphql-relay-py v3

Add adpaters for the PageInfo and Connection args.

* Avoid various deprecation warnings

* Use graphql-core 3 instead of graphql-core-next

* Update dependencies, reformat changes with black

* Update graphene/relay/connection.py

Co-Authored-By: Jonathan Kim <jkimbo@gmail.com>

* Run black on setup.py

* Remove trailing whitespace
This commit is contained in:
Eran Kampf 2019-08-12 11:04:02 -07:00 committed by Mel van Londen
parent dc812fe028
commit 7ef3c8ee3e
68 changed files with 1092 additions and 1897 deletions

14
.gitignore vendored
View File

@ -10,9 +10,6 @@ __pycache__/
# Distribution / packaging
.Python
env/
venv/
.venv/
build/
develop-eggs/
dist/
@ -47,7 +44,8 @@ htmlcov/
.pytest_cache
nosetests.xml
coverage.xml
*,cover
*.cover
.pytest_cache/
# Translations
*.mo
@ -62,6 +60,14 @@ docs/_build/
# PyBuilder
target/
# VirtualEnv
.env
.venv
env/
venv/
# Typing
.mypy_cache/
/tests/django.sqlite

View File

@ -23,6 +23,6 @@ repos:
- id: black
language_version: python3
- repo: https://github.com/PyCQA/flake8
rev: 3.7.7
rev: 3.7.8
hooks:
- id: flake8

View File

@ -30,21 +30,22 @@ snapshots["test_correctly_refetches_xwing 1"] = {
snapshots[
"test_str_schema 1"
] = """schema {
query: Query
mutation: Mutation
}
] = '''"""A faction in the Star Wars saga"""
type Faction implements Node {
"""The ID of the object"""
id: ID!
"""The name of the faction."""
name: String
ships(before: String, after: String, first: Int, last: Int): ShipConnection
"""The ships used by the faction."""
ships(before: String = null, after: String = null, first: Int = null, last: Int = null): ShipConnection
}
input IntroduceShipInput {
shipName: String!
factionId: String!
clientMutationId: String
clientMutationId: String = null
}
type IntroduceShipPayload {
@ -57,35 +58,60 @@ type Mutation {
introduceShip(input: IntroduceShipInput!): IntroduceShipPayload
}
"""An object with an ID"""
interface Node {
"""The ID of the object"""
id: ID!
}
"""
The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.
"""
type PageInfo {
"""When paginating forwards, are there more items?"""
hasNextPage: Boolean!
"""When paginating backwards, are there more items?"""
hasPreviousPage: Boolean!
"""When paginating backwards, the cursor to continue."""
startCursor: String
"""When paginating forwards, the cursor to continue."""
endCursor: String
}
type Query {
rebels: Faction
empire: Faction
"""The ID of the object"""
node(id: ID!): Node
}
"""A ship in the Star Wars saga"""
type Ship implements Node {
"""The ID of the object"""
id: ID!
"""The name of the ship."""
name: String
}
type ShipConnection {
"""Pagination data for this connection."""
pageInfo: PageInfo!
"""Contains the nodes in this connection."""
edges: [ShipEdge]!
}
"""A Relay edge containing a `Ship` and its cursor."""
type ShipEdge {
"""The item at the end of the edge"""
node: Ship
"""A cursor for use in pagination"""
cursor: String!
}
"""
'''

View File

@ -1,8 +0,0 @@
from __future__ import absolute_import
from graphql.pyutils.compat import Enum
try:
from inspect import signature
except ImportError:
from .signature import signature

View File

@ -1,850 +0,0 @@
# Copyright 2001-2013 Python Software Foundation; All Rights Reserved
"""Function signature objects for callables
Back port of Python 3.3's function signature tools from the inspect module,
modified to be compatible with Python 2.7 and 3.2+.
"""
from __future__ import absolute_import, division, print_function
import functools
import itertools
import re
import types
from collections import OrderedDict
__version__ = "0.4"
__all__ = ["BoundArguments", "Parameter", "Signature", "signature"]
_WrapperDescriptor = type(type.__call__)
_MethodWrapper = type(all.__call__)
_NonUserDefinedCallables = (
_WrapperDescriptor,
_MethodWrapper,
types.BuiltinFunctionType,
)
def formatannotation(annotation, base_module=None):
if isinstance(annotation, type):
if annotation.__module__ in ("builtins", "__builtin__", base_module):
return annotation.__name__
return annotation.__module__ + "." + annotation.__name__
return repr(annotation)
def _get_user_defined_method(cls, method_name, *nested):
try:
if cls is type:
return
meth = getattr(cls, method_name)
for name in nested:
meth = getattr(meth, name, meth)
except AttributeError:
return
else:
if not isinstance(meth, _NonUserDefinedCallables):
# Once '__signature__' will be added to 'C'-level
# callables, this check won't be necessary
return meth
def signature(obj):
"""Get a signature object for the passed callable."""
if not callable(obj):
raise TypeError("{!r} is not a callable object".format(obj))
if isinstance(obj, types.MethodType):
sig = signature(obj.__func__)
if obj.__self__ is None:
# Unbound method: the first parameter becomes positional-only
if sig.parameters:
first = sig.parameters.values()[0].replace(kind=_POSITIONAL_ONLY)
return sig.replace(
parameters=(first,) + tuple(sig.parameters.values())[1:]
)
else:
return sig
else:
# In this case we skip the first parameter of the underlying
# function (usually `self` or `cls`).
return sig.replace(parameters=tuple(sig.parameters.values())[1:])
try:
sig = obj.__signature__
except AttributeError:
pass
else:
if sig is not None:
return sig
try:
# Was this function wrapped by a decorator?
wrapped = obj.__wrapped__
except AttributeError:
pass
else:
return signature(wrapped)
if isinstance(obj, types.FunctionType):
return Signature.from_function(obj)
if isinstance(obj, functools.partial):
sig = signature(obj.func)
new_params = OrderedDict(sig.parameters.items())
partial_args = obj.args or ()
partial_keywords = obj.keywords or {}
try:
ba = sig.bind_partial(*partial_args, **partial_keywords)
except TypeError as ex:
msg = "partial object {!r} has incorrect arguments".format(obj)
raise ValueError(msg)
for arg_name, arg_value in ba.arguments.items():
param = new_params[arg_name]
if arg_name in partial_keywords:
# We set a new default value, because the following code
# is correct:
#
# >>> def foo(a): print(a)
# >>> print(partial(partial(foo, a=10), a=20)())
# 20
# >>> print(partial(partial(foo, a=10), a=20)(a=30))
# 30
#
# So, with 'partial' objects, passing a keyword argument is
# like setting a new default value for the corresponding
# parameter
#
# We also mark this parameter with '_partial_kwarg'
# flag. Later, in '_bind', the 'default' value of this
# parameter will be added to 'kwargs', to simulate
# the 'functools.partial' real call.
new_params[arg_name] = param.replace(
default=arg_value, _partial_kwarg=True
)
elif (
param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL)
and not param._partial_kwarg
):
new_params.pop(arg_name)
return sig.replace(parameters=new_params.values())
sig = None
if isinstance(obj, type):
# obj is a class or a metaclass
# First, let's see if it has an overloaded __call__ defined
# in its metaclass
call = _get_user_defined_method(type(obj), "__call__")
if call is not None:
sig = signature(call)
else:
# Now we check if the 'obj' class has a '__new__' method
new = _get_user_defined_method(obj, "__new__")
if new is not None:
sig = signature(new)
else:
# Finally, we should have at least __init__ implemented
init = _get_user_defined_method(obj, "__init__")
if init is not None:
sig = signature(init)
elif not isinstance(obj, _NonUserDefinedCallables):
# An object with __call__
# We also check that the 'obj' is not an instance of
# _WrapperDescriptor or _MethodWrapper to avoid
# infinite recursion (and even potential segfault)
call = _get_user_defined_method(type(obj), "__call__", "im_func")
if call is not None:
sig = signature(call)
if sig is not None:
# For classes and objects we skip the first parameter of their
# __call__, __new__, or __init__ methods
return sig.replace(parameters=tuple(sig.parameters.values())[1:])
if isinstance(obj, types.BuiltinFunctionType):
# Raise a nicer error message for builtins
msg = "no signature found for builtin function {!r}".format(obj)
raise ValueError(msg)
raise ValueError("callable {!r} is not supported by signature".format(obj))
class _void(object):
"""A private marker - used in Parameter & Signature"""
class _empty(object):
pass
class _ParameterKind(int):
def __new__(self, *args, **kwargs):
obj = int.__new__(self, *args)
obj._name = kwargs["name"]
return obj
def __str__(self):
return self._name
def __repr__(self):
return "<_ParameterKind: {!r}>".format(self._name)
_POSITIONAL_ONLY = _ParameterKind(0, name="POSITIONAL_ONLY")
_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name="POSITIONAL_OR_KEYWORD")
_VAR_POSITIONAL = _ParameterKind(2, name="VAR_POSITIONAL")
_KEYWORD_ONLY = _ParameterKind(3, name="KEYWORD_ONLY")
_VAR_KEYWORD = _ParameterKind(4, name="VAR_KEYWORD")
class Parameter(object):
"""Represents a parameter in a function signature.
Has the following public attributes:
* name : str
The name of the parameter as a string.
* default : object
The default value for the parameter if specified. If the
parameter has no default value, this attribute is not set.
* annotation
The annotation for the parameter if specified. If the
parameter has no annotation, this attribute is not set.
* kind : str
Describes how argument values are bound to the parameter.
Possible values: `Parameter.POSITIONAL_ONLY`,
`Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`,
`Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`.
"""
__slots__ = ("_name", "_kind", "_default", "_annotation", "_partial_kwarg")
POSITIONAL_ONLY = _POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD
VAR_POSITIONAL = _VAR_POSITIONAL
KEYWORD_ONLY = _KEYWORD_ONLY
VAR_KEYWORD = _VAR_KEYWORD
empty = _empty
def __init__(
self, name, kind, default=_empty, annotation=_empty, _partial_kwarg=False
):
if kind not in (
_POSITIONAL_ONLY,
_POSITIONAL_OR_KEYWORD,
_VAR_POSITIONAL,
_KEYWORD_ONLY,
_VAR_KEYWORD,
):
raise ValueError("invalid value for 'Parameter.kind' attribute")
self._kind = kind
if default is not _empty:
if kind in (_VAR_POSITIONAL, _VAR_KEYWORD):
msg = "{} parameters cannot have default values".format(kind)
raise ValueError(msg)
self._default = default
self._annotation = annotation
if name is None:
if kind != _POSITIONAL_ONLY:
raise ValueError(
"None is not a valid name for a " "non-positional-only parameter"
)
self._name = name
else:
name = str(name)
if kind != _POSITIONAL_ONLY and not re.match(r"[a-z_]\w*$", name, re.I):
msg = "{!r} is not a valid parameter name".format(name)
raise ValueError(msg)
self._name = name
self._partial_kwarg = _partial_kwarg
@property
def name(self):
return self._name
@property
def default(self):
return self._default
@property
def annotation(self):
return self._annotation
@property
def kind(self):
return self._kind
def replace(
self,
name=_void,
kind=_void,
annotation=_void,
default=_void,
_partial_kwarg=_void,
):
"""Creates a customized copy of the Parameter."""
if name is _void:
name = self._name
if kind is _void:
kind = self._kind
if annotation is _void:
annotation = self._annotation
if default is _void:
default = self._default
if _partial_kwarg is _void:
_partial_kwarg = self._partial_kwarg
return type(self)(
name,
kind,
default=default,
annotation=annotation,
_partial_kwarg=_partial_kwarg,
)
def __str__(self):
kind = self.kind
formatted = self._name
if kind == _POSITIONAL_ONLY:
if formatted is None:
formatted = ""
formatted = "<{}>".format(formatted)
# Add annotation and default value
if self._annotation is not _empty:
formatted = "{}:{}".format(formatted, formatannotation(self._annotation))
if self._default is not _empty:
formatted = "{}={}".format(formatted, repr(self._default))
if kind == _VAR_POSITIONAL:
formatted = "*" + formatted
elif kind == _VAR_KEYWORD:
formatted = "**" + formatted
return formatted
def __repr__(self):
return "<{} at {:#x} {!r}>".format(self.__class__.__name__, id(self), self.name)
def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg)
def __eq__(self, other):
return (
issubclass(other.__class__, Parameter)
and self._name == other._name
and self._kind == other._kind
and self._default == other._default
and self._annotation == other._annotation
)
def __ne__(self, other):
return not self.__eq__(other)
class BoundArguments(object):
"""Result of `Signature.bind` call. Holds the mapping of arguments
to the function's parameters.
Has the following public attributes:
* arguments : OrderedDict
An ordered mutable mapping of parameters' names to arguments' values.
Does not contain arguments' default values.
* signature : Signature
The Signature object that created this instance.
* args : tuple
Tuple of positional arguments values.
* kwargs : dict
Dict of keyword arguments values.
"""
def __init__(self, signature, arguments):
self.arguments = arguments
self._signature = signature
@property
def signature(self):
return self._signature
@property
def args(self):
args = []
for param_name, param in self._signature.parameters.items():
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or param._partial_kwarg:
# Keyword arguments mapped by 'functools.partial'
# (Parameter._partial_kwarg is True) are mapped
# in 'BoundArguments.kwargs', along with VAR_KEYWORD &
# KEYWORD_ONLY
break
try:
arg = self.arguments[param_name]
except KeyError:
# We're done here. Other arguments
# will be mapped in 'BoundArguments.kwargs'
break
else:
if param.kind == _VAR_POSITIONAL:
# *args
args.extend(arg)
else:
# plain argument
args.append(arg)
return tuple(args)
@property
def kwargs(self):
kwargs = {}
kwargs_started = False
for param_name, param in self._signature.parameters.items():
if not kwargs_started:
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or param._partial_kwarg:
kwargs_started = True
else:
if param_name not in self.arguments:
kwargs_started = True
continue
if not kwargs_started:
continue
try:
arg = self.arguments[param_name]
except KeyError:
pass
else:
if param.kind == _VAR_KEYWORD:
# **kwargs
kwargs.update(arg)
else:
# plain keyword argument
kwargs[param_name] = arg
return kwargs
def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg)
def __eq__(self, other):
return (
issubclass(other.__class__, BoundArguments)
and self.signature == other.signature
and self.arguments == other.arguments
)
def __ne__(self, other):
return not self.__eq__(other)
class Signature(object):
"""A Signature object represents the overall signature of a function.
It stores a Parameter object for each parameter accepted by the
function, as well as information specific to the function itself.
A Signature object has the following public attributes and methods:
* parameters : OrderedDict
An ordered mapping of parameters' names to the corresponding
Parameter objects (keyword-only arguments are in the same order
as listed in `code.co_varnames`).
* return_annotation : object
The annotation for the return type of the function if specified.
If the function has no annotation for its return type, this
attribute is not set.
* bind(*args, **kwargs) -> BoundArguments
Creates a mapping from positional and keyword arguments to
parameters.
* bind_partial(*args, **kwargs) -> BoundArguments
Creates a partial mapping from positional and keyword arguments
to parameters (simulating 'functools.partial' behavior.)
"""
__slots__ = ("_return_annotation", "_parameters")
_parameter_cls = Parameter
_bound_arguments_cls = BoundArguments
empty = _empty
def __init__(
self, parameters=None, return_annotation=_empty, __validate_parameters__=True
):
"""Constructs Signature from the given list of Parameter
objects and 'return_annotation'. All arguments are optional.
"""
if parameters is None:
params = OrderedDict()
else:
if __validate_parameters__:
params = OrderedDict()
top_kind = _POSITIONAL_ONLY
for idx, param in enumerate(parameters):
kind = param.kind
if kind < top_kind:
msg = "wrong parameter order: {0} before {1}"
msg = msg.format(top_kind, param.kind)
raise ValueError(msg)
else:
top_kind = kind
name = param.name
if name is None:
name = str(idx)
param = param.replace(name=name)
if name in params:
msg = "duplicate parameter name: {!r}".format(name)
raise ValueError(msg)
params[name] = param
else:
params = OrderedDict(((param.name, param) for param in parameters))
self._parameters = params
self._return_annotation = return_annotation
@classmethod
def from_function(cls, func):
"""Constructs Signature for the given python function"""
if not isinstance(func, types.FunctionType):
raise TypeError("{!r} is not a Python function".format(func))
Parameter = cls._parameter_cls
# Parameter information.
func_code = func.__code__
pos_count = func_code.co_argcount
arg_names = func_code.co_varnames
positional = tuple(arg_names[:pos_count])
keyword_only_count = getattr(func_code, "co_kwonlyargcount", 0)
keyword_only = arg_names[pos_count : (pos_count + keyword_only_count)]
annotations = getattr(func, "__annotations__", {})
defaults = func.__defaults__
kwdefaults = getattr(func, "__kwdefaults__", None)
if defaults:
pos_default_count = len(defaults)
else:
pos_default_count = 0
parameters = []
# Non-keyword-only parameters w/o defaults.
non_default_count = pos_count - pos_default_count
for name in positional[:non_default_count]:
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(name, annotation=annotation, kind=_POSITIONAL_OR_KEYWORD)
)
# ... w/ defaults.
for offset, name in enumerate(positional[non_default_count:]):
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(
name,
annotation=annotation,
kind=_POSITIONAL_OR_KEYWORD,
default=defaults[offset],
)
)
# *args
if func_code.co_flags & 0x04:
name = arg_names[pos_count + keyword_only_count]
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(name, annotation=annotation, kind=_VAR_POSITIONAL)
)
# Keyword-only parameters.
for name in keyword_only:
default = _empty
if kwdefaults is not None:
default = kwdefaults.get(name, _empty)
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(
name, annotation=annotation, kind=_KEYWORD_ONLY, default=default
)
)
# **kwargs
if func_code.co_flags & 0x08:
index = pos_count + keyword_only_count
if func_code.co_flags & 0x04:
index += 1
name = arg_names[index]
annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, kind=_VAR_KEYWORD))
return cls(
parameters,
return_annotation=annotations.get("return", _empty),
__validate_parameters__=False,
)
@property
def parameters(self):
try:
return types.MappingProxyType(self._parameters)
except AttributeError:
return OrderedDict(self._parameters.items())
@property
def return_annotation(self):
return self._return_annotation
def replace(self, parameters=_void, return_annotation=_void):
"""Creates a customized copy of the Signature.
Pass 'parameters' and/or 'return_annotation' arguments
to override them in the new copy.
"""
if parameters is _void:
parameters = self.parameters.values()
if return_annotation is _void:
return_annotation = self._return_annotation
return type(self)(parameters, return_annotation=return_annotation)
def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg)
def __eq__(self, other):
if (
not issubclass(type(other), Signature)
or self.return_annotation != other.return_annotation
or len(self.parameters) != len(other.parameters)
):
return False
other_positions = {
param: idx for idx, param in enumerate(other.parameters.keys())
}
for idx, (param_name, param) in enumerate(self.parameters.items()):
if param.kind == _KEYWORD_ONLY:
try:
other_param = other.parameters[param_name]
except KeyError:
return False
else:
if param != other_param:
return False
else:
try:
other_idx = other_positions[param_name]
except KeyError:
return False
else:
if idx != other_idx or param != other.parameters[param_name]:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def _bind(self, args, kwargs, partial=False):
"""Private method. Don't use directly."""
arguments = OrderedDict()
parameters = iter(self.parameters.values())
parameters_ex = ()
arg_vals = iter(args)
if partial:
# Support for binding arguments to 'functools.partial' objects.
# See 'functools.partial' case in 'signature()' implementation
# for details.
for param_name, param in self.parameters.items():
if param._partial_kwarg and param_name not in kwargs:
# Simulating 'functools.partial' behavior
kwargs[param_name] = param.default
while True:
# Let's iterate through the positional arguments and corresponding
# parameters
try:
arg_val = next(arg_vals)
except StopIteration:
# No more positional arguments
try:
param = next(parameters)
except StopIteration:
# No more parameters. That's it. Just need to check that
# we have no `kwargs` after this while loop
break
else:
if param.kind == _VAR_POSITIONAL:
# That's OK, just empty *args. Let's start parsing
# kwargs
break
elif param.name in kwargs:
if param.kind == _POSITIONAL_ONLY:
msg = (
"{arg!r} parameter is positional only, "
"but was passed as a keyword"
)
msg = msg.format(arg=param.name)
raise TypeError(msg)
parameters_ex = (param,)
break
elif param.kind == _VAR_KEYWORD or param.default is not _empty:
# That's fine too - we have a default value for this
# parameter. So, lets start parsing `kwargs`, starting
# with the current parameter
parameters_ex = (param,)
break
else:
if partial:
parameters_ex = (param,)
break
else:
msg = "{arg!r} parameter lacking default value"
msg = msg.format(arg=param.name)
raise TypeError(msg)
else:
# We have a positional argument to process
try:
param = next(parameters)
except StopIteration:
raise TypeError("too many positional arguments")
else:
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
# Looks like we have no parameter for this positional
# argument
raise TypeError("too many positional arguments")
if param.kind == _VAR_POSITIONAL:
# We have an '*args'-like argument, let's fill it with
# all positional arguments we have left and move on to
# the next phase
values = [arg_val]
values.extend(arg_vals)
arguments[param.name] = tuple(values)
break
if param.name in kwargs:
raise TypeError(
"multiple values for argument "
"{arg!r}".format(arg=param.name)
)
arguments[param.name] = arg_val
# Now, we iterate through the remaining parameters to process
# keyword arguments
kwargs_param = None
for param in itertools.chain(parameters_ex, parameters):
if param.kind == _POSITIONAL_ONLY:
# This should never happen in case of a properly built
# Signature object (but let's have this check here
# to ensure correct behaviour just in case)
raise TypeError(
"{arg!r} parameter is positional only, "
"but was passed as a keyword".format(arg=param.name)
)
if param.kind == _VAR_KEYWORD:
# Memorize that we have a '**kwargs'-like parameter
kwargs_param = param
continue
param_name = param.name
try:
arg_val = kwargs.pop(param_name)
except KeyError:
# We have no value for this parameter. It's fine though,
# if it has a default value, or it is an '*args'-like
# parameter, left alone by the processing of positional
# arguments.
if (
not partial
and param.kind != _VAR_POSITIONAL
and param.default is _empty
):
raise TypeError(
"{arg!r} parameter lacking default value".format(arg=param_name)
)
else:
arguments[param_name] = arg_val
if kwargs:
if kwargs_param is not None:
# Process our '**kwargs'-like parameter
arguments[kwargs_param.name] = kwargs
else:
raise TypeError("too many keyword arguments")
return self._bound_arguments_cls(self, arguments)
def bind(self, *args, **kwargs):
"""Get a BoundArguments object, that maps the passed `args`
and `kwargs` to the function's signature. Raises `TypeError`
if the passed arguments can not be bound.
"""
return self._bind(args, kwargs)
def bind_partial(self, *args, **kwargs):
"""Get a BoundArguments object, that partially maps the
passed `args` and `kwargs` to the function's signature.
Raises `TypeError` if the passed arguments can not be bound.
"""
return self._bind(args, kwargs, partial=True)
def __str__(self):
result = []
render_kw_only_separator = True
for idx, param in enumerate(self.parameters.values()):
formatted = str(param)
kind = param.kind
if kind == _VAR_POSITIONAL:
# OK, we have an '*args'-like parameter, so we won't need
# a '*' to separate keyword-only arguments
render_kw_only_separator = False
elif kind == _KEYWORD_ONLY and render_kw_only_separator:
# We have a keyword-only parameter to render and we haven't
# rendered an '*args'-like parameter before, so add a '*'
# separator to the parameters list ("foo(arg1, *, arg2)" case)
result.append("*")
# This condition should be only triggered once, so
# reset the flag
render_kw_only_separator = False
result.append(formatted)
rendered = "({})".format(", ".join(result))
if self.return_annotation is not _empty:
anno = formatannotation(self.return_annotation)
rendered += " -> {}".format(anno)
return rendered

View File

@ -1,8 +1,8 @@
import re
from collections import Iterable, OrderedDict
from collections.abc import Iterable
from functools import partial
from graphql_relay import connection_from_list
from graphql_relay import connection_from_array
from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union
from ..types.field import Field
@ -41,6 +41,17 @@ class PageInfo(ObjectType):
)
# noinspection PyPep8Naming
def page_info_adapter(startCursor, endCursor, hasPreviousPage, hasNextPage):
"""Adapter for creating PageInfo instances"""
return PageInfo(
start_cursor=startCursor,
end_cursor=endCursor,
has_previous_page=hasPreviousPage,
has_next_page=hasNextPage,
)
class ConnectionOptions(ObjectTypeOptions):
node = None
@ -66,7 +77,7 @@ class Connection(ObjectType):
edge_class = getattr(cls, "Edge", None)
_node = node
class EdgeBase(object):
class EdgeBase:
node = Field(_node, description="The item at the end of the edge")
cursor = String(required=True, description="A cursor for use in pagination")
@ -86,31 +97,29 @@ class Connection(ObjectType):
options["name"] = name
_meta.node = node
_meta.fields = OrderedDict(
[
(
"page_info",
Field(
PageInfo,
name="pageInfo",
required=True,
description="Pagination data for this connection.",
),
),
(
"edges",
Field(
NonNull(List(edge)),
description="Contains the nodes in this connection.",
),
),
]
)
_meta.fields = {
"page_info": Field(
PageInfo,
name="pageInfo",
required=True,
description="Pagination data for this connection.",
),
"edges": Field(
NonNull(List(edge)),
description="Contains the nodes in this connection.",
),
}
return super(Connection, cls).__init_subclass_with_meta__(
_meta=_meta, **options
)
# noinspection PyPep8Naming
def connection_adapter(cls, edges, pageInfo):
"""Adapter for creating Connection instances"""
return cls(edges=edges, page_info=pageInfo)
class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs):
kwargs.setdefault("before", String())
@ -133,7 +142,7 @@ class IterableConnectionField(Field):
)
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".'
'{} type has to be a subclass of Connection. Received "{}".'
).format(self.__class__.__name__, connection_type)
return type
@ -143,15 +152,15 @@ class IterableConnectionField(Field):
return resolved
assert isinstance(resolved, Iterable), (
"Resolved value from the connection field have to be iterable or instance of {}. "
"Resolved value from the connection field has to be an iterable or instance of {}. "
'Received "{}"'
).format(connection_type, resolved)
connection = connection_from_list(
connection = connection_from_array(
resolved,
args,
connection_type=connection_type,
connection_type=partial(connection_adapter, connection_type),
edge_type=connection_type.Edge,
pageinfo_type=PageInfo,
page_info_type=page_info_adapter,
)
connection.iterable = resolved
return connection

View File

@ -1,5 +1,4 @@
import re
from collections import OrderedDict
from ..types import Field, InputObjectType, String
from ..types.mutation import Mutation
@ -30,12 +29,10 @@ class ClientIDMutation(Mutation):
cls.Input = type(
"{}Input".format(base_name),
bases,
OrderedDict(
input_fields, client_mutation_id=String(name="clientMutationId")
),
dict(input_fields, client_mutation_id=String(name="clientMutationId")),
)
arguments = OrderedDict(
arguments = dict(
input=cls.Input(required=True)
# 'client_mutation_id': String(name='clientMutationId')
)

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from functools import partial
from inspect import isclass
@ -72,9 +71,7 @@ class AbstractNode(Interface):
@classmethod
def __init_subclass_with_meta__(cls, **options):
_meta = InterfaceOptions(cls)
_meta.fields = OrderedDict(
id=GlobalID(cls, description="The ID of the object.")
)
_meta.fields = {"id": GlobalID(cls, description="The ID of the object")}
super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options)

View File

@ -1,4 +1,4 @@
import pytest
from pytest import raises
from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String
from ..connection import Connection, ConnectionField, PageInfo
@ -24,7 +24,7 @@ def test_connection():
assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ["page_info", "edges", "extra"]
assert list(fields) == ["page_info", "edges", "extra"]
edge_field = fields["edges"]
pageinfo_field = fields["page_info"]
@ -39,7 +39,7 @@ def test_connection():
def test_connection_inherit_abstracttype():
class BaseConnection(object):
class BaseConnection:
extra = String()
class MyObjectConnection(BaseConnection, Connection):
@ -48,13 +48,13 @@ def test_connection_inherit_abstracttype():
assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ["page_info", "edges", "extra"]
assert list(fields) == ["page_info", "edges", "extra"]
def test_connection_name():
custom_name = "MyObjectCustomNameConnection"
class BaseConnection(object):
class BaseConnection:
extra = String()
class MyObjectConnection(BaseConnection, Connection):
@ -76,7 +76,7 @@ def test_edge():
Edge = MyObjectConnection.Edge
assert Edge._meta.name == "MyObjectEdge"
edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ["node", "cursor", "other"]
assert list(edge_fields) == ["node", "cursor", "other"]
assert isinstance(edge_fields["node"], Field)
assert edge_fields["node"].type == MyObject
@ -86,7 +86,7 @@ def test_edge():
def test_edge_with_bases():
class BaseEdge(object):
class BaseEdge:
extra = String()
class MyObjectConnection(Connection):
@ -99,7 +99,7 @@ def test_edge_with_bases():
Edge = MyObjectConnection.Edge
assert Edge._meta.name == "MyObjectEdge"
edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ["node", "cursor", "extra", "other"]
assert list(edge_fields) == ["node", "cursor", "extra", "other"]
assert isinstance(edge_fields["node"], Field)
assert edge_fields["node"].type == MyObject
@ -122,7 +122,7 @@ def test_edge_with_nonnull_node():
def test_pageinfo():
assert PageInfo._meta.name == "PageInfo"
fields = PageInfo._meta.fields
assert list(fields.keys()) == [
assert list(fields) == [
"has_next_page",
"has_previous_page",
"start_cursor",
@ -146,7 +146,7 @@ def test_connectionfield():
def test_connectionfield_node_deprecated():
field = ConnectionField(MyObject)
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
field.type
assert "ConnectionFields now need a explicit ConnectionType for Nodes." in str(

View File

@ -1,7 +1,6 @@
from collections import OrderedDict
from pytest import mark
from graphql_relay.utils import base64
from promise import Promise
from ...types import ObjectType, Schema, String
from ..connection import Connection, ConnectionField, PageInfo
@ -25,15 +24,15 @@ class LetterConnection(Connection):
class Query(ObjectType):
letters = ConnectionField(LetterConnection)
connection_letters = ConnectionField(LetterConnection)
promise_letters = ConnectionField(LetterConnection)
async_letters = ConnectionField(LetterConnection)
node = Node.Field()
def resolve_letters(self, info, **args):
return list(letters.values())
def resolve_promise_letters(self, info, **args):
return Promise.resolve(list(letters.values()))
async def resolve_async_letters(self, info, **args):
return list(letters.values())
def resolve_connection_letters(self, info, **args):
return LetterConnection(
@ -46,9 +45,7 @@ class Query(ObjectType):
schema = Schema(Query)
letters = OrderedDict()
for i, letter in enumerate(letter_chars):
letters[letter] = Letter(id=i, letter=letter)
letters = {letter: Letter(id=i, letter=letter) for i, letter in enumerate(letter_chars)}
def edges(selected_letters):
@ -66,11 +63,11 @@ def cursor_for(ltr):
return base64("arrayconnection:%s" % letter.id)
def execute(args=""):
async def execute(args=""):
if args:
args = "(" + args + ")"
return schema.execute(
return await schema.execute_async(
"""
{
letters%s {
@ -94,8 +91,8 @@ def execute(args=""):
)
def check(args, letters, has_previous_page=False, has_next_page=False):
result = execute(args)
async def check(args, letters, has_previous_page=False, has_next_page=False):
result = await execute(args)
expected_edges = edges(letters)
expected_page_info = {
"hasPreviousPage": has_previous_page,
@ -110,96 +107,118 @@ def check(args, letters, has_previous_page=False, has_next_page=False):
}
def test_returns_all_elements_without_filters():
check("", "ABCDE")
@mark.asyncio
async def test_returns_all_elements_without_filters():
await check("", "ABCDE")
def test_respects_a_smaller_first():
check("first: 2", "AB", has_next_page=True)
@mark.asyncio
async def test_respects_a_smaller_first():
await check("first: 2", "AB", has_next_page=True)
def test_respects_an_overly_large_first():
check("first: 10", "ABCDE")
@mark.asyncio
async def test_respects_an_overly_large_first():
await check("first: 10", "ABCDE")
def test_respects_a_smaller_last():
check("last: 2", "DE", has_previous_page=True)
@mark.asyncio
async def test_respects_a_smaller_last():
await check("last: 2", "DE", has_previous_page=True)
def test_respects_an_overly_large_last():
check("last: 10", "ABCDE")
@mark.asyncio
async def test_respects_an_overly_large_last():
await check("last: 10", "ABCDE")
def test_respects_first_and_after():
check('first: 2, after: "{}"'.format(cursor_for("B")), "CD", has_next_page=True)
@mark.asyncio
async def test_respects_first_and_after():
await check(
'first: 2, after: "{}"'.format(cursor_for("B")), "CD", has_next_page=True
)
def test_respects_first_and_after_with_long_first():
check('first: 10, after: "{}"'.format(cursor_for("B")), "CDE")
@mark.asyncio
async def test_respects_first_and_after_with_long_first():
await check('first: 10, after: "{}"'.format(cursor_for("B")), "CDE")
def test_respects_last_and_before():
check('last: 2, before: "{}"'.format(cursor_for("D")), "BC", has_previous_page=True)
@mark.asyncio
async def test_respects_last_and_before():
await check(
'last: 2, before: "{}"'.format(cursor_for("D")), "BC", has_previous_page=True
)
def test_respects_last_and_before_with_long_last():
check('last: 10, before: "{}"'.format(cursor_for("D")), "ABC")
@mark.asyncio
async def test_respects_last_and_before_with_long_last():
await check('last: 10, before: "{}"'.format(cursor_for("D")), "ABC")
def test_respects_first_and_after_and_before_too_few():
check(
@mark.asyncio
async def test_respects_first_and_after_and_before_too_few():
await check(
'first: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BC",
has_next_page=True,
)
def test_respects_first_and_after_and_before_too_many():
check(
@mark.asyncio
async def test_respects_first_and_after_and_before_too_many():
await check(
'first: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_respects_first_and_after_and_before_exactly_right():
check(
@mark.asyncio
async def test_respects_first_and_after_and_before_exactly_right():
await check(
'first: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_respects_last_and_after_and_before_too_few():
check(
@mark.asyncio
async def test_respects_last_and_after_and_before_too_few():
await check(
'last: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"CD",
has_previous_page=True,
)
def test_respects_last_and_after_and_before_too_many():
check(
@mark.asyncio
async def test_respects_last_and_after_and_before_too_many():
await check(
'last: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_respects_last_and_after_and_before_exactly_right():
check(
@mark.asyncio
async def test_respects_last_and_after_and_before_exactly_right():
await check(
'last: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD",
)
def test_returns_no_elements_if_first_is_0():
check("first: 0", "", has_next_page=True)
@mark.asyncio
async def test_returns_no_elements_if_first_is_0():
await check("first: 0", "", has_next_page=True)
def test_returns_all_elements_if_cursors_are_invalid():
check('before: "invalid" after: "invalid"', "ABCDE")
@mark.asyncio
async def test_returns_all_elements_if_cursors_are_invalid():
await check('before: "invalid" after: "invalid"', "ABCDE")
def test_returns_all_elements_if_cursors_are_on_the_outside():
check(
@mark.asyncio
async def test_returns_all_elements_if_cursors_are_on_the_outside():
await check(
'before: "{}" after: "{}"'.format(
base64("arrayconnection:%s" % 6), base64("arrayconnection:%s" % -1)
),
@ -207,8 +226,9 @@ def test_returns_all_elements_if_cursors_are_on_the_outside():
)
def test_returns_no_elements_if_cursors_cross():
check(
@mark.asyncio
async def test_returns_no_elements_if_cursors_cross():
await check(
'before: "{}" after: "{}"'.format(
base64("arrayconnection:%s" % 2), base64("arrayconnection:%s" % 4)
),
@ -216,8 +236,9 @@ def test_returns_no_elements_if_cursors_cross():
)
def test_connection_type_nodes():
result = schema.execute(
@mark.asyncio
async def test_connection_type_nodes():
result = await schema.execute_async(
"""
{
connectionLetters {
@ -248,11 +269,12 @@ def test_connection_type_nodes():
}
def test_connection_promise():
result = schema.execute(
@mark.asyncio
async def test_connection_async():
result = await schema.execute_async(
"""
{
promiseLetters(first:1) {
asyncLetters(first:1) {
edges {
node {
id
@ -270,7 +292,7 @@ def test_connection_promise():
assert not result.errors
assert result.data == {
"promiseLetters": {
"asyncLetters": {
"edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}],
"pageInfo": {"hasPreviousPage": False, "hasNextPage": True},
}

View File

@ -17,7 +17,7 @@ class User(ObjectType):
name = String()
class Info(object):
class Info:
def __init__(self, parent_type):
self.parent_type = GrapheneObjectType(
graphene_type=parent_type,

View File

@ -1,5 +1,4 @@
import pytest
from promise import Promise
from pytest import mark, raises
from ...types import (
ID,
@ -15,7 +14,7 @@ from ...types.scalars import String
from ..mutation import ClientIDMutation
class SharedFields(object):
class SharedFields:
shared = String()
@ -37,7 +36,7 @@ class SaySomething(ClientIDMutation):
return SaySomething(phrase=str(what))
class FixedSaySomething(object):
class FixedSaySomething:
__slots__ = ("phrase",)
def __init__(self, phrase):
@ -55,15 +54,15 @@ class SaySomethingFixed(ClientIDMutation):
return FixedSaySomething(phrase=str(what))
class SaySomethingPromise(ClientIDMutation):
class SaySomethingAsync(ClientIDMutation):
class Input:
what = String()
phrase = String()
@staticmethod
def mutate_and_get_payload(self, info, what, client_mutation_id=None):
return Promise.resolve(SaySomething(phrase=str(what)))
async def mutate_and_get_payload(self, info, what, client_mutation_id=None):
return SaySomething(phrase=str(what))
# MyEdge = MyNode.Connection.Edge
@ -81,11 +80,11 @@ class OtherMutation(ClientIDMutation):
@staticmethod
def mutate_and_get_payload(
self, info, shared="", additional_field="", client_mutation_id=None
self, info, shared, additional_field, client_mutation_id=None
):
edge_type = MyEdge
return OtherMutation(
name=shared + additional_field,
name=(shared or "") + (additional_field or ""),
my_node_edge=edge_type(cursor="1", node=MyNode(name="name")),
)
@ -97,7 +96,7 @@ class RootQuery(ObjectType):
class Mutation(ObjectType):
say = SaySomething.Field()
say_fixed = SaySomethingFixed.Field()
say_promise = SaySomethingPromise.Field()
say_async = SaySomethingAsync.Field()
other = OtherMutation.Field()
@ -105,7 +104,7 @@ schema = Schema(query=RootQuery, mutation=Mutation)
def test_no_mutate_and_get_payload():
with pytest.raises(AssertionError) as excinfo:
with raises(AssertionError) as excinfo:
class MyMutation(ClientIDMutation):
pass
@ -118,12 +117,12 @@ def test_no_mutate_and_get_payload():
def test_mutation():
fields = SaySomething._meta.fields
assert list(fields.keys()) == ["phrase", "client_mutation_id"]
assert list(fields) == ["phrase", "client_mutation_id"]
assert SaySomething._meta.name == "SaySomethingPayload"
assert isinstance(fields["phrase"], Field)
field = SaySomething.Field()
assert field.type == SaySomething
assert list(field.args.keys()) == ["input"]
assert list(field.args) == ["input"]
assert isinstance(field.args["input"], Argument)
assert isinstance(field.args["input"].type, NonNull)
assert field.args["input"].type.of_type == SaySomething.Input
@ -136,7 +135,7 @@ def test_mutation_input():
Input = SaySomething.Input
assert issubclass(Input, InputObjectType)
fields = Input._meta.fields
assert list(fields.keys()) == ["what", "client_mutation_id"]
assert list(fields) == ["what", "client_mutation_id"]
assert isinstance(fields["what"], InputField)
assert fields["what"].type == String
assert isinstance(fields["client_mutation_id"], InputField)
@ -145,11 +144,11 @@ def test_mutation_input():
def test_subclassed_mutation():
fields = OtherMutation._meta.fields
assert list(fields.keys()) == ["name", "my_node_edge", "client_mutation_id"]
assert list(fields) == ["name", "my_node_edge", "client_mutation_id"]
assert isinstance(fields["name"], Field)
field = OtherMutation.Field()
assert field.type == OtherMutation
assert list(field.args.keys()) == ["input"]
assert list(field.args) == ["input"]
assert isinstance(field.args["input"], Argument)
assert isinstance(field.args["input"].type, NonNull)
assert field.args["input"].type.of_type == OtherMutation.Input
@ -159,7 +158,7 @@ def test_subclassed_mutation_input():
Input = OtherMutation.Input
assert issubclass(Input, InputObjectType)
fields = Input._meta.fields
assert list(fields.keys()) == ["shared", "additional_field", "client_mutation_id"]
assert list(fields) == ["shared", "additional_field", "client_mutation_id"]
assert isinstance(fields["shared"], InputField)
assert fields["shared"].type == String
assert isinstance(fields["additional_field"], InputField)
@ -185,12 +184,13 @@ def test_node_query_fixed():
)
def test_node_query_promise():
executed = schema.execute(
'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
@mark.asyncio
async def test_node_query_async():
executed = await schema.execute_async(
'mutation a { sayAsync(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
)
assert not executed.errors
assert executed.data == {"sayPromise": {"phrase": "hello"}}
assert executed.data == {"sayAsync": {"phrase": "hello"}}
def test_edge_query():

View File

@ -1,12 +1,12 @@
from collections import OrderedDict
from graphql_relay import to_global_id
from graphql.pyutils import dedent
from ...types import ObjectType, Schema, String
from ..node import Node, is_node
class SharedNodeFields(object):
class SharedNodeFields:
shared = String()
something_else = String()
@ -70,17 +70,13 @@ def test_subclassed_node_query():
% to_global_id("MyOtherNode", 1)
)
assert not executed.errors
assert executed.data == OrderedDict(
{
"node": OrderedDict(
[
("shared", "1"),
("extraField", "extra field info."),
("somethingElse", "----"),
]
)
assert executed.data == {
"node": {
"shared": "1",
"extraField": "extra field info.",
"somethingElse": "----",
}
)
}
def test_node_requesting_non_node():
@ -124,7 +120,7 @@ def test_node_field_only_type_wrong():
% Node.to_global_id("MyOtherNode", 1)
)
assert len(executed.errors) == 1
assert str(executed.errors[0]) == "Must receive a MyNode id."
assert str(executed.errors[0]).startswith("Must receive a MyNode id.")
assert executed.data == {"onlyNode": None}
@ -143,39 +139,48 @@ def test_node_field_only_lazy_type_wrong():
% Node.to_global_id("MyOtherNode", 1)
)
assert len(executed.errors) == 1
assert str(executed.errors[0]) == "Must receive a MyNode id."
assert str(executed.errors[0]).startswith("Must receive a MyNode id.")
assert executed.data == {"onlyNodeLazy": None}
def test_str_schema():
assert (
str(schema)
== """
schema {
query: RootQuery
}
assert str(schema) == dedent(
'''
schema {
query: RootQuery
}
type MyNode implements Node {
id: ID!
name: String
}
type MyNode implements Node {
"""The ID of the object"""
id: ID!
name: String
}
type MyOtherNode implements Node {
id: ID!
shared: String
somethingElse: String
extraField: String
}
type MyOtherNode implements Node {
"""The ID of the object"""
id: ID!
shared: String
somethingElse: String
extraField: String
}
interface Node {
id: ID!
}
"""An object with an ID"""
interface Node {
"""The ID of the object"""
id: ID!
}
type RootQuery {
first: String
node(id: ID!): Node
onlyNode(id: ID!): MyNode
onlyNodeLazy(id: ID!): MyNode
}
""".lstrip()
type RootQuery {
first: String
"""The ID of the object"""
node(id: ID!): Node
"""The ID of the object"""
onlyNode(id: ID!): MyNode
"""The ID of the object"""
onlyNodeLazy(id: ID!): MyNode
}
'''
)

View File

@ -1,4 +1,5 @@
from graphql import graphql
from graphql import graphql_sync
from graphql.pyutils import dedent
from ...types import Interface, ObjectType, Schema
from ...types.scalars import Int, String
@ -15,7 +16,7 @@ class CustomNode(Node):
@staticmethod
def get_node_from_global_id(info, id, only_type=None):
assert info.schema == schema
assert info.schema is graphql_schema
if id in user_data:
return user_data.get(id)
else:
@ -23,14 +24,14 @@ class CustomNode(Node):
class BasePhoto(Interface):
width = Int()
width = Int(description="The width of the photo in pixels")
class User(ObjectType):
class Meta:
interfaces = [CustomNode]
name = String()
name = String(description="The full name of the user")
class Photo(ObjectType):
@ -48,37 +49,47 @@ class RootQuery(ObjectType):
schema = Schema(query=RootQuery, types=[User, Photo])
graphql_schema = schema.graphql_schema
def test_str_schema_correct():
assert (
str(schema)
== """schema {
query: RootQuery
}
assert str(schema) == dedent(
'''
schema {
query: RootQuery
}
interface BasePhoto {
width: Int
}
interface BasePhoto {
"""The width of the photo in pixels"""
width: Int
}
interface Node {
id: ID!
}
interface Node {
"""The ID of the object"""
id: ID!
}
type Photo implements Node, BasePhoto {
id: ID!
width: Int
}
type Photo implements Node & BasePhoto {
"""The ID of the object"""
id: ID!
type RootQuery {
node(id: ID!): Node
}
"""The width of the photo in pixels"""
width: Int
}
type User implements Node {
id: ID!
name: String
}
"""
type RootQuery {
"""The ID of the object"""
node(id: ID!): Node
}
type User implements Node {
"""The ID of the object"""
id: ID!
"""The full name of the user"""
name: String
}
'''
)
@ -91,7 +102,7 @@ def test_gets_the_correct_id_for_users():
}
"""
expected = {"node": {"id": "1"}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -105,7 +116,7 @@ def test_gets_the_correct_id_for_photos():
}
"""
expected = {"node": {"id": "4"}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -122,7 +133,7 @@ def test_gets_the_correct_name_for_users():
}
"""
expected = {"node": {"id": "1", "name": "John Doe"}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -139,7 +150,7 @@ def test_gets_the_correct_width_for_photos():
}
"""
expected = {"node": {"id": "4", "width": 400}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -154,7 +165,7 @@ def test_gets_the_correct_typename_for_users():
}
"""
expected = {"node": {"id": "1", "__typename": "User"}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -169,7 +180,7 @@ def test_gets_the_correct_typename_for_photos():
}
"""
expected = {"node": {"id": "4", "__typename": "Photo"}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -186,7 +197,7 @@ def test_ignores_photo_fragments_on_user():
}
"""
expected = {"node": {"id": "1"}}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -200,7 +211,7 @@ def test_returns_null_for_bad_ids():
}
"""
expected = {"node": None}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -239,7 +250,7 @@ def test_have_correct_node_interface():
],
}
}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected
@ -291,6 +302,6 @@ def test_has_correct_node_root_field():
}
}
}
result = graphql(schema, query)
result = graphql_sync(graphql_schema, query)
assert not result.errors
assert result.data == expected

View File

@ -8,24 +8,19 @@ from graphene.types.schema import Schema
def default_format_error(error):
if isinstance(error, GraphQLError):
return format_graphql_error(error)
return {"message": str(error)}
def format_execution_result(execution_result, format_error):
if execution_result:
response = {}
if execution_result.errors:
response["errors"] = [format_error(e) for e in execution_result.errors]
if not execution_result.invalid:
response["data"] = execution_result.data
response["data"] = execution_result.data
return response
class Client(object):
class Client:
def __init__(self, schema, format_error=None, **execute_options):
assert isinstance(schema, Schema)
self.schema = schema

View File

@ -21,7 +21,7 @@ class CreatePostResult(graphene.Union):
class CreatePost(graphene.Mutation):
class Input:
class Arguments:
text = graphene.String(required=True)
result = graphene.Field(CreatePostResult)

View File

@ -1,6 +1,6 @@
# https://github.com/graphql-python/graphene/issues/356
import pytest
from pytest import raises
import graphene
from graphene import relay
@ -23,10 +23,11 @@ def test_issue():
class Query(graphene.ObjectType):
things = relay.ConnectionField(MyUnion)
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
graphene.Schema(query=Query)
assert str(exc_info.value) == (
"IterableConnectionField type have to be a subclass of Connection. "
'Received "MyUnion".'
"Query fields cannot be resolved:"
" IterableConnectionField type has to be a subclass of Connection."
' Received "MyUnion".'
)

View File

@ -1,5 +1,5 @@
# flake8: noqa
from graphql import ResolveInfo
from graphql import GraphQLResolveInfo as ResolveInfo
from .objecttype import ObjectType
from .interface import Interface

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from itertools import chain
from .dynamic import Dynamic
@ -81,7 +80,7 @@ def to_arguments(args, extra_args=None):
else:
extra_args = []
iter_arguments = chain(args.items(), extra_args)
arguments = OrderedDict()
arguments = {}
for default_name, arg in iter_arguments:
if isinstance(arg, Dynamic):
arg = arg.get_type()

View File

@ -4,7 +4,7 @@ from ..utils.subclass_with_meta import SubclassWithMeta
from ..utils.trim_docstring import trim_docstring
class BaseOptions(object):
class BaseOptions:
name = None # type: str
description = None # type: str

View File

@ -1,4 +1,4 @@
class Context(object):
class Context:
"""
Context can be used to make a convenient container for attributes to provide
for execution for resolvers of a GraphQL operation like a query.

View File

@ -3,7 +3,8 @@ from __future__ import absolute_import
import datetime
from aniso8601 import parse_date, parse_datetime, parse_time
from graphql.language import ast
from graphql.error import INVALID
from graphql.language import StringValueNode
from .scalars import Scalar
@ -26,7 +27,7 @@ class Date(Scalar):
@classmethod
def parse_literal(cls, node):
if isinstance(node, ast.StringValue):
if isinstance(node, StringValueNode):
return cls.parse_value(node.value)
@staticmethod
@ -37,7 +38,7 @@ class Date(Scalar):
elif isinstance(value, str):
return parse_date(value)
except ValueError:
return None
return INVALID
class DateTime(Scalar):
@ -56,7 +57,7 @@ class DateTime(Scalar):
@classmethod
def parse_literal(cls, node):
if isinstance(node, ast.StringValue):
if isinstance(node, StringValueNode):
return cls.parse_value(node.value)
@staticmethod
@ -67,7 +68,7 @@ class DateTime(Scalar):
elif isinstance(value, str):
return parse_datetime(value)
except ValueError:
return None
return INVALID
class Time(Scalar):
@ -86,7 +87,7 @@ class Time(Scalar):
@classmethod
def parse_literal(cls, node):
if isinstance(node, ast.StringValue):
if isinstance(node, StringValueNode):
return cls.parse_value(node.value)
@classmethod
@ -97,4 +98,4 @@ class Time(Scalar):
elif isinstance(value, str):
return parse_time(value)
except ValueError:
return None
return INVALID

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
from decimal import Decimal as _Decimal
from graphql.language import ast
from graphql.language.ast import StringValueNode
from .scalars import Scalar
@ -23,7 +23,7 @@ class Decimal(Scalar):
@classmethod
def parse_literal(cls, node):
if isinstance(node, ast.StringValue):
if isinstance(node, StringValueNode):
return cls.parse_value(node.value)
@staticmethod

View File

@ -8,7 +8,7 @@ from graphql import (
)
class GrapheneGraphQLType(object):
class GrapheneGraphQLType:
"""
A class for extending the base GraphQLType with the related
graphene_type

View File

@ -1,7 +1,7 @@
from collections import OrderedDict
from enum import Enum as PyEnum
from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta
from ..pyutils.compat import Enum as PyEnum
from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
@ -22,13 +22,13 @@ class EnumOptions(BaseOptions):
class EnumMeta(SubclassWithMeta_Meta):
def __new__(cls, name, bases, classdict, **options):
enum_members = OrderedDict(classdict, __eq__=eq_enum)
enum_members = dict(classdict, __eq__=eq_enum)
# We remove the Meta attribute from the class to not collide
# with the enum values.
enum_members.pop("Meta", None)
enum = PyEnum(cls.__name__, enum_members)
return SubclassWithMeta_Meta.__new__(
cls, name, bases, OrderedDict(classdict, __enum__=enum), **options
cls, name, bases, dict(classdict, __enum__=enum), **options
)
def get(cls, value):
@ -38,7 +38,7 @@ class EnumMeta(SubclassWithMeta_Meta):
return cls._meta.enum[value]
def __prepare__(name, bases, **kwargs): # noqa: N805
return OrderedDict()
return {}
def __call__(cls, *args, **kwargs): # noqa: N805
if cls is Enum:

View File

@ -1,5 +1,5 @@
import inspect
from collections import Mapping, OrderedDict
from collections.abc import Mapping
from functools import partial
from .argument import Argument, to_arguments
@ -100,7 +100,7 @@ class Field(MountedType):
self.name = name
self._type = type
self.args = to_arguments(args or OrderedDict(), extra_args)
self.args = to_arguments(args or {}, extra_args)
if source:
resolver = partial(source_resolver, source)
self.resolver = resolver

View File

@ -1,12 +1,12 @@
from __future__ import unicode_literals
from graphql.language.ast import (
BooleanValue,
FloatValue,
IntValue,
ListValue,
ObjectValue,
StringValue,
BooleanValueNode,
FloatValueNode,
IntValueNode,
ListValueNode,
ObjectValueNode,
StringValueNode,
)
from graphene.types.scalars import MAX_INT, MIN_INT
@ -30,17 +30,17 @@ class GenericScalar(Scalar):
@staticmethod
def parse_literal(ast):
if isinstance(ast, (StringValue, BooleanValue)):
if isinstance(ast, (StringValueNode, BooleanValueNode)):
return ast.value
elif isinstance(ast, IntValue):
elif isinstance(ast, IntValueNode):
num = int(ast.value)
if MIN_INT <= num <= MAX_INT:
return num
elif isinstance(ast, FloatValue):
elif isinstance(ast, FloatValueNode):
return float(ast.value)
elif isinstance(ast, ListValue):
elif isinstance(ast, ListValueNode):
return [GenericScalar.parse_literal(value) for value in ast.values]
elif isinstance(ast, ObjectValue):
elif isinstance(ast, ObjectValueNode):
return {
field.name.value: GenericScalar.parse_literal(field.value)
for field in ast.fields

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from .base import BaseOptions, BaseType
from .inputfield import InputField
from .unmountedtype import UnmountedType
@ -22,7 +20,7 @@ class InputObjectTypeContainer(dict, BaseType):
def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
for key in self._meta.fields.keys():
for key in self._meta.fields:
setattr(self, key, self.get(key, None))
def __init_subclass__(cls, *args, **kwargs):
@ -70,7 +68,7 @@ class InputObjectType(UnmountedType, BaseType):
if not _meta:
_meta = InputObjectTypeOptions(cls)
fields = OrderedDict()
fields = {}
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=InputField))

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from .base import BaseOptions, BaseType
from .field import Field
from .utils import yank_fields_from_attrs
@ -51,7 +49,7 @@ class Interface(BaseType):
if not _meta:
_meta = InterfaceOptions(cls)
fields = OrderedDict()
fields = {}
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
import json
from graphql.language import ast
from graphql.language.ast import StringValueNode
from .scalars import Scalar
@ -21,7 +21,7 @@ class JSONString(Scalar):
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
if isinstance(node, StringValueNode):
return json.loads(node.value)
@staticmethod

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from ..utils.deprecated import warn_deprecation
from ..utils.get_unbound_function import get_unbound_function
from ..utils.props import props
@ -90,7 +88,7 @@ class Mutation(ObjectType):
if not output:
# If output is defined, we don't need to get the fields
fields = OrderedDict()
fields = {}
for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
output = cls
@ -103,7 +101,7 @@ class Mutation(ObjectType):
warn_deprecation(
(
"Please use {name}.Arguments instead of {name}.Input."
"Input is now only used in ClientMutationID.\n"
" Input is now only used in ClientMutationID.\n"
"Read more:"
" https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input"
).format(name=cls.__name__)

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from .base import BaseOptions, BaseType
from .field import Field
from .interface import Interface
@ -100,7 +98,7 @@ class ObjectType(BaseType):
if not _meta:
_meta = ObjectTypeOptions(cls)
fields = OrderedDict()
fields = {}
for interface in interfaces:
assert issubclass(interface, Interface), (

View File

@ -1,6 +1,11 @@
from typing import Any
from graphql.language.ast import BooleanValue, FloatValue, IntValue, StringValue
from graphql.language.ast import (
BooleanValueNode,
FloatValueNode,
IntValueNode,
StringValueNode,
)
from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType
@ -71,7 +76,7 @@ class Int(Scalar):
@staticmethod
def parse_literal(ast):
if isinstance(ast, IntValue):
if isinstance(ast, IntValueNode):
num = int(ast.value)
if MIN_INT <= num <= MAX_INT:
return num
@ -97,7 +102,7 @@ class Float(Scalar):
@staticmethod
def parse_literal(ast):
if isinstance(ast, (FloatValue, IntValue)):
if isinstance(ast, (FloatValueNode, IntValueNode)):
return float(ast.value)
@ -119,7 +124,7 @@ class String(Scalar):
@staticmethod
def parse_literal(ast):
if isinstance(ast, StringValue):
if isinstance(ast, StringValueNode):
return ast.value
@ -133,7 +138,7 @@ class Boolean(Scalar):
@staticmethod
def parse_literal(ast):
if isinstance(ast, BooleanValue):
if isinstance(ast, BooleanValueNode):
return ast.value
@ -151,5 +156,5 @@ class ID(Scalar):
@staticmethod
def parse_literal(ast):
if isinstance(ast, (StringValue, IntValue)):
if isinstance(ast, (StringValueNode, IntValueNode)):
return ast.value

View File

@ -1,104 +1,133 @@
import inspect
from functools import partial
from graphql import GraphQLObjectType, GraphQLSchema, graphql, is_type
from graphql.type.directives import (
GraphQLDirective,
GraphQLIncludeDirective,
GraphQLSkipDirective,
from graphql import (
default_type_resolver,
get_introspection_query,
graphql,
graphql_sync,
introspection_types,
is_type,
print_schema,
GraphQLArgument,
GraphQLBoolean,
GraphQLEnumValue,
GraphQLField,
GraphQLFloat,
GraphQLID,
GraphQLInputField,
GraphQLInt,
GraphQLList,
GraphQLNonNull,
GraphQLObjectType,
GraphQLSchema,
GraphQLString,
INVALID,
)
from graphql.type.introspection import IntrospectionSchema
from graphql.utils.introspection_query import introspection_query
from graphql.utils.schema_printer import print_schema
from .definitions import GrapheneGraphQLType
from ..utils.str_converters import to_camel_case
from ..utils.get_unbound_function import get_unbound_function
from .definitions import (
GrapheneEnumType,
GrapheneGraphQLType,
GrapheneInputObjectType,
GrapheneInterfaceType,
GrapheneObjectType,
GrapheneScalarType,
GrapheneUnionType,
)
from .dynamic import Dynamic
from .enum import Enum
from .field import Field
from .inputobjecttype import InputObjectType
from .interface import Interface
from .objecttype import ObjectType
from .typemap import TypeMap, is_graphene_type
from .resolver import get_default_resolver
from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull
from .union import Union
from .utils import get_field_as
introspection_query = get_introspection_query()
IntrospectionSchema = introspection_types["__Schema"]
def assert_valid_root_type(_type):
if _type is None:
def assert_valid_root_type(type_):
if type_ is None:
return
is_graphene_objecttype = inspect.isclass(_type) and issubclass(_type, ObjectType)
is_graphql_objecttype = isinstance(_type, GraphQLObjectType)
is_graphene_objecttype = inspect.isclass(type_) and issubclass(type_, ObjectType)
is_graphql_objecttype = isinstance(type_, GraphQLObjectType)
assert is_graphene_objecttype or is_graphql_objecttype, (
"Type {} is not a valid ObjectType."
).format(_type)
).format(type_)
class Schema(GraphQLSchema):
"""
Graphene Schema can execute operations (query, mutation, subscription) against the defined
types.
def is_graphene_type(type_):
if isinstance(type_, (List, NonNull)):
return True
if inspect.isclass(type_) and issubclass(
type_, (ObjectType, InputObjectType, Scalar, Interface, Union, Enum)
):
return True
For advanced purposes, the schema can be used to lookup type definitions and answer questions
about the types through introspection.
Args:
query (ObjectType): Root query *ObjectType*. Describes entry point for fields to *read*
data in your Schema.
mutation (ObjectType, optional): Root mutation *ObjectType*. Describes entry point for
fields to *create, update or delete* data in your API.
subscription (ObjectType, optional): Root subscription *ObjectType*. Describes entry point
for fields to receive continuous updates.
directives (List[GraphQLDirective], optional): List of custom directives to include in
GraphQL schema. Defaults to only include directives definved by GraphQL spec (@include
and @skip) [GraphQLIncludeDirective, GraphQLSkipDirective].
types (List[GraphQLType], optional): List of any types to include in schema that
may not be introspected through root types.
auto_camelcase (bool): Fieldnames will be transformed in Schema's TypeMap from snake_case
to camelCase (preferred by GraphQL standard). Default True.
"""
def resolve_type(resolve_type_func, map_, type_name, root, info, _type):
type_ = resolve_type_func(root, info)
if not type_:
return_type = map_[type_name]
return default_type_resolver(root, info, return_type)
if inspect.isclass(type_) and issubclass(type_, ObjectType):
graphql_type = map_.get(type_._meta.name)
assert graphql_type, "Can't find type {} in schema".format(type_._meta.name)
assert graphql_type.graphene_type == type_, (
"The type {} does not match with the associated graphene type {}."
).format(type_, graphql_type.graphene_type)
return graphql_type
return type_
def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types)
class GrapheneGraphQLSchema(GraphQLSchema):
"""A GraphQLSchema that can deal with Graphene types as well."""
def __init__(
self,
query=None,
mutation=None,
subscription=None,
directives=None,
types=None,
directives=None,
auto_camelcase=True,
):
assert_valid_root_type(query)
assert_valid_root_type(mutation)
assert_valid_root_type(subscription)
self._query = query
self._mutation = mutation
self._subscription = subscription
self.types = types
self.auto_camelcase = auto_camelcase
if directives is None:
directives = [GraphQLIncludeDirective, GraphQLSkipDirective]
super().__init__(query, mutation, subscription, types, directives)
assert all(
isinstance(d, GraphQLDirective) for d in directives
), "Schema directives must be List[GraphQLDirective] if provided but got: {}.".format(
directives
)
self._directives = directives
self.build_typemap()
def get_query_type(self):
return self.get_graphql_type(self._query)
def get_mutation_type(self):
return self.get_graphql_type(self._mutation)
def get_subscription_type(self):
return self.get_graphql_type(self._subscription)
def __getattr__(self, type_name):
"""
This function let the developer select a type in a given schema
by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema
"""
_type = super(Schema, self).get_type(type_name)
if _type is None:
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type
return _type
if query:
self.query_type = self.get_type(
query.name if isinstance(query, GraphQLObjectType) else query._meta.name
)
if mutation:
self.mutation_type = self.get_type(
mutation.name
if isinstance(mutation, GraphQLObjectType)
else mutation._meta.name
)
if subscription:
self.subscription_type = self.get_type(
subscription.name
if isinstance(subscription, GraphQLObjectType)
else subscription._meta.name
)
def get_graphql_type(self, _type):
if not _type:
@ -114,56 +143,383 @@ class Schema(GraphQLSchema):
return graphql_type
raise Exception("{} is not a valid GraphQL type.".format(_type))
def execute(self, *args, **kwargs):
"""
Use the `graphql` function from `graphql-core` to provide the result for a query string.
Most of the time this method will be called by one of the Graphene :ref:`Integrations`
via a web request.
# noinspection PyMethodOverriding
def type_map_reducer(self, map_, type_):
if not type_:
return map_
if inspect.isfunction(type_):
type_ = type_()
if is_graphene_type(type_):
return self.graphene_reducer(map_, type_)
return super().type_map_reducer(map_, type_)
Args:
request_string (str or Document): GraphQL request (query, mutation or subscription) in
string or parsed AST form from `graphql-core`.
root (Any, optional): Value to use as the parent value object when resolving root
types.
context (Any, optional): Value to be made avaiable to all resolvers via
`info.context`. Can be used to share authorization, dataloaders or other
information needed to resolve an operation.
variables (dict, optional): If variables are used in the request string, they can be
provided in dictionary form mapping the variable name to the variable value.
operation_name (str, optional): If mutiple operations are provided in the
request_string, an operation name must be provided for the result to be provided.
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
defined in `graphql-core`.
backend (GraphQLCoreBackend, optional): Override the default GraphQLCoreBackend.
**execute_options (Any): Depends on backend selected. Default backend has several
options such as: validate, allow_subscriptions, return_promise, executor.
def graphene_reducer(self, map_, type_):
if isinstance(type_, (List, NonNull)):
return self.type_map_reducer(map_, type_.of_type)
if type_._meta.name in map_:
_type = map_[type_._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type_, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type_)
return map_
Returns:
:obj:`ExecutionResult` containing any data and errors for the operation.
"""
return graphql(self, *args, **kwargs)
if issubclass(type_, ObjectType):
internal_type = self.construct_objecttype(map_, type_)
elif issubclass(type_, InputObjectType):
internal_type = self.construct_inputobjecttype(map_, type_)
elif issubclass(type_, Interface):
internal_type = self.construct_interface(map_, type_)
elif issubclass(type_, Scalar):
internal_type = self.construct_scalar(type_)
elif issubclass(type_, Enum):
internal_type = self.construct_enum(type_)
elif issubclass(type_, Union):
internal_type = self.construct_union(map_, type_)
else:
raise Exception("Expected Graphene type, but received: {}.".format(type_))
def introspect(self):
instrospection = self.execute(introspection_query)
if instrospection.errors:
raise instrospection.errors[0]
return instrospection.data
return super().type_map_reducer(map_, internal_type)
@staticmethod
def construct_scalar(type_):
# We have a mapping to the original GraphQL types
# so there are no collisions.
_scalars = {
String: GraphQLString,
Int: GraphQLInt,
Float: GraphQLFloat,
Boolean: GraphQLBoolean,
ID: GraphQLID,
}
if type_ in _scalars:
return _scalars[type_]
return GrapheneScalarType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
serialize=getattr(type_, "serialize", None),
parse_value=getattr(type_, "parse_value", None),
parse_literal=getattr(type_, "parse_literal", None),
)
@staticmethod
def construct_enum(type_):
values = {}
for name, value in type_._meta.enum.__members__.items():
description = getattr(value, "description", None)
deprecation_reason = getattr(value, "deprecation_reason", None)
if not description and callable(type_._meta.description):
description = type_._meta.description(value)
if not deprecation_reason and callable(type_._meta.deprecation_reason):
deprecation_reason = type_._meta.deprecation_reason(value)
values[name] = GraphQLEnumValue(
value=value.value,
description=description,
deprecation_reason=deprecation_reason,
)
type_description = (
type_._meta.description(None)
if callable(type_._meta.description)
else type_._meta.description
)
return GrapheneEnumType(
graphene_type=type_,
values=values,
name=type_._meta.name,
description=type_description,
)
def construct_objecttype(self, map_, type_):
if type_._meta.name in map_:
_type = map_[type_._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type_, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type_)
return _type
def interfaces():
interfaces = []
for interface in type_._meta.interfaces:
self.graphene_reducer(map_, interface)
internal_type = map_[interface._meta.name]
assert internal_type.graphene_type == interface
interfaces.append(internal_type)
return interfaces
if type_._meta.possible_types:
is_type_of = partial(
is_type_of_from_possible_types, type_._meta.possible_types
)
else:
is_type_of = type_.is_type_of
return GrapheneObjectType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
fields=partial(self.construct_fields_for_type, map_, type_),
is_type_of=is_type_of,
interfaces=interfaces,
)
def construct_interface(self, map_, type_):
if type_._meta.name in map_:
_type = map_[type_._meta.name]
if isinstance(_type, GrapheneInterfaceType):
assert _type.graphene_type == type_, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type_)
return _type
_resolve_type = None
if type_.resolve_type:
_resolve_type = partial(
resolve_type, type_.resolve_type, map_, type_._meta.name
)
return GrapheneInterfaceType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
fields=partial(self.construct_fields_for_type, map_, type_),
resolve_type=_resolve_type,
)
def construct_inputobjecttype(self, map_, type_):
return GrapheneInputObjectType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
out_type=type_._meta.container,
fields=partial(
self.construct_fields_for_type, map_, type_, is_input_type=True
),
)
def construct_union(self, map_, type_):
_resolve_type = None
if type_.resolve_type:
_resolve_type = partial(
resolve_type, type_.resolve_type, map_, type_._meta.name
)
def types():
union_types = []
for objecttype in type_._meta.types:
self.graphene_reducer(map_, objecttype)
internal_type = map_[objecttype._meta.name]
assert internal_type.graphene_type == objecttype
union_types.append(internal_type)
return union_types
return GrapheneUnionType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
types=types,
resolve_type=_resolve_type,
)
def get_name(self, name):
if self.auto_camelcase:
return to_camel_case(name)
return name
def construct_fields_for_type(self, map_, type_, is_input_type=False):
fields = {}
for name, field in type_._meta.fields.items():
if isinstance(field, Dynamic):
field = get_field_as(field.get_type(self), _as=Field)
if not field:
continue
map_ = self.type_map_reducer(map_, field.type)
field_type = self.get_field_type(map_, field.type)
if is_input_type:
_field = GraphQLInputField(
field_type,
default_value=field.default_value,
out_name=name,
description=field.description,
)
else:
args = {}
for arg_name, arg in field.args.items():
map_ = self.type_map_reducer(map_, arg.type)
arg_type = self.get_field_type(map_, arg.type)
processed_arg_name = arg.name or self.get_name(arg_name)
args[processed_arg_name] = GraphQLArgument(
arg_type,
out_name=arg_name,
description=arg.description,
default_value=INVALID
if isinstance(arg.type, NonNull)
else arg.default_value,
)
_field = GraphQLField(
field_type,
args=args,
resolve=field.get_resolver(
self.get_resolver_for_type(type_, name, field.default_value)
),
deprecation_reason=field.deprecation_reason,
description=field.description,
)
field_name = field.name or self.get_name(name)
fields[field_name] = _field
return fields
def get_resolver_for_type(self, type_, name, default_value):
if not issubclass(type_, ObjectType):
return
resolver = getattr(type_, "resolve_{}".format(name), None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in type_._meta.interfaces:
if name not in interface._meta.fields:
continue
interface_resolver = getattr(interface, "resolve_{}".format(name), None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver:
return get_unbound_function(resolver)
default_resolver = type_._meta.default_resolver or get_default_resolver()
return partial(default_resolver, name, default_value)
def get_field_type(self, map_, type_):
if isinstance(type_, List):
return GraphQLList(self.get_field_type(map_, type_.of_type))
if isinstance(type_, NonNull):
return GraphQLNonNull(self.get_field_type(map_, type_.of_type))
return map_.get(type_._meta.name)
class Schema:
"""Schema Definition.
A Graphene Schema can execute operations (query, mutation, subscription) against the defined
types. For advanced purposes, the schema can be used to lookup type definitions and answer
questions about the types through introspection.
Args:
query (ObjectType): Root query *ObjectType*. Describes entry point for fields to *read*
data in your Schema.
mutation (ObjectType, optional): Root mutation *ObjectType*. Describes entry point for
fields to *create, update or delete* data in your API.
subscription (ObjectType, optional): Root subscription *ObjectType*. Describes entry point
for fields to receive continuous updates.
directives (List[GraphQLDirective], optional): List of custom directives to include in the
GraphQL schema. Defaults to only include directives defined by GraphQL spec (@include
and @skip) [GraphQLIncludeDirective, GraphQLSkipDirective].
types (List[GraphQLType], optional): List of any types to include in schema that
may not be introspected through root types.
auto_camelcase (bool): Fieldnames will be transformed in Schema's TypeMap from snake_case
to camelCase (preferred by GraphQL standard). Default True.
"""
def __init__(
self,
query=None,
mutation=None,
subscription=None,
types=None,
directives=None,
auto_camelcase=True,
):
self.query = query
self.mutation = mutation
self.subscription = subscription
self.graphql_schema = GrapheneGraphQLSchema(
query,
mutation,
subscription,
types,
directives,
auto_camelcase=auto_camelcase,
)
def __str__(self):
return print_schema(self)
return print_schema(self.graphql_schema)
def __getattr__(self, type_name):
"""
This function let the developer select a type in a given schema
by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema
"""
_type = self.graphql_schema.get_type(type_name)
if _type is None:
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type
return _type
def lazy(self, _type):
return lambda: self.get_type(_type)
def build_typemap(self):
initial_types = [
self._query,
self._mutation,
self._subscription,
IntrospectionSchema,
]
if self.types:
initial_types += self.types
self._type_map = TypeMap(
initial_types, auto_camelcase=self.auto_camelcase, schema=self
)
def execute(self, *args, **kwargs):
"""Execute a GraphQL query on the schema.
Use the `graphql_sync` function from `graphql-core` to provide the result
for a query string. Most of the time this method will be called by one of the Graphene
:ref:`Integrations` via a web request.
Args:
request_string (str or Document): GraphQL request (query, mutation or subscription)
as string or parsed AST form from `graphql-core`.
root_value (Any, optional): Value to use as the parent value object when resolving
root types.
context_value (Any, optional): Value to be made avaiable to all resolvers via
`info.context`. Can be used to share authorization, dataloaders or other
information needed to resolve an operation.
variable_values (dict, optional): If variables are used in the request string, they can
be provided in dictionary form mapping the variable name to the variable value.
operation_name (str, optional): If multiple operations are provided in the
request_string, an operation name must be provided for the result to be provided.
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
defined in `graphql-core`.
Returns:
:obj:`ExecutionResult` containing any data and errors for the operation.
"""
kwargs = normalize_execute_kwargs(kwargs)
return graphql_sync(self.graphql_schema, *args, **kwargs)
async def execute_async(self, *args, **kwargs):
"""Execute a GraphQL query on the schema asynchronously.
Same as `execute`, but uses `graphql` instead of `graphql_sync`.
"""
kwargs = normalize_execute_kwargs(kwargs)
return await graphql(self.graphql_schema, *args, **kwargs)
def introspect(self):
introspection = self.execute(introspection_query)
if introspection.errors:
raise introspection.errors[0]
return introspection.data
def normalize_execute_kwargs(kwargs):
"""Replace alias names in keyword arguments for graphql()"""
if "root" in kwargs and "root_value" not in kwargs:
kwargs["root_value"] = kwargs.pop("root")
if "context" in kwargs and "context_value" not in kwargs:
kwargs["context_value"] = kwargs.pop("context")
if "variables" in kwargs and "variable_values" not in kwargs:
kwargs["variable_values"] = kwargs.pop("variables")
if "operation" in kwargs and "operation_name" not in kwargs:
kwargs["operation_name"] = kwargs.pop("operation")
return kwargs

View File

@ -1,4 +1,5 @@
from .. import abstracttype
from pytest import deprecated_call
from ..abstracttype import AbstractType
from ..field import Field
from ..objecttype import ObjectType
@ -14,24 +15,25 @@ class MyScalar(UnmountedType):
return MyType
def test_abstract_objecttype_warn_deprecation(mocker):
mocker.patch.object(abstracttype, "warn_deprecation")
def test_abstract_objecttype_warn_deprecation():
with deprecated_call():
class MyAbstractType(AbstractType):
field1 = MyScalar()
assert abstracttype.warn_deprecation.called
# noinspection PyUnusedLocal
class MyAbstractType(AbstractType):
field1 = MyScalar()
def test_generate_objecttype_inherit_abstracttype():
class MyAbstractType(AbstractType):
field1 = MyScalar()
with deprecated_call():
class MyObjectType(ObjectType, MyAbstractType):
field2 = MyScalar()
class MyAbstractType(AbstractType):
field1 = MyScalar()
class MyObjectType(ObjectType, MyAbstractType):
field2 = MyScalar()
assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(MyObjectType._meta.fields) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]

View File

@ -1,6 +1,6 @@
from functools import partial
import pytest
from pytest import raises
from ..argument import Argument, to_arguments
from ..field import Field
@ -43,7 +43,7 @@ def test_to_arguments():
def test_to_arguments_raises_if_field():
args = {"arg_string": Field(String)}
with pytest.raises(ValueError) as exc_info:
with raises(ValueError) as exc_info:
to_arguments(args)
assert str(exc_info.value) == (
@ -55,7 +55,7 @@ def test_to_arguments_raises_if_field():
def test_to_arguments_raises_if_inputfield():
args = {"arg_string": InputField(String)}
with pytest.raises(ValueError) as exc_info:
with raises(ValueError) as exc_info:
to_arguments(args)
assert str(exc_info.value) == (

View File

@ -2,7 +2,8 @@ import datetime
import pytz
from graphql import GraphQLError
import pytest
from pytest import fixture, mark
from ..datetime import Date, DateTime, Time
from ..objecttype import ObjectType
@ -27,13 +28,13 @@ class Query(ObjectType):
schema = Schema(query=Query)
@pytest.fixture
@fixture
def sample_datetime():
utc_datetime = datetime.datetime(2019, 5, 25, 5, 30, 15, 10, pytz.utc)
return utc_datetime
@pytest.fixture
@fixture
def sample_time(sample_datetime):
time = datetime.time(
sample_datetime.hour,
@ -45,7 +46,7 @@ def sample_time(sample_datetime):
return time
@pytest.fixture
@fixture
def sample_date(sample_datetime):
date = sample_datetime.date()
return date
@ -76,12 +77,16 @@ def test_time_query(sample_time):
def test_bad_datetime_query():
not_a_date = "Some string that's not a date"
not_a_date = "Some string that's not a datetime"
result = schema.execute("""{ datetime(in: "%s") }""" % not_a_date)
assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError)
assert result.errors and len(result.errors) == 1
error = result.errors[0]
assert isinstance(error, GraphQLError)
assert error.message == (
'Expected type DateTime, found "Some string that\'s not a datetime".'
)
assert result.data is None
@ -90,18 +95,24 @@ def test_bad_date_query():
result = schema.execute("""{ date(in: "%s") }""" % not_a_date)
assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError)
error = result.errors[0]
assert isinstance(error, GraphQLError)
assert error.message == (
'Expected type Date, found "Some string that\'s not a date".'
)
assert result.data is None
def test_bad_time_query():
not_a_date = "Some string that's not a date"
not_a_date = "Some string that's not a time"
result = schema.execute("""{ time(at: "%s") }""" % not_a_date)
assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError)
error = result.errors[0]
assert isinstance(error, GraphQLError)
assert error.message == (
'Expected type Time, found "Some string that\'s not a time".'
)
assert result.data is None
@ -163,7 +174,7 @@ def test_time_query_variable(sample_time):
assert result.data == {"time": isoformat}
@pytest.mark.xfail(
@mark.xfail(
reason="creating the error message fails when un-parsable object is not JSON serializable."
)
def test_bad_variables(sample_date, sample_datetime, sample_time):
@ -174,11 +185,11 @@ def test_bad_variables(sample_date, sample_datetime, sample_time):
),
variables={"input": input_},
)
assert len(result.errors) == 1
# when `input` is not JSON serializable formatting the error message in
# `graphql.utils.is_valid_value` line 79 fails with a TypeError
assert isinstance(result.errors, list)
assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError)
print(result.errors[0])
assert result.data is None
not_a_date = dict()

View File

@ -69,7 +69,8 @@ class MyInputObjectType(InputObjectType):
def test_defines_a_query_only_schema():
blog_schema = Schema(Query)
assert blog_schema.get_query_type().graphene_type == Query
assert blog_schema.query == Query
assert blog_schema.graphql_schema.query_type.graphene_type == Query
article_field = Query._meta.fields["article"]
assert article_field.type == Article
@ -95,7 +96,8 @@ def test_defines_a_query_only_schema():
def test_defines_a_mutation_schema():
blog_schema = Schema(Query, mutation=Mutation)
assert blog_schema.get_mutation_type().graphene_type == Mutation
assert blog_schema.mutation == Mutation
assert blog_schema.graphql_schema.mutation_type.graphene_type == Mutation
write_mutation = Mutation._meta.fields["write_article"]
assert write_mutation.type == Article
@ -105,7 +107,8 @@ def test_defines_a_mutation_schema():
def test_defines_a_subscription_schema():
blog_schema = Schema(Query, subscription=Subscription)
assert blog_schema.get_subscription_type().graphene_type == Subscription
assert blog_schema.subscription == Subscription
assert blog_schema.graphql_schema.subscription_type.graphene_type == Subscription
subscription = Subscription._meta.fields["article_subscribe"]
assert subscription.type == Article
@ -126,8 +129,9 @@ def test_includes_nested_input_objects_in_the_map():
subscribe_to_something = Field(Article, input=Argument(SomeInputObject))
schema = Schema(query=Query, mutation=SomeMutation, subscription=SomeSubscription)
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["NestedInputObject"].graphene_type is NestedInputObject
assert type_map["NestedInputObject"].graphene_type is NestedInputObject
def test_includes_interfaces_thunk_subtypes_in_the_type_map():
@ -142,8 +146,9 @@ def test_includes_interfaces_thunk_subtypes_in_the_type_map():
iface = Field(lambda: SomeInterface)
schema = Schema(query=Query, types=[SomeSubtype])
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["SomeSubtype"].graphene_type is SomeSubtype
assert type_map["SomeSubtype"].graphene_type is SomeSubtype
def test_includes_types_in_union():
@ -161,9 +166,10 @@ def test_includes_types_in_union():
union = Field(MyUnion)
schema = Schema(query=Query)
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["OtherType"].graphene_type is OtherType
assert schema.get_type_map()["SomeType"].graphene_type is SomeType
assert type_map["OtherType"].graphene_type is OtherType
assert type_map["SomeType"].graphene_type is SomeType
def test_maps_enum():
@ -181,9 +187,10 @@ def test_maps_enum():
union = Field(MyUnion)
schema = Schema(query=Query)
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["OtherType"].graphene_type is OtherType
assert schema.get_type_map()["SomeType"].graphene_type is SomeType
assert type_map["OtherType"].graphene_type is OtherType
assert type_map["SomeType"].graphene_type is SomeType
def test_includes_interfaces_subtypes_in_the_type_map():
@ -198,8 +205,9 @@ def test_includes_interfaces_subtypes_in_the_type_map():
iface = Field(SomeInterface)
schema = Schema(query=Query, types=[SomeSubtype])
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["SomeSubtype"].graphene_type is SomeSubtype
assert type_map["SomeSubtype"].graphene_type is SomeSubtype
def test_stringifies_simple_types():
@ -281,7 +289,7 @@ def test_stringifies_simple_types():
def test_does_not_mutate_passed_field_definitions():
class CommonFields(object):
class CommonFields:
field1 = String()
field2 = String(id=String())
@ -293,7 +301,7 @@ def test_does_not_mutate_passed_field_definitions():
assert TestObject1._meta.fields == TestObject2._meta.fields
class CommonFields(object):
class CommonFields:
field1 = String()
field2 = String()

View File

@ -80,36 +80,19 @@ def test_enum_from_builtin_enum_accepts_lambda_description():
class Query(ObjectType):
foo = Episode()
schema = Schema(query=Query)
schema = Schema(query=Query).graphql_schema
GraphQLPyEpisode = schema._type_map["PyEpisode"].values
episode = schema.get_type("PyEpisode")
assert schema._type_map["PyEpisode"].description == "StarWars Episodes"
assert (
GraphQLPyEpisode[0].name == "NEWHOPE"
and GraphQLPyEpisode[0].description == "New Hope Episode"
)
assert (
GraphQLPyEpisode[1].name == "EMPIRE"
and GraphQLPyEpisode[1].description == "Other"
)
assert (
GraphQLPyEpisode[2].name == "JEDI"
and GraphQLPyEpisode[2].description == "Other"
)
assert (
GraphQLPyEpisode[0].name == "NEWHOPE"
and GraphQLPyEpisode[0].deprecation_reason == "meh"
)
assert (
GraphQLPyEpisode[1].name == "EMPIRE"
and GraphQLPyEpisode[1].deprecation_reason is None
)
assert (
GraphQLPyEpisode[2].name == "JEDI"
and GraphQLPyEpisode[2].deprecation_reason is None
)
assert episode.description == "StarWars Episodes"
assert [
(name, value.description, value.deprecation_reason)
for name, value in episode.values.items()
] == [
("NEWHOPE", "New Hope Episode", "meh"),
("EMPIRE", "Other", None),
("JEDI", "Other", None),
]
def test_enum_from_python3_enum_uses_enum_doc():

View File

@ -1,6 +1,6 @@
from functools import partial
import pytest
from pytest import raises
from ..argument import Argument
from ..field import Field
@ -9,7 +9,7 @@ from ..structures import NonNull
from .utils import MyLazyType
class MyInstance(object):
class MyInstance:
value = "value"
value_func = staticmethod(lambda: "value_func")
@ -85,7 +85,7 @@ def test_field_with_string_type():
def test_field_not_source_and_resolver():
MyType = object()
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
Field(MyType, source="value", resolver=lambda: None)
assert (
str(exc_info.value)
@ -122,7 +122,7 @@ def test_field_name_as_argument():
def test_field_source_argument_as_kw():
MyType = object()
field = Field(MyType, b=NonNull(True), c=Argument(None), a=NonNull(False))
assert list(field.args.keys()) == ["b", "c", "a"]
assert list(field.args) == ["b", "c", "a"]
assert isinstance(field.args["b"], Argument)
assert isinstance(field.args["b"].type, NonNull)
assert field.args["b"].type.of_type is True

View File

@ -8,7 +8,7 @@ from ..schema import Schema
from ..unmountedtype import UnmountedType
class MyType(object):
class MyType:
pass
@ -50,7 +50,7 @@ def test_ordered_fields_in_inputobjecttype():
field = MyScalar()
asa = InputField(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ["b", "a", "field", "asa"]
assert list(MyInputObjectType._meta.fields) == ["b", "a", "field", "asa"]
def test_generate_inputobjecttype_unmountedtype():
@ -78,13 +78,13 @@ def test_generate_inputobjecttype_as_argument():
def test_generate_inputobjecttype_inherit_abstracttype():
class MyAbstractType(object):
class MyAbstractType:
field1 = MyScalar(MyType)
class MyInputObjectType(InputObjectType, MyAbstractType):
field2 = MyScalar(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(MyInputObjectType._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [
InputField,
InputField,
@ -92,13 +92,13 @@ def test_generate_inputobjecttype_inherit_abstracttype():
def test_generate_inputobjecttype_inherit_abstracttype_reversed():
class MyAbstractType(object):
class MyAbstractType:
field1 = MyScalar(MyType)
class MyInputObjectType(MyAbstractType, InputObjectType):
field2 = MyScalar(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(MyInputObjectType._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [
InputField,
InputField,
@ -133,5 +133,6 @@ def test_inputobjecttype_of_input():
}
"""
)
assert not result.errors
assert result.data == {"isChild": True}

View File

@ -3,7 +3,7 @@ from ..interface import Interface
from ..unmountedtype import UnmountedType
class MyType(object):
class MyType:
pass
@ -45,7 +45,7 @@ def test_ordered_fields_in_interface():
field = MyScalar()
asa = Field(MyType)
assert list(MyInterface._meta.fields.keys()) == ["b", "a", "field", "asa"]
assert list(MyInterface._meta.fields) == ["b", "a", "field", "asa"]
def test_generate_interface_unmountedtype():
@ -57,13 +57,13 @@ def test_generate_interface_unmountedtype():
def test_generate_interface_inherit_abstracttype():
class MyAbstractType(object):
class MyAbstractType:
field1 = MyScalar()
class MyInterface(Interface, MyAbstractType):
field2 = MyScalar()
assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"]
assert list(MyInterface._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
@ -75,16 +75,16 @@ def test_generate_interface_inherit_interface():
field2 = MyScalar()
assert MyInterface._meta.name == "MyInterface"
assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"]
assert list(MyInterface._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
def test_generate_interface_inherit_abstracttype_reversed():
class MyAbstractType(object):
class MyAbstractType:
field1 = MyScalar()
class MyInterface(MyAbstractType, Interface):
field2 = MyScalar()
assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"]
assert list(MyInterface._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]

View File

@ -1,4 +1,4 @@
import pytest
from pytest import raises
from ..argument import Argument
from ..dynamic import Dynamic
@ -46,7 +46,7 @@ def test_generate_mutation_with_meta():
def test_mutation_raises_exception_if_no_mutate():
with pytest.raises(AssertionError) as excinfo:
with raises(AssertionError) as excinfo:
class MyMutation(Mutation):
pass

View File

@ -1,4 +1,4 @@
import pytest
from pytest import raises
from ..field import Field
from ..interface import Interface
@ -91,7 +91,7 @@ def test_generate_objecttype_with_private_attributes():
m = MyObjectType(_private_state="custom")
assert m._private_state == "custom"
with pytest.raises(TypeError):
with raises(TypeError):
MyObjectType(_other_private_state="Wrong")
@ -102,11 +102,11 @@ def test_ordered_fields_in_objecttype():
field = MyScalar()
asa = Field(MyType)
assert list(MyObjectType._meta.fields.keys()) == ["b", "a", "field", "asa"]
assert list(MyObjectType._meta.fields) == ["b", "a", "field", "asa"]
def test_generate_objecttype_inherit_abstracttype():
class MyAbstractType(object):
class MyAbstractType:
field1 = MyScalar()
class MyObjectType(ObjectType, MyAbstractType):
@ -115,12 +115,12 @@ def test_generate_objecttype_inherit_abstracttype():
assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(MyObjectType._meta.fields) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]
def test_generate_objecttype_inherit_abstracttype_reversed():
class MyAbstractType(object):
class MyAbstractType:
field1 = MyScalar()
class MyObjectType(MyAbstractType, ObjectType):
@ -129,7 +129,7 @@ def test_generate_objecttype_inherit_abstracttype_reversed():
assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"]
assert list(MyObjectType._meta.fields) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]
@ -142,15 +142,11 @@ def test_generate_objecttype_unmountedtype():
def test_parent_container_get_fields():
assert list(Container._meta.fields.keys()) == ["field1", "field2"]
assert list(Container._meta.fields) == ["field1", "field2"]
def test_parent_container_interface_get_fields():
assert list(ContainerWithInterface._meta.fields.keys()) == [
"ifield",
"field1",
"field2",
]
assert list(ContainerWithInterface._meta.fields) == ["ifield", "field1", "field2"]
def test_objecttype_as_container_only_args():
@ -177,14 +173,14 @@ def test_objecttype_as_container_all_kwargs():
def test_objecttype_as_container_extra_args():
with pytest.raises(IndexError) as excinfo:
with raises(IndexError) as excinfo:
Container("1", "2", "3")
assert "Number of args exceeds number of fields" == str(excinfo.value)
def test_objecttype_as_container_invalid_kwargs():
with pytest.raises(TypeError) as excinfo:
with raises(TypeError) as excinfo:
Container(unexisting_field="3")
assert "'unexisting_field' is an invalid keyword argument for Container" == str(
@ -218,7 +214,7 @@ def test_objecttype_with_possible_types():
def test_objecttype_with_possible_types_and_is_type_of_should_raise():
with pytest.raises(AssertionError) as excinfo:
with raises(AssertionError) as excinfo:
class MyObjectType(ObjectType):
class Meta:

View File

@ -1,7 +1,13 @@
import json
from functools import partial
from graphql import GraphQLError, ResolveInfo, Source, execute, parse
from graphql import (
GraphQLError,
GraphQLResolveInfo as ResolveInfo,
Source,
execute,
parse,
)
from ..context import Context
from ..dynamic import Dynamic
@ -28,7 +34,7 @@ def test_query():
def test_query_source():
class Root(object):
class Root:
_hello = "World"
def hello(self):
@ -45,10 +51,10 @@ def test_query_source():
def test_query_union():
class one_object(object):
class one_object:
pass
class two_object(object):
class two_object:
pass
class One(ObjectType):
@ -83,10 +89,10 @@ def test_query_union():
def test_query_interface():
class one_object(object):
class one_object:
pass
class two_object(object):
class two_object:
pass
class MyInterface(Interface):
@ -175,7 +181,7 @@ def test_query_wrong_default_value():
assert len(executed.errors) == 1
assert (
executed.errors[0].message
== GraphQLError('Expected value of type "MyType" but got: str.').message
== GraphQLError("Expected value of type 'MyType' but got: 'hello'.").message
)
assert executed.data == {"hello": None}
@ -223,11 +229,11 @@ def test_query_arguments():
result = test_schema.execute("{ test }", None)
assert not result.errors
assert result.data == {"test": "[null,{}]"}
assert result.data == {"test": '[null,{"a_str":null,"a_int":null}]'}
result = test_schema.execute('{ test(aStr: "String!") }', "Source!")
assert not result.errors
assert result.data == {"test": '["Source!",{"a_str":"String!"}]'}
assert result.data == {"test": '["Source!",{"a_str":"String!","a_int":null}]'}
result = test_schema.execute('{ test(aInt: -123, aStr: "String!") }', "Source!")
assert not result.errors
@ -252,18 +258,21 @@ def test_query_input_field():
result = test_schema.execute("{ test }", None)
assert not result.errors
assert result.data == {"test": "[null,{}]"}
assert result.data == {"test": '[null,{"a_input":null}]'}
result = test_schema.execute('{ test(aInput: {aField: "String!"} ) }', "Source!")
assert not result.errors
assert result.data == {"test": '["Source!",{"a_input":{"a_field":"String!"}}]'}
assert result.data == {
"test": '["Source!",{"a_input":{"a_field":"String!","recursive_field":null}}]'
}
result = test_schema.execute(
'{ test(aInput: {recursiveField: {aField: "String!"}}) }', "Source!"
)
assert not result.errors
assert result.data == {
"test": '["Source!",{"a_input":{"recursive_field":{"a_field":"String!"}}}]'
"test": '["Source!",{"a_input":{"a_field":null,"recursive_field":'
'{"a_field":"String!","recursive_field":null}}}]'
}
@ -279,8 +288,7 @@ def test_query_middlewares():
return "other"
def reversed_middleware(next, *args, **kwargs):
p = next(*args, **kwargs)
return p.then(lambda x: x[::-1])
return next(*args, **kwargs)[::-1]
hello_schema = Schema(Query)
@ -342,10 +350,11 @@ def test_big_list_query_compiled_query_benchmark(benchmark):
return big_list
hello_schema = Schema(Query)
graphql_schema = hello_schema.graphql_schema
source = Source("{ allInts }")
query_ast = parse(source)
big_list_query = partial(execute, hello_schema, query_ast)
big_list_query = partial(execute, graphql_schema, query_ast)
result = benchmark(big_list_query)
assert not result.errors
assert result.data == {"allInts": list(big_list)}

View File

@ -13,7 +13,7 @@ info = None
demo_dict = {"attr": "value"}
class demo_obj(object):
class demo_obj:
attr = "value"

View File

@ -1,4 +1,6 @@
import pytest
from pytest import raises
from graphql.pyutils import dedent
from ..field import Field
from ..objecttype import ObjectType
@ -15,8 +17,8 @@ class Query(ObjectType):
def test_schema():
schema = Schema(Query)
assert schema.get_query_type() == schema.get_graphql_type(Query)
schema = Schema(Query).graphql_schema
assert schema.query_type == schema.get_graphql_type(Query)
def test_schema_get_type():
@ -27,7 +29,7 @@ def test_schema_get_type():
def test_schema_get_type_error():
schema = Schema(Query)
with pytest.raises(AttributeError) as exc_info:
with raises(AttributeError) as exc_info:
schema.X
assert str(exc_info.value) == 'Type "X" not found in the Schema'
@ -35,20 +37,16 @@ def test_schema_get_type_error():
def test_schema_str():
schema = Schema(Query)
assert (
str(schema)
== """schema {
query: Query
}
assert str(schema) == dedent(
"""
type MyOtherType {
field: String
}
type MyOtherType {
field: String
}
type Query {
inner: MyOtherType
}
"""
type Query {
inner: MyOtherType
}
"""
)

View File

@ -1,6 +1,6 @@
from functools import partial
import pytest
from pytest import raises
from ..scalars import String
from ..structures import List, NonNull
@ -14,7 +14,7 @@ def test_list():
def test_list_with_unmounted_type():
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
List(String())
assert (
@ -82,7 +82,7 @@ def test_nonnull_inherited_works_list():
def test_nonnull_inherited_dont_work_nonnull():
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
NonNull(NonNull(String))
assert (
@ -92,7 +92,7 @@ def test_nonnull_inherited_dont_work_nonnull():
def test_nonnull_with_unmounted_type():
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
NonNull(String())
assert (

View File

@ -1,10 +1,11 @@
import pytest
from pytest import raises
from graphql.type import (
GraphQLArgument,
GraphQLEnumType,
GraphQLEnumValue,
GraphQLField,
GraphQLInputObjectField,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLObjectType,
@ -20,7 +21,13 @@ from ..interface import Interface
from ..objecttype import ObjectType
from ..scalars import Int, String
from ..structures import List, NonNull
from ..typemap import TypeMap, resolve_type
from ..schema import GrapheneGraphQLSchema, resolve_type
def create_type_map(types, auto_camelcase=True):
query = GraphQLObjectType("Query", {})
schema = GrapheneGraphQLSchema(query, types=types, auto_camelcase=auto_camelcase)
return schema.type_map
def test_enum():
@ -39,22 +46,18 @@ def test_enum():
if self == MyEnum.foo:
return "Is deprecated"
typemap = TypeMap([MyEnum])
assert "MyEnum" in typemap
graphql_enum = typemap["MyEnum"]
type_map = create_type_map([MyEnum])
assert "MyEnum" in type_map
graphql_enum = type_map["MyEnum"]
assert isinstance(graphql_enum, GraphQLEnumType)
assert graphql_enum.name == "MyEnum"
assert graphql_enum.description == "Description"
values = graphql_enum.values
assert values == [
GraphQLEnumValue(
name="foo",
value=1,
description="Description foo=1",
deprecation_reason="Is deprecated",
assert graphql_enum.values == {
"foo": GraphQLEnumValue(
value=1, description="Description foo=1", deprecation_reason="Is deprecated"
),
GraphQLEnumValue(name="bar", value=2, description="Description bar=2"),
]
"bar": GraphQLEnumValue(value=2, description="Description bar=2"),
}
def test_objecttype():
@ -70,15 +73,15 @@ def test_objecttype():
def resolve_foo(self, bar):
return bar
typemap = TypeMap([MyObjectType])
assert "MyObjectType" in typemap
graphql_type = typemap["MyObjectType"]
type_map = create_type_map([MyObjectType])
assert "MyObjectType" in type_map
graphql_type = type_map["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == "MyObjectType"
assert graphql_type.description == "Description"
fields = graphql_type.fields
assert list(fields.keys()) == ["foo", "gizmo"]
assert list(fields) == ["foo", "gizmo"]
foo_field = fields["foo"]
assert isinstance(foo_field, GraphQLField)
assert foo_field.description == "Field description"
@ -100,13 +103,13 @@ def test_dynamic_objecttype():
bar = Dynamic(lambda: Field(String))
own = Field(lambda: MyObjectType)
typemap = TypeMap([MyObjectType])
assert "MyObjectType" in typemap
assert list(MyObjectType._meta.fields.keys()) == ["bar", "own"]
graphql_type = typemap["MyObjectType"]
type_map = create_type_map([MyObjectType])
assert "MyObjectType" in type_map
assert list(MyObjectType._meta.fields) == ["bar", "own"]
graphql_type = type_map["MyObjectType"]
fields = graphql_type.fields
assert list(fields.keys()) == ["bar", "own"]
assert list(fields) == ["bar", "own"]
assert fields["bar"].type == GraphQLString
assert fields["own"].type == graphql_type
@ -125,21 +128,21 @@ def test_interface():
def resolve_foo(self, args, info):
return args.get("bar")
typemap = TypeMap([MyInterface])
assert "MyInterface" in typemap
graphql_type = typemap["MyInterface"]
type_map = create_type_map([MyInterface])
assert "MyInterface" in type_map
graphql_type = type_map["MyInterface"]
assert isinstance(graphql_type, GraphQLInterfaceType)
assert graphql_type.name == "MyInterface"
assert graphql_type.description == "Description"
fields = graphql_type.fields
assert list(fields.keys()) == ["foo", "gizmo", "own"]
assert list(fields) == ["foo", "gizmo", "own"]
assert fields["own"].type == graphql_type
assert list(fields["gizmo"].args.keys()) == ["firstArg", "oth_arg"]
assert list(fields["gizmo"].args) == ["firstArg", "oth_arg"]
foo_field = fields["foo"]
assert isinstance(foo_field, GraphQLField)
assert foo_field.description == "Field description"
assert not foo_field.resolver # Resolver not attached in interfaces
assert not foo_field.resolve # Resolver not attached in interfaces
assert foo_field.args == {
"bar": GraphQLArgument(
GraphQLString,
@ -169,23 +172,23 @@ def test_inputobject():
def resolve_foo_bar(self, args, info):
return args.get("bar")
typemap = TypeMap([MyInputObjectType])
assert "MyInputObjectType" in typemap
graphql_type = typemap["MyInputObjectType"]
type_map = create_type_map([MyInputObjectType])
assert "MyInputObjectType" in type_map
graphql_type = type_map["MyInputObjectType"]
assert isinstance(graphql_type, GraphQLInputObjectType)
assert graphql_type.name == "MyInputObjectType"
assert graphql_type.description == "Description"
other_graphql_type = typemap["OtherObjectType"]
inner_graphql_type = typemap["MyInnerObjectType"]
container = graphql_type.create_container(
other_graphql_type = type_map["OtherObjectType"]
inner_graphql_type = type_map["MyInnerObjectType"]
container = graphql_type.out_type(
{
"bar": "oh!",
"baz": inner_graphql_type.create_container(
"baz": inner_graphql_type.out_type(
{
"some_other_field": [
other_graphql_type.create_container({"thingy": 1}),
other_graphql_type.create_container({"thingy": 2}),
other_graphql_type.out_type({"thingy": 1}),
other_graphql_type.out_type({"thingy": 2}),
]
}
),
@ -201,11 +204,11 @@ def test_inputobject():
assert container.baz.some_other_field[1].thingy == 2
fields = graphql_type.fields
assert list(fields.keys()) == ["fooBar", "gizmo", "baz", "own"]
assert list(fields) == ["fooBar", "gizmo", "baz", "own"]
own_field = fields["own"]
assert own_field.type == graphql_type
foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLInputObjectField)
assert isinstance(foo_field, GraphQLInputField)
assert foo_field.description == "Field description"
@ -215,19 +218,19 @@ def test_objecttype_camelcase():
foo_bar = String(bar_foo=String())
typemap = TypeMap([MyObjectType])
assert "MyObjectType" in typemap
graphql_type = typemap["MyObjectType"]
type_map = create_type_map([MyObjectType])
assert "MyObjectType" in type_map
graphql_type = type_map["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == "MyObjectType"
assert graphql_type.description == "Description"
fields = graphql_type.fields
assert list(fields.keys()) == ["fooBar"]
assert list(fields) == ["fooBar"]
foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLField)
assert foo_field.args == {
"barFoo": GraphQLArgument(GraphQLString, out_name="bar_foo")
"barFoo": GraphQLArgument(GraphQLString, default_value=None, out_name="bar_foo")
}
@ -237,19 +240,21 @@ def test_objecttype_camelcase_disabled():
foo_bar = String(bar_foo=String())
typemap = TypeMap([MyObjectType], auto_camelcase=False)
assert "MyObjectType" in typemap
graphql_type = typemap["MyObjectType"]
type_map = create_type_map([MyObjectType], auto_camelcase=False)
assert "MyObjectType" in type_map
graphql_type = type_map["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == "MyObjectType"
assert graphql_type.description == "Description"
fields = graphql_type.fields
assert list(fields.keys()) == ["foo_bar"]
assert list(fields) == ["foo_bar"]
foo_field = fields["foo_bar"]
assert isinstance(foo_field, GraphQLField)
assert foo_field.args == {
"bar_foo": GraphQLArgument(GraphQLString, out_name="bar_foo")
"bar_foo": GraphQLArgument(
GraphQLString, default_value=None, out_name="bar_foo"
)
}
@ -262,8 +267,8 @@ def test_objecttype_with_possible_types():
foo_bar = String()
typemap = TypeMap([MyObjectType])
graphql_type = typemap["MyObjectType"]
type_map = create_type_map([MyObjectType])
graphql_type = type_map["MyObjectType"]
assert graphql_type.is_type_of
assert graphql_type.is_type_of({}, None) is True
assert graphql_type.is_type_of(MyObjectType(), None) is False
@ -279,8 +284,8 @@ def test_resolve_type_with_missing_type():
def resolve_type_func(root, info):
return MyOtherObjectType
typemap = TypeMap([MyObjectType])
with pytest.raises(AssertionError) as excinfo:
resolve_type(resolve_type_func, typemap, "MyOtherObjectType", {}, {})
type_map = create_type_map([MyObjectType])
with raises(AssertionError) as excinfo:
resolve_type(resolve_type_func, type_map, "MyOtherObjectType", {}, {}, None)
assert "MyOtherObjectTyp" in str(excinfo.value)

View File

@ -1,4 +1,4 @@
import pytest
from pytest import raises
from ..field import Field
from ..objecttype import ObjectType
@ -38,7 +38,7 @@ def test_generate_union_with_meta():
def test_generate_union_with_no_types():
with pytest.raises(Exception) as exc_info:
with raises(Exception) as exc_info:
class MyUnion(Union):
pass

View File

@ -1,337 +0,0 @@
import inspect
from collections import OrderedDict
from functools import partial
from graphql import (
GraphQLArgument,
GraphQLBoolean,
GraphQLField,
GraphQLFloat,
GraphQLID,
GraphQLInputObjectField,
GraphQLInt,
GraphQLList,
GraphQLNonNull,
GraphQLString,
)
from graphql.execution.executor import get_default_resolve_type_fn
from graphql.type import GraphQLEnumValue
from graphql.type.typemap import GraphQLTypeMap
from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case
from .definitions import (
GrapheneEnumType,
GrapheneGraphQLType,
GrapheneInputObjectType,
GrapheneInterfaceType,
GrapheneObjectType,
GrapheneScalarType,
GrapheneUnionType,
)
from .dynamic import Dynamic
from .enum import Enum
from .field import Field
from .inputobjecttype import InputObjectType
from .interface import Interface
from .objecttype import ObjectType
from .resolver import get_default_resolver
from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull
from .union import Union
from .utils import get_field_as
def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)):
return True
if inspect.isclass(_type) and issubclass(
_type, (ObjectType, InputObjectType, Scalar, Interface, Union, Enum)
):
return True
def resolve_type(resolve_type_func, map, type_name, root, info):
_type = resolve_type_func(root, info)
if not _type:
return_type = map[type_name]
return get_default_resolve_type_fn(root, info, return_type)
if inspect.isclass(_type) and issubclass(_type, ObjectType):
graphql_type = map.get(_type._meta.name)
assert graphql_type, "Can't find type {} in schema".format(_type._meta.name)
assert graphql_type.graphene_type == _type, (
"The type {} does not match with the associated graphene type {}."
).format(_type, graphql_type.graphene_type)
return graphql_type
return _type
def is_type_of_from_possible_types(possible_types, root, info):
return isinstance(root, possible_types)
class TypeMap(GraphQLTypeMap):
def __init__(self, types, auto_camelcase=True, schema=None):
self.auto_camelcase = auto_camelcase
self.schema = schema
super(TypeMap, self).__init__(types)
def reducer(self, map, type):
if not type:
return map
if inspect.isfunction(type):
type = type()
if is_graphene_type(type):
return self.graphene_reducer(map, type)
return GraphQLTypeMap.reducer(map, type)
def graphene_reducer(self, map, type):
if isinstance(type, (List, NonNull)):
return self.reducer(map, type.of_type)
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type)
return map
if issubclass(type, ObjectType):
internal_type = self.construct_objecttype(map, type)
elif issubclass(type, InputObjectType):
internal_type = self.construct_inputobjecttype(map, type)
elif issubclass(type, Interface):
internal_type = self.construct_interface(map, type)
elif issubclass(type, Scalar):
internal_type = self.construct_scalar(map, type)
elif issubclass(type, Enum):
internal_type = self.construct_enum(map, type)
elif issubclass(type, Union):
internal_type = self.construct_union(map, type)
else:
raise Exception("Expected Graphene type, but received: {}.".format(type))
return GraphQLTypeMap.reducer(map, internal_type)
def construct_scalar(self, map, type):
# We have a mapping to the original GraphQL types
# so there are no collisions.
_scalars = {
String: GraphQLString,
Int: GraphQLInt,
Float: GraphQLFloat,
Boolean: GraphQLBoolean,
ID: GraphQLID,
}
if type in _scalars:
return _scalars[type]
return GrapheneScalarType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
serialize=getattr(type, "serialize", None),
parse_value=getattr(type, "parse_value", None),
parse_literal=getattr(type, "parse_literal", None),
)
def construct_enum(self, map, type):
values = OrderedDict()
for name, value in type._meta.enum.__members__.items():
description = getattr(value, "description", None)
deprecation_reason = getattr(value, "deprecation_reason", None)
if not description and callable(type._meta.description):
description = type._meta.description(value)
if not deprecation_reason and callable(type._meta.deprecation_reason):
deprecation_reason = type._meta.deprecation_reason(value)
values[name] = GraphQLEnumValue(
name=name,
value=value.value,
description=description,
deprecation_reason=deprecation_reason,
)
type_description = (
type._meta.description(None)
if callable(type._meta.description)
else type._meta.description
)
return GrapheneEnumType(
graphene_type=type,
values=values,
name=type._meta.name,
description=type_description,
)
def construct_objecttype(self, map, type):
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type)
return _type
def interfaces():
interfaces = []
for interface in type._meta.interfaces:
self.graphene_reducer(map, interface)
internal_type = map[interface._meta.name]
assert internal_type.graphene_type == interface
interfaces.append(internal_type)
return interfaces
if type._meta.possible_types:
is_type_of = partial(
is_type_of_from_possible_types, type._meta.possible_types
)
else:
is_type_of = type.is_type_of
return GrapheneObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=partial(self.construct_fields_for_type, map, type),
is_type_of=is_type_of,
interfaces=interfaces,
)
def construct_interface(self, map, type):
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneInterfaceType):
assert _type.graphene_type == type, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type)
return _type
_resolve_type = None
if type.resolve_type:
_resolve_type = partial(
resolve_type, type.resolve_type, map, type._meta.name
)
return GrapheneInterfaceType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=partial(self.construct_fields_for_type, map, type),
resolve_type=_resolve_type,
)
def construct_inputobjecttype(self, map, type):
return GrapheneInputObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
container_type=type._meta.container,
fields=partial(
self.construct_fields_for_type, map, type, is_input_type=True
),
)
def construct_union(self, map, type):
_resolve_type = None
if type.resolve_type:
_resolve_type = partial(
resolve_type, type.resolve_type, map, type._meta.name
)
def types():
union_types = []
for objecttype in type._meta.types:
self.graphene_reducer(map, objecttype)
internal_type = map[objecttype._meta.name]
assert internal_type.graphene_type == objecttype
union_types.append(internal_type)
return union_types
return GrapheneUnionType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
types=types,
resolve_type=_resolve_type,
)
def get_name(self, name):
if self.auto_camelcase:
return to_camel_case(name)
return name
def construct_fields_for_type(self, map, type, is_input_type=False):
fields = OrderedDict()
for name, field in type._meta.fields.items():
if isinstance(field, Dynamic):
field = get_field_as(field.get_type(self.schema), _as=Field)
if not field:
continue
map = self.reducer(map, field.type)
field_type = self.get_field_type(map, field.type)
if is_input_type:
_field = GraphQLInputObjectField(
field_type,
default_value=field.default_value,
out_name=name,
description=field.description,
)
else:
args = OrderedDict()
for arg_name, arg in field.args.items():
map = self.reducer(map, arg.type)
arg_type = self.get_field_type(map, arg.type)
processed_arg_name = arg.name or self.get_name(arg_name)
args[processed_arg_name] = GraphQLArgument(
arg_type,
out_name=arg_name,
description=arg.description,
default_value=arg.default_value,
)
_field = GraphQLField(
field_type,
args=args,
resolver=field.get_resolver(
self.get_resolver_for_type(type, name, field.default_value)
),
deprecation_reason=field.deprecation_reason,
description=field.description,
)
field_name = field.name or self.get_name(name)
fields[field_name] = _field
return fields
def get_resolver_for_type(self, type, name, default_value):
if not issubclass(type, ObjectType):
return
resolver = getattr(type, "resolve_{}".format(name), None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in type._meta.interfaces:
if name not in interface._meta.fields:
continue
interface_resolver = getattr(interface, "resolve_{}".format(name), None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver:
return get_unbound_function(resolver)
default_resolver = type._meta.default_resolver or get_default_resolver()
return partial(default_resolver, name, default_value)
def get_field_type(self, map, type):
if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type))
if isinstance(type, NonNull):
return GraphQLNonNull(self.get_field_type(map, type.of_type))
return map.get(type._meta.name)

View File

@ -1,5 +1,4 @@
import inspect
from collections import OrderedDict
from functools import partial
from ..utils.module_loading import import_string
@ -33,7 +32,7 @@ def yank_fields_from_attrs(attrs, _as=None, sort=True):
if sort:
fields_with_names = sorted(fields_with_names, key=lambda f: f[1])
return OrderedDict(fields_with_names)
return dict(fields_with_names)
def get_type(_type):

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import
from uuid import UUID as _UUID
from graphql.language import ast
from graphql.language.ast import StringValueNode
from .scalars import Scalar
@ -24,7 +24,7 @@ class UUID(Scalar):
@staticmethod
def parse_literal(node):
if isinstance(node, ast.StringValue):
if isinstance(node, StringValueNode):
return _UUID(node.value)
@staticmethod

View File

@ -1,5 +1,5 @@
import json
from collections import Mapping
from collections.abc import Mapping
def to_key(value):

View File

@ -1,4 +1,4 @@
from collections import Mapping, OrderedDict
from collections.abc import Mapping
def deflate(node, index=None, path=None):
@ -16,10 +16,9 @@ def deflate(node, index=None, path=None):
else:
index[cache_key] = True
field_names = node.keys()
result = OrderedDict()
result = {}
for field_name in field_names:
for field_name in node:
value = node[field_name]
new_path = path + [field_name]

View File

@ -2,7 +2,7 @@ from functools import total_ordering
@total_ordering
class OrderedType(object):
class OrderedType:
creation_counter = 1
def __init__(self, _creation_counter=None):

View File

@ -2,7 +2,7 @@ class _OldClass:
pass
class _NewClass(object):
class _NewClass:
pass

View File

@ -43,7 +43,7 @@ class SubclassWithMeta(metaclass=SubclassWithMeta_Meta):
assert not options, (
"Abstract types can only contain the abstract attribute. "
"Received: abstract, {option_keys}"
).format(option_keys=", ".join(options.keys()))
).format(option_keys=", ".join(options))
else:
super_class = super(cls, cls)
if hasattr(super_class, "__init_subclass_with_meta__"):

View File

@ -1,10 +1,9 @@
import pytest
from collections import OrderedDict
from pytest import mark
from ..crunch import crunch
@pytest.mark.parametrize(
@mark.parametrize(
"description,uncrunched,crunched",
[
["number primitive", 0, [0]],
@ -28,28 +27,22 @@ from ..crunch import crunch
["single-item object", {"a": None}, [None, {"a": 0}]],
[
"multi-item all distinct object",
OrderedDict([("a", None), ("b", 0), ("c", True), ("d", "string")]),
{"a": None, "b": 0, "c": True, "d": "string"},
[None, 0, True, "string", {"a": 0, "b": 1, "c": 2, "d": 3}],
],
[
"multi-item repeated object",
OrderedDict([("a", True), ("b", True), ("c", True), ("d", True)]),
{"a": True, "b": True, "c": True, "d": True},
[True, {"a": 0, "b": 0, "c": 0, "d": 0}],
],
[
"complex array",
[OrderedDict([("a", True), ("b", [1, 2, 3])]), [1, 2, 3]],
[{"a": True, "b": [1, 2, 3]}, [1, 2, 3]],
[True, 1, 2, 3, [1, 2, 3], {"a": 0, "b": 4}, [5, 4]],
],
[
"complex object",
OrderedDict(
[
("a", True),
("b", [1, 2, 3]),
("c", OrderedDict([("a", True), ("b", [1, 2, 3])])),
]
),
{"a": True, "b": [1, 2, 3], "c": {"a": True, "b": [1, 2, 3]}},
[True, 1, 2, 3, [1, 2, 3], {"a": 0, "b": 4}, {"a": 0, "b": 4, "c": 5}],
],
],

View File

@ -150,8 +150,8 @@ def test_example_end_to_end():
result = schema.execute(query)
assert not result.errors
result.data = deflate(result.data)
assert result.data == {
data = deflate(result.data)
assert data == {
"events": [
{
"__typename": "Event",

View File

@ -1,4 +1,4 @@
import pytest
from pytest import raises
from .. import deprecated
from ..deprecated import deprecated as deprecated_decorator
@ -71,5 +71,5 @@ def test_deprecated_class_text(mocker):
def test_deprecated_other_object(mocker):
mocker.patch.object(deprecated, "warn_deprecation")
with pytest.raises(TypeError):
with raises(TypeError):
deprecated_decorator({})

View File

@ -2,7 +2,7 @@ from ..trim_docstring import trim_docstring
def test_trim_docstring():
class WellDocumentedObject(object):
class WellDocumentedObject:
"""
This object is very well-documented. It has multiple lines in its
description.
@ -16,7 +16,7 @@ def test_trim_docstring():
"description.\n\nMultiple paragraphs too"
)
class UndocumentedObject(object):
class UndocumentedObject:
pass
assert trim_docstring(UndocumentedObject.__doc__) is None

View File

@ -1,28 +1,15 @@
"""
This file is used mainly as a bridge for thenable abstractions.
This includes:
- Promises
- Asyncio Coroutines
"""
try:
from promise import Promise, is_thenable # type: ignore
except ImportError:
class Promise(object): # type: ignore
pass
def is_thenable(obj): # type: ignore
return False
from inspect import isawaitable
try:
from inspect import isawaitable
from .thenables_asyncio import await_and_execute
except ImportError:
def await_and_execute(obj, on_resolve):
async def build_resolve_async():
return on_resolve(await obj)
def isawaitable(obj): # type: ignore
return False
return build_resolve_async()
def maybe_thenable(obj, on_resolve):
@ -31,12 +18,8 @@ def maybe_thenable(obj, on_resolve):
returning the same type of object inputed.
If the object is not thenable, it should return on_resolve(obj)
"""
if isawaitable(obj) and not isinstance(obj, Promise):
if isawaitable(obj):
return await_and_execute(obj, on_resolve)
if is_thenable(obj):
return Promise.resolve(obj).then(on_resolve)
# If it's not awaitable not a Promise, return
# the function executed over the object
# If it's not awaitable, return the function executed over the object
return on_resolve(obj)

View File

@ -1,5 +0,0 @@
def await_and_execute(obj, on_resolve):
async def build_resolve_async():
return on_resolve(await obj)
return build_resolve_async()

View File

@ -80,15 +80,11 @@ setup(
keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests", "tests.*", "examples"]),
install_requires=[
"graphql-core>=2.1,<3",
"graphql-relay>=2,<3",
"aniso8601>=3,<=7",
"graphql-core>=3.0.0a0,<4",
"graphql-relay>=3.0.0a0,<4",
"aniso8601>=6,<8",
],
tests_require=tests_require,
extras_require={
"test": tests_require,
"django": ["graphene-django"],
"sqlalchemy": ["graphene-sqlalchemy"],
},
extras_require={"test": tests_require},
cmdclass={"test": PyTest},
)

View File

@ -1,7 +1,4 @@
import pytest
from collections import OrderedDict
from graphql.execution.executors.asyncio import AsyncioExecutor
from pytest import mark
from graphql_relay.utils import base64
@ -27,14 +24,14 @@ class LetterConnection(Connection):
class Query(ObjectType):
letters = ConnectionField(LetterConnection)
connection_letters = ConnectionField(LetterConnection)
promise_letters = ConnectionField(LetterConnection)
async_letters = ConnectionField(LetterConnection)
node = Node.Field()
def resolve_letters(self, info, **args):
return list(letters.values())
async def resolve_promise_letters(self, info, **args):
async def resolve_async_letters(self, info, **args):
return list(letters.values())
def resolve_connection_letters(self, info, **args):
@ -48,9 +45,7 @@ class Query(ObjectType):
schema = Schema(Query)
letters = OrderedDict()
for i, letter in enumerate(letter_chars):
letters[letter] = Letter(id=i, letter=letter)
letters = {letter: Letter(id=i, letter=letter) for i, letter in enumerate(letter_chars)}
def edges(selected_letters):
@ -96,12 +91,12 @@ def execute(args=""):
)
@pytest.mark.asyncio
async def test_connection_promise():
result = await schema.execute(
@mark.asyncio
async def test_connection_async():
result = await schema.execute_async(
"""
{
promiseLetters(first:1) {
asyncLetters(first:1) {
edges {
node {
id
@ -114,14 +109,12 @@ async def test_connection_promise():
}
}
}
""",
executor=AsyncioExecutor(),
return_promise=True,
"""
)
assert not result.errors
assert result.data == {
"promiseLetters": {
"asyncLetters": {
"edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}],
"pageInfo": {"hasPreviousPage": False, "hasNextPage": True},
}

View File

@ -1,5 +1,4 @@
import pytest
from graphql.execution.executors.asyncio import AsyncioExecutor
from pytest import mark
from graphene.types import ID, Field, ObjectType, Schema
from graphene.types.scalars import String
@ -43,11 +42,11 @@ class OtherMutation(ClientIDMutation):
@staticmethod
def mutate_and_get_payload(
self, info, shared="", additional_field="", client_mutation_id=None
self, info, shared, additional_field, client_mutation_id=None
):
edge_type = MyEdge
return OtherMutation(
name=shared + additional_field,
name=(shared or "") + (additional_field or ""),
my_node_edge=edge_type(cursor="1", node=MyNode(name="name")),
)
@ -64,23 +63,19 @@ class Mutation(ObjectType):
schema = Schema(query=RootQuery, mutation=Mutation)
@pytest.mark.asyncio
@mark.asyncio
async def test_node_query_promise():
executed = await schema.execute(
'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }',
executor=AsyncioExecutor(),
return_promise=True,
executed = await schema.execute_async(
'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
)
assert not executed.errors
assert executed.data == {"sayPromise": {"phrase": "hello"}}
@pytest.mark.asyncio
@mark.asyncio
async def test_edge_query():
executed = await schema.execute(
'mutation a { other(input: {clientMutationId:"1"}) { clientMutationId, myNodeEdge { cursor node { name }} } }',
executor=AsyncioExecutor(),
return_promise=True,
executed = await schema.execute_async(
'mutation a { other(input: {clientMutationId:"1"}) { clientMutationId, myNodeEdge { cursor node { name }} } }'
)
assert not executed.errors
assert dict(executed.data) == {

View File

@ -22,14 +22,16 @@ commands =
[testenv:mypy]
basepython=python3.7
deps =
mypy
mypy>=0.720
commands =
mypy graphene
[testenv:flake8]
deps = flake8
basepython=python3.6
deps =
flake8>=3.7,<4
commands =
pip install -e .
pip install --pre -e .
flake8 graphene
[pytest]