Change dependency to graphql-core-next (#988)

* Changed dependencies to core-next

* Converted Scalars

* ResolveInfo name change

* Ignore .venv

* Make Schema compatible with GraphQL-core-next

* Ignore more venv names and mypy and pytest caches

* Remove print statements for debugging in schema test

* core-next now provides out_type and out_name

* Adapt date and time scalar types to core-next

* Ignore the non-standard result.invalid flag

* Results are named tuples in core-next (immutable)

* Enum values are returned as dict in core-next

* Fix mutation tests with promises

* Make all 345 tests pass with graphql-core-next

* Remove the compat module which was only needed for older Py version

* Remove object as base class (not needed in Py 3)

* We can assume that dicts are ordered in Py 3.6+

* Make use of the fact that dicts are iterable

* Use consistent style of importing from pytest

* Restore compatibility with graphql-relay-py v3

Add adpaters for the PageInfo and Connection args.

* Avoid various deprecation warnings

* Use graphql-core 3 instead of graphql-core-next

* Update dependencies, reformat changes with black

* Update graphene/relay/connection.py

Co-Authored-By: Jonathan Kim <jkimbo@gmail.com>

* Run black on setup.py

* Remove trailing whitespace
This commit is contained in:
Eran Kampf 2019-08-12 11:04:02 -07:00 committed by Mel van Londen
parent dc812fe028
commit 7ef3c8ee3e
68 changed files with 1092 additions and 1897 deletions

14
.gitignore vendored
View File

@ -10,9 +10,6 @@ __pycache__/
# Distribution / packaging # Distribution / packaging
.Python .Python
env/
venv/
.venv/
build/ build/
develop-eggs/ develop-eggs/
dist/ dist/
@ -47,7 +44,8 @@ htmlcov/
.pytest_cache .pytest_cache
nosetests.xml nosetests.xml
coverage.xml coverage.xml
*,cover *.cover
.pytest_cache/
# Translations # Translations
*.mo *.mo
@ -62,6 +60,14 @@ docs/_build/
# PyBuilder # PyBuilder
target/ target/
# VirtualEnv
.env
.venv
env/
venv/
# Typing
.mypy_cache/
/tests/django.sqlite /tests/django.sqlite

View File

@ -23,6 +23,6 @@ repos:
- id: black - id: black
language_version: python3 language_version: python3
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 3.7.7 rev: 3.7.8
hooks: hooks:
- id: flake8 - id: flake8

View File

@ -30,21 +30,22 @@ snapshots["test_correctly_refetches_xwing 1"] = {
snapshots[ snapshots[
"test_str_schema 1" "test_str_schema 1"
] = """schema { ] = '''"""A faction in the Star Wars saga"""
query: Query
mutation: Mutation
}
type Faction implements Node { type Faction implements Node {
"""The ID of the object"""
id: ID! id: ID!
"""The name of the faction."""
name: String name: String
ships(before: String, after: String, first: Int, last: Int): ShipConnection
"""The ships used by the faction."""
ships(before: String = null, after: String = null, first: Int = null, last: Int = null): ShipConnection
} }
input IntroduceShipInput { input IntroduceShipInput {
shipName: String! shipName: String!
factionId: String! factionId: String!
clientMutationId: String clientMutationId: String = null
} }
type IntroduceShipPayload { type IntroduceShipPayload {
@ -57,35 +58,60 @@ type Mutation {
introduceShip(input: IntroduceShipInput!): IntroduceShipPayload introduceShip(input: IntroduceShipInput!): IntroduceShipPayload
} }
"""An object with an ID"""
interface Node { interface Node {
"""The ID of the object"""
id: ID! id: ID!
} }
"""
The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.
"""
type PageInfo { type PageInfo {
"""When paginating forwards, are there more items?"""
hasNextPage: Boolean! hasNextPage: Boolean!
"""When paginating backwards, are there more items?"""
hasPreviousPage: Boolean! hasPreviousPage: Boolean!
"""When paginating backwards, the cursor to continue."""
startCursor: String startCursor: String
"""When paginating forwards, the cursor to continue."""
endCursor: String endCursor: String
} }
type Query { type Query {
rebels: Faction rebels: Faction
empire: Faction empire: Faction
"""The ID of the object"""
node(id: ID!): Node node(id: ID!): Node
} }
"""A ship in the Star Wars saga"""
type Ship implements Node { type Ship implements Node {
"""The ID of the object"""
id: ID! id: ID!
"""The name of the ship."""
name: String name: String
} }
type ShipConnection { type ShipConnection {
"""Pagination data for this connection."""
pageInfo: PageInfo! pageInfo: PageInfo!
"""Contains the nodes in this connection."""
edges: [ShipEdge]! edges: [ShipEdge]!
} }
"""A Relay edge containing a `Ship` and its cursor."""
type ShipEdge { type ShipEdge {
"""The item at the end of the edge"""
node: Ship node: Ship
"""A cursor for use in pagination"""
cursor: String! cursor: String!
} }
""" '''

View File

@ -1,8 +0,0 @@
from __future__ import absolute_import
from graphql.pyutils.compat import Enum
try:
from inspect import signature
except ImportError:
from .signature import signature

View File

@ -1,850 +0,0 @@
# Copyright 2001-2013 Python Software Foundation; All Rights Reserved
"""Function signature objects for callables
Back port of Python 3.3's function signature tools from the inspect module,
modified to be compatible with Python 2.7 and 3.2+.
"""
from __future__ import absolute_import, division, print_function
import functools
import itertools
import re
import types
from collections import OrderedDict
__version__ = "0.4"
__all__ = ["BoundArguments", "Parameter", "Signature", "signature"]
_WrapperDescriptor = type(type.__call__)
_MethodWrapper = type(all.__call__)
_NonUserDefinedCallables = (
_WrapperDescriptor,
_MethodWrapper,
types.BuiltinFunctionType,
)
def formatannotation(annotation, base_module=None):
if isinstance(annotation, type):
if annotation.__module__ in ("builtins", "__builtin__", base_module):
return annotation.__name__
return annotation.__module__ + "." + annotation.__name__
return repr(annotation)
def _get_user_defined_method(cls, method_name, *nested):
try:
if cls is type:
return
meth = getattr(cls, method_name)
for name in nested:
meth = getattr(meth, name, meth)
except AttributeError:
return
else:
if not isinstance(meth, _NonUserDefinedCallables):
# Once '__signature__' will be added to 'C'-level
# callables, this check won't be necessary
return meth
def signature(obj):
"""Get a signature object for the passed callable."""
if not callable(obj):
raise TypeError("{!r} is not a callable object".format(obj))
if isinstance(obj, types.MethodType):
sig = signature(obj.__func__)
if obj.__self__ is None:
# Unbound method: the first parameter becomes positional-only
if sig.parameters:
first = sig.parameters.values()[0].replace(kind=_POSITIONAL_ONLY)
return sig.replace(
parameters=(first,) + tuple(sig.parameters.values())[1:]
)
else:
return sig
else:
# In this case we skip the first parameter of the underlying
# function (usually `self` or `cls`).
return sig.replace(parameters=tuple(sig.parameters.values())[1:])
try:
sig = obj.__signature__
except AttributeError:
pass
else:
if sig is not None:
return sig
try:
# Was this function wrapped by a decorator?
wrapped = obj.__wrapped__
except AttributeError:
pass
else:
return signature(wrapped)
if isinstance(obj, types.FunctionType):
return Signature.from_function(obj)
if isinstance(obj, functools.partial):
sig = signature(obj.func)
new_params = OrderedDict(sig.parameters.items())
partial_args = obj.args or ()
partial_keywords = obj.keywords or {}
try:
ba = sig.bind_partial(*partial_args, **partial_keywords)
except TypeError as ex:
msg = "partial object {!r} has incorrect arguments".format(obj)
raise ValueError(msg)
for arg_name, arg_value in ba.arguments.items():
param = new_params[arg_name]
if arg_name in partial_keywords:
# We set a new default value, because the following code
# is correct:
#
# >>> def foo(a): print(a)
# >>> print(partial(partial(foo, a=10), a=20)())
# 20
# >>> print(partial(partial(foo, a=10), a=20)(a=30))
# 30
#
# So, with 'partial' objects, passing a keyword argument is
# like setting a new default value for the corresponding
# parameter
#
# We also mark this parameter with '_partial_kwarg'
# flag. Later, in '_bind', the 'default' value of this
# parameter will be added to 'kwargs', to simulate
# the 'functools.partial' real call.
new_params[arg_name] = param.replace(
default=arg_value, _partial_kwarg=True
)
elif (
param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL)
and not param._partial_kwarg
):
new_params.pop(arg_name)
return sig.replace(parameters=new_params.values())
sig = None
if isinstance(obj, type):
# obj is a class or a metaclass
# First, let's see if it has an overloaded __call__ defined
# in its metaclass
call = _get_user_defined_method(type(obj), "__call__")
if call is not None:
sig = signature(call)
else:
# Now we check if the 'obj' class has a '__new__' method
new = _get_user_defined_method(obj, "__new__")
if new is not None:
sig = signature(new)
else:
# Finally, we should have at least __init__ implemented
init = _get_user_defined_method(obj, "__init__")
if init is not None:
sig = signature(init)
elif not isinstance(obj, _NonUserDefinedCallables):
# An object with __call__
# We also check that the 'obj' is not an instance of
# _WrapperDescriptor or _MethodWrapper to avoid
# infinite recursion (and even potential segfault)
call = _get_user_defined_method(type(obj), "__call__", "im_func")
if call is not None:
sig = signature(call)
if sig is not None:
# For classes and objects we skip the first parameter of their
# __call__, __new__, or __init__ methods
return sig.replace(parameters=tuple(sig.parameters.values())[1:])
if isinstance(obj, types.BuiltinFunctionType):
# Raise a nicer error message for builtins
msg = "no signature found for builtin function {!r}".format(obj)
raise ValueError(msg)
raise ValueError("callable {!r} is not supported by signature".format(obj))
class _void(object):
"""A private marker - used in Parameter & Signature"""
class _empty(object):
pass
class _ParameterKind(int):
def __new__(self, *args, **kwargs):
obj = int.__new__(self, *args)
obj._name = kwargs["name"]
return obj
def __str__(self):
return self._name
def __repr__(self):
return "<_ParameterKind: {!r}>".format(self._name)
_POSITIONAL_ONLY = _ParameterKind(0, name="POSITIONAL_ONLY")
_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name="POSITIONAL_OR_KEYWORD")
_VAR_POSITIONAL = _ParameterKind(2, name="VAR_POSITIONAL")
_KEYWORD_ONLY = _ParameterKind(3, name="KEYWORD_ONLY")
_VAR_KEYWORD = _ParameterKind(4, name="VAR_KEYWORD")
class Parameter(object):
"""Represents a parameter in a function signature.
Has the following public attributes:
* name : str
The name of the parameter as a string.
* default : object
The default value for the parameter if specified. If the
parameter has no default value, this attribute is not set.
* annotation
The annotation for the parameter if specified. If the
parameter has no annotation, this attribute is not set.
* kind : str
Describes how argument values are bound to the parameter.
Possible values: `Parameter.POSITIONAL_ONLY`,
`Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`,
`Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`.
"""
__slots__ = ("_name", "_kind", "_default", "_annotation", "_partial_kwarg")
POSITIONAL_ONLY = _POSITIONAL_ONLY
POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD
VAR_POSITIONAL = _VAR_POSITIONAL
KEYWORD_ONLY = _KEYWORD_ONLY
VAR_KEYWORD = _VAR_KEYWORD
empty = _empty
def __init__(
self, name, kind, default=_empty, annotation=_empty, _partial_kwarg=False
):
if kind not in (
_POSITIONAL_ONLY,
_POSITIONAL_OR_KEYWORD,
_VAR_POSITIONAL,
_KEYWORD_ONLY,
_VAR_KEYWORD,
):
raise ValueError("invalid value for 'Parameter.kind' attribute")
self._kind = kind
if default is not _empty:
if kind in (_VAR_POSITIONAL, _VAR_KEYWORD):
msg = "{} parameters cannot have default values".format(kind)
raise ValueError(msg)
self._default = default
self._annotation = annotation
if name is None:
if kind != _POSITIONAL_ONLY:
raise ValueError(
"None is not a valid name for a " "non-positional-only parameter"
)
self._name = name
else:
name = str(name)
if kind != _POSITIONAL_ONLY and not re.match(r"[a-z_]\w*$", name, re.I):
msg = "{!r} is not a valid parameter name".format(name)
raise ValueError(msg)
self._name = name
self._partial_kwarg = _partial_kwarg
@property
def name(self):
return self._name
@property
def default(self):
return self._default
@property
def annotation(self):
return self._annotation
@property
def kind(self):
return self._kind
def replace(
self,
name=_void,
kind=_void,
annotation=_void,
default=_void,
_partial_kwarg=_void,
):
"""Creates a customized copy of the Parameter."""
if name is _void:
name = self._name
if kind is _void:
kind = self._kind
if annotation is _void:
annotation = self._annotation
if default is _void:
default = self._default
if _partial_kwarg is _void:
_partial_kwarg = self._partial_kwarg
return type(self)(
name,
kind,
default=default,
annotation=annotation,
_partial_kwarg=_partial_kwarg,
)
def __str__(self):
kind = self.kind
formatted = self._name
if kind == _POSITIONAL_ONLY:
if formatted is None:
formatted = ""
formatted = "<{}>".format(formatted)
# Add annotation and default value
if self._annotation is not _empty:
formatted = "{}:{}".format(formatted, formatannotation(self._annotation))
if self._default is not _empty:
formatted = "{}={}".format(formatted, repr(self._default))
if kind == _VAR_POSITIONAL:
formatted = "*" + formatted
elif kind == _VAR_KEYWORD:
formatted = "**" + formatted
return formatted
def __repr__(self):
return "<{} at {:#x} {!r}>".format(self.__class__.__name__, id(self), self.name)
def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg)
def __eq__(self, other):
return (
issubclass(other.__class__, Parameter)
and self._name == other._name
and self._kind == other._kind
and self._default == other._default
and self._annotation == other._annotation
)
def __ne__(self, other):
return not self.__eq__(other)
class BoundArguments(object):
"""Result of `Signature.bind` call. Holds the mapping of arguments
to the function's parameters.
Has the following public attributes:
* arguments : OrderedDict
An ordered mutable mapping of parameters' names to arguments' values.
Does not contain arguments' default values.
* signature : Signature
The Signature object that created this instance.
* args : tuple
Tuple of positional arguments values.
* kwargs : dict
Dict of keyword arguments values.
"""
def __init__(self, signature, arguments):
self.arguments = arguments
self._signature = signature
@property
def signature(self):
return self._signature
@property
def args(self):
args = []
for param_name, param in self._signature.parameters.items():
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or param._partial_kwarg:
# Keyword arguments mapped by 'functools.partial'
# (Parameter._partial_kwarg is True) are mapped
# in 'BoundArguments.kwargs', along with VAR_KEYWORD &
# KEYWORD_ONLY
break
try:
arg = self.arguments[param_name]
except KeyError:
# We're done here. Other arguments
# will be mapped in 'BoundArguments.kwargs'
break
else:
if param.kind == _VAR_POSITIONAL:
# *args
args.extend(arg)
else:
# plain argument
args.append(arg)
return tuple(args)
@property
def kwargs(self):
kwargs = {}
kwargs_started = False
for param_name, param in self._signature.parameters.items():
if not kwargs_started:
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or param._partial_kwarg:
kwargs_started = True
else:
if param_name not in self.arguments:
kwargs_started = True
continue
if not kwargs_started:
continue
try:
arg = self.arguments[param_name]
except KeyError:
pass
else:
if param.kind == _VAR_KEYWORD:
# **kwargs
kwargs.update(arg)
else:
# plain keyword argument
kwargs[param_name] = arg
return kwargs
def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg)
def __eq__(self, other):
return (
issubclass(other.__class__, BoundArguments)
and self.signature == other.signature
and self.arguments == other.arguments
)
def __ne__(self, other):
return not self.__eq__(other)
class Signature(object):
"""A Signature object represents the overall signature of a function.
It stores a Parameter object for each parameter accepted by the
function, as well as information specific to the function itself.
A Signature object has the following public attributes and methods:
* parameters : OrderedDict
An ordered mapping of parameters' names to the corresponding
Parameter objects (keyword-only arguments are in the same order
as listed in `code.co_varnames`).
* return_annotation : object
The annotation for the return type of the function if specified.
If the function has no annotation for its return type, this
attribute is not set.
* bind(*args, **kwargs) -> BoundArguments
Creates a mapping from positional and keyword arguments to
parameters.
* bind_partial(*args, **kwargs) -> BoundArguments
Creates a partial mapping from positional and keyword arguments
to parameters (simulating 'functools.partial' behavior.)
"""
__slots__ = ("_return_annotation", "_parameters")
_parameter_cls = Parameter
_bound_arguments_cls = BoundArguments
empty = _empty
def __init__(
self, parameters=None, return_annotation=_empty, __validate_parameters__=True
):
"""Constructs Signature from the given list of Parameter
objects and 'return_annotation'. All arguments are optional.
"""
if parameters is None:
params = OrderedDict()
else:
if __validate_parameters__:
params = OrderedDict()
top_kind = _POSITIONAL_ONLY
for idx, param in enumerate(parameters):
kind = param.kind
if kind < top_kind:
msg = "wrong parameter order: {0} before {1}"
msg = msg.format(top_kind, param.kind)
raise ValueError(msg)
else:
top_kind = kind
name = param.name
if name is None:
name = str(idx)
param = param.replace(name=name)
if name in params:
msg = "duplicate parameter name: {!r}".format(name)
raise ValueError(msg)
params[name] = param
else:
params = OrderedDict(((param.name, param) for param in parameters))
self._parameters = params
self._return_annotation = return_annotation
@classmethod
def from_function(cls, func):
"""Constructs Signature for the given python function"""
if not isinstance(func, types.FunctionType):
raise TypeError("{!r} is not a Python function".format(func))
Parameter = cls._parameter_cls
# Parameter information.
func_code = func.__code__
pos_count = func_code.co_argcount
arg_names = func_code.co_varnames
positional = tuple(arg_names[:pos_count])
keyword_only_count = getattr(func_code, "co_kwonlyargcount", 0)
keyword_only = arg_names[pos_count : (pos_count + keyword_only_count)]
annotations = getattr(func, "__annotations__", {})
defaults = func.__defaults__
kwdefaults = getattr(func, "__kwdefaults__", None)
if defaults:
pos_default_count = len(defaults)
else:
pos_default_count = 0
parameters = []
# Non-keyword-only parameters w/o defaults.
non_default_count = pos_count - pos_default_count
for name in positional[:non_default_count]:
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(name, annotation=annotation, kind=_POSITIONAL_OR_KEYWORD)
)
# ... w/ defaults.
for offset, name in enumerate(positional[non_default_count:]):
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(
name,
annotation=annotation,
kind=_POSITIONAL_OR_KEYWORD,
default=defaults[offset],
)
)
# *args
if func_code.co_flags & 0x04:
name = arg_names[pos_count + keyword_only_count]
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(name, annotation=annotation, kind=_VAR_POSITIONAL)
)
# Keyword-only parameters.
for name in keyword_only:
default = _empty
if kwdefaults is not None:
default = kwdefaults.get(name, _empty)
annotation = annotations.get(name, _empty)
parameters.append(
Parameter(
name, annotation=annotation, kind=_KEYWORD_ONLY, default=default
)
)
# **kwargs
if func_code.co_flags & 0x08:
index = pos_count + keyword_only_count
if func_code.co_flags & 0x04:
index += 1
name = arg_names[index]
annotation = annotations.get(name, _empty)
parameters.append(Parameter(name, annotation=annotation, kind=_VAR_KEYWORD))
return cls(
parameters,
return_annotation=annotations.get("return", _empty),
__validate_parameters__=False,
)
@property
def parameters(self):
try:
return types.MappingProxyType(self._parameters)
except AttributeError:
return OrderedDict(self._parameters.items())
@property
def return_annotation(self):
return self._return_annotation
def replace(self, parameters=_void, return_annotation=_void):
"""Creates a customized copy of the Signature.
Pass 'parameters' and/or 'return_annotation' arguments
to override them in the new copy.
"""
if parameters is _void:
parameters = self.parameters.values()
if return_annotation is _void:
return_annotation = self._return_annotation
return type(self)(parameters, return_annotation=return_annotation)
def __hash__(self):
msg = "unhashable type: '{}'".format(self.__class__.__name__)
raise TypeError(msg)
def __eq__(self, other):
if (
not issubclass(type(other), Signature)
or self.return_annotation != other.return_annotation
or len(self.parameters) != len(other.parameters)
):
return False
other_positions = {
param: idx for idx, param in enumerate(other.parameters.keys())
}
for idx, (param_name, param) in enumerate(self.parameters.items()):
if param.kind == _KEYWORD_ONLY:
try:
other_param = other.parameters[param_name]
except KeyError:
return False
else:
if param != other_param:
return False
else:
try:
other_idx = other_positions[param_name]
except KeyError:
return False
else:
if idx != other_idx or param != other.parameters[param_name]:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def _bind(self, args, kwargs, partial=False):
"""Private method. Don't use directly."""
arguments = OrderedDict()
parameters = iter(self.parameters.values())
parameters_ex = ()
arg_vals = iter(args)
if partial:
# Support for binding arguments to 'functools.partial' objects.
# See 'functools.partial' case in 'signature()' implementation
# for details.
for param_name, param in self.parameters.items():
if param._partial_kwarg and param_name not in kwargs:
# Simulating 'functools.partial' behavior
kwargs[param_name] = param.default
while True:
# Let's iterate through the positional arguments and corresponding
# parameters
try:
arg_val = next(arg_vals)
except StopIteration:
# No more positional arguments
try:
param = next(parameters)
except StopIteration:
# No more parameters. That's it. Just need to check that
# we have no `kwargs` after this while loop
break
else:
if param.kind == _VAR_POSITIONAL:
# That's OK, just empty *args. Let's start parsing
# kwargs
break
elif param.name in kwargs:
if param.kind == _POSITIONAL_ONLY:
msg = (
"{arg!r} parameter is positional only, "
"but was passed as a keyword"
)
msg = msg.format(arg=param.name)
raise TypeError(msg)
parameters_ex = (param,)
break
elif param.kind == _VAR_KEYWORD or param.default is not _empty:
# That's fine too - we have a default value for this
# parameter. So, lets start parsing `kwargs`, starting
# with the current parameter
parameters_ex = (param,)
break
else:
if partial:
parameters_ex = (param,)
break
else:
msg = "{arg!r} parameter lacking default value"
msg = msg.format(arg=param.name)
raise TypeError(msg)
else:
# We have a positional argument to process
try:
param = next(parameters)
except StopIteration:
raise TypeError("too many positional arguments")
else:
if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
# Looks like we have no parameter for this positional
# argument
raise TypeError("too many positional arguments")
if param.kind == _VAR_POSITIONAL:
# We have an '*args'-like argument, let's fill it with
# all positional arguments we have left and move on to
# the next phase
values = [arg_val]
values.extend(arg_vals)
arguments[param.name] = tuple(values)
break
if param.name in kwargs:
raise TypeError(
"multiple values for argument "
"{arg!r}".format(arg=param.name)
)
arguments[param.name] = arg_val
# Now, we iterate through the remaining parameters to process
# keyword arguments
kwargs_param = None
for param in itertools.chain(parameters_ex, parameters):
if param.kind == _POSITIONAL_ONLY:
# This should never happen in case of a properly built
# Signature object (but let's have this check here
# to ensure correct behaviour just in case)
raise TypeError(
"{arg!r} parameter is positional only, "
"but was passed as a keyword".format(arg=param.name)
)
if param.kind == _VAR_KEYWORD:
# Memorize that we have a '**kwargs'-like parameter
kwargs_param = param
continue
param_name = param.name
try:
arg_val = kwargs.pop(param_name)
except KeyError:
# We have no value for this parameter. It's fine though,
# if it has a default value, or it is an '*args'-like
# parameter, left alone by the processing of positional
# arguments.
if (
not partial
and param.kind != _VAR_POSITIONAL
and param.default is _empty
):
raise TypeError(
"{arg!r} parameter lacking default value".format(arg=param_name)
)
else:
arguments[param_name] = arg_val
if kwargs:
if kwargs_param is not None:
# Process our '**kwargs'-like parameter
arguments[kwargs_param.name] = kwargs
else:
raise TypeError("too many keyword arguments")
return self._bound_arguments_cls(self, arguments)
def bind(self, *args, **kwargs):
"""Get a BoundArguments object, that maps the passed `args`
and `kwargs` to the function's signature. Raises `TypeError`
if the passed arguments can not be bound.
"""
return self._bind(args, kwargs)
def bind_partial(self, *args, **kwargs):
"""Get a BoundArguments object, that partially maps the
passed `args` and `kwargs` to the function's signature.
Raises `TypeError` if the passed arguments can not be bound.
"""
return self._bind(args, kwargs, partial=True)
def __str__(self):
result = []
render_kw_only_separator = True
for idx, param in enumerate(self.parameters.values()):
formatted = str(param)
kind = param.kind
if kind == _VAR_POSITIONAL:
# OK, we have an '*args'-like parameter, so we won't need
# a '*' to separate keyword-only arguments
render_kw_only_separator = False
elif kind == _KEYWORD_ONLY and render_kw_only_separator:
# We have a keyword-only parameter to render and we haven't
# rendered an '*args'-like parameter before, so add a '*'
# separator to the parameters list ("foo(arg1, *, arg2)" case)
result.append("*")
# This condition should be only triggered once, so
# reset the flag
render_kw_only_separator = False
result.append(formatted)
rendered = "({})".format(", ".join(result))
if self.return_annotation is not _empty:
anno = formatannotation(self.return_annotation)
rendered += " -> {}".format(anno)
return rendered

View File

@ -1,8 +1,8 @@
import re import re
from collections import Iterable, OrderedDict from collections.abc import Iterable
from functools import partial from functools import partial
from graphql_relay import connection_from_list from graphql_relay import connection_from_array
from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union
from ..types.field import Field from ..types.field import Field
@ -41,6 +41,17 @@ class PageInfo(ObjectType):
) )
# noinspection PyPep8Naming
def page_info_adapter(startCursor, endCursor, hasPreviousPage, hasNextPage):
"""Adapter for creating PageInfo instances"""
return PageInfo(
start_cursor=startCursor,
end_cursor=endCursor,
has_previous_page=hasPreviousPage,
has_next_page=hasNextPage,
)
class ConnectionOptions(ObjectTypeOptions): class ConnectionOptions(ObjectTypeOptions):
node = None node = None
@ -66,7 +77,7 @@ class Connection(ObjectType):
edge_class = getattr(cls, "Edge", None) edge_class = getattr(cls, "Edge", None)
_node = node _node = node
class EdgeBase(object): class EdgeBase:
node = Field(_node, description="The item at the end of the edge") node = Field(_node, description="The item at the end of the edge")
cursor = String(required=True, description="A cursor for use in pagination") cursor = String(required=True, description="A cursor for use in pagination")
@ -86,31 +97,29 @@ class Connection(ObjectType):
options["name"] = name options["name"] = name
_meta.node = node _meta.node = node
_meta.fields = OrderedDict( _meta.fields = {
[ "page_info": Field(
( PageInfo,
"page_info", name="pageInfo",
Field( required=True,
PageInfo, description="Pagination data for this connection.",
name="pageInfo", ),
required=True, "edges": Field(
description="Pagination data for this connection.", NonNull(List(edge)),
), description="Contains the nodes in this connection.",
), ),
( }
"edges",
Field(
NonNull(List(edge)),
description="Contains the nodes in this connection.",
),
),
]
)
return super(Connection, cls).__init_subclass_with_meta__( return super(Connection, cls).__init_subclass_with_meta__(
_meta=_meta, **options _meta=_meta, **options
) )
# noinspection PyPep8Naming
def connection_adapter(cls, edges, pageInfo):
"""Adapter for creating Connection instances"""
return cls(edges=edges, page_info=pageInfo)
class IterableConnectionField(Field): class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs): def __init__(self, type, *args, **kwargs):
kwargs.setdefault("before", String()) kwargs.setdefault("before", String())
@ -133,7 +142,7 @@ class IterableConnectionField(Field):
) )
assert issubclass(connection_type, Connection), ( assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".' '{} type has to be a subclass of Connection. Received "{}".'
).format(self.__class__.__name__, connection_type) ).format(self.__class__.__name__, connection_type)
return type return type
@ -143,15 +152,15 @@ class IterableConnectionField(Field):
return resolved return resolved
assert isinstance(resolved, Iterable), ( assert isinstance(resolved, Iterable), (
"Resolved value from the connection field have to be iterable or instance of {}. " "Resolved value from the connection field has to be an iterable or instance of {}. "
'Received "{}"' 'Received "{}"'
).format(connection_type, resolved) ).format(connection_type, resolved)
connection = connection_from_list( connection = connection_from_array(
resolved, resolved,
args, args,
connection_type=connection_type, connection_type=partial(connection_adapter, connection_type),
edge_type=connection_type.Edge, edge_type=connection_type.Edge,
pageinfo_type=PageInfo, page_info_type=page_info_adapter,
) )
connection.iterable = resolved connection.iterable = resolved
return connection return connection

View File

@ -1,5 +1,4 @@
import re import re
from collections import OrderedDict
from ..types import Field, InputObjectType, String from ..types import Field, InputObjectType, String
from ..types.mutation import Mutation from ..types.mutation import Mutation
@ -30,12 +29,10 @@ class ClientIDMutation(Mutation):
cls.Input = type( cls.Input = type(
"{}Input".format(base_name), "{}Input".format(base_name),
bases, bases,
OrderedDict( dict(input_fields, client_mutation_id=String(name="clientMutationId")),
input_fields, client_mutation_id=String(name="clientMutationId")
),
) )
arguments = OrderedDict( arguments = dict(
input=cls.Input(required=True) input=cls.Input(required=True)
# 'client_mutation_id': String(name='clientMutationId') # 'client_mutation_id': String(name='clientMutationId')
) )

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from functools import partial from functools import partial
from inspect import isclass from inspect import isclass
@ -72,9 +71,7 @@ class AbstractNode(Interface):
@classmethod @classmethod
def __init_subclass_with_meta__(cls, **options): def __init_subclass_with_meta__(cls, **options):
_meta = InterfaceOptions(cls) _meta = InterfaceOptions(cls)
_meta.fields = OrderedDict( _meta.fields = {"id": GlobalID(cls, description="The ID of the object")}
id=GlobalID(cls, description="The ID of the object.")
)
super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options) super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options)

View File

@ -1,4 +1,4 @@
import pytest from pytest import raises
from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String from ...types import Argument, Field, Int, List, NonNull, ObjectType, Schema, String
from ..connection import Connection, ConnectionField, PageInfo from ..connection import Connection, ConnectionField, PageInfo
@ -24,7 +24,7 @@ def test_connection():
assert MyObjectConnection._meta.name == "MyObjectConnection" assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ["page_info", "edges", "extra"] assert list(fields) == ["page_info", "edges", "extra"]
edge_field = fields["edges"] edge_field = fields["edges"]
pageinfo_field = fields["page_info"] pageinfo_field = fields["page_info"]
@ -39,7 +39,7 @@ def test_connection():
def test_connection_inherit_abstracttype(): def test_connection_inherit_abstracttype():
class BaseConnection(object): class BaseConnection:
extra = String() extra = String()
class MyObjectConnection(BaseConnection, Connection): class MyObjectConnection(BaseConnection, Connection):
@ -48,13 +48,13 @@ def test_connection_inherit_abstracttype():
assert MyObjectConnection._meta.name == "MyObjectConnection" assert MyObjectConnection._meta.name == "MyObjectConnection"
fields = MyObjectConnection._meta.fields fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ["page_info", "edges", "extra"] assert list(fields) == ["page_info", "edges", "extra"]
def test_connection_name(): def test_connection_name():
custom_name = "MyObjectCustomNameConnection" custom_name = "MyObjectCustomNameConnection"
class BaseConnection(object): class BaseConnection:
extra = String() extra = String()
class MyObjectConnection(BaseConnection, Connection): class MyObjectConnection(BaseConnection, Connection):
@ -76,7 +76,7 @@ def test_edge():
Edge = MyObjectConnection.Edge Edge = MyObjectConnection.Edge
assert Edge._meta.name == "MyObjectEdge" assert Edge._meta.name == "MyObjectEdge"
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ["node", "cursor", "other"] assert list(edge_fields) == ["node", "cursor", "other"]
assert isinstance(edge_fields["node"], Field) assert isinstance(edge_fields["node"], Field)
assert edge_fields["node"].type == MyObject assert edge_fields["node"].type == MyObject
@ -86,7 +86,7 @@ def test_edge():
def test_edge_with_bases(): def test_edge_with_bases():
class BaseEdge(object): class BaseEdge:
extra = String() extra = String()
class MyObjectConnection(Connection): class MyObjectConnection(Connection):
@ -99,7 +99,7 @@ def test_edge_with_bases():
Edge = MyObjectConnection.Edge Edge = MyObjectConnection.Edge
assert Edge._meta.name == "MyObjectEdge" assert Edge._meta.name == "MyObjectEdge"
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ["node", "cursor", "extra", "other"] assert list(edge_fields) == ["node", "cursor", "extra", "other"]
assert isinstance(edge_fields["node"], Field) assert isinstance(edge_fields["node"], Field)
assert edge_fields["node"].type == MyObject assert edge_fields["node"].type == MyObject
@ -122,7 +122,7 @@ def test_edge_with_nonnull_node():
def test_pageinfo(): def test_pageinfo():
assert PageInfo._meta.name == "PageInfo" assert PageInfo._meta.name == "PageInfo"
fields = PageInfo._meta.fields fields = PageInfo._meta.fields
assert list(fields.keys()) == [ assert list(fields) == [
"has_next_page", "has_next_page",
"has_previous_page", "has_previous_page",
"start_cursor", "start_cursor",
@ -146,7 +146,7 @@ def test_connectionfield():
def test_connectionfield_node_deprecated(): def test_connectionfield_node_deprecated():
field = ConnectionField(MyObject) field = ConnectionField(MyObject)
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
field.type field.type
assert "ConnectionFields now need a explicit ConnectionType for Nodes." in str( assert "ConnectionFields now need a explicit ConnectionType for Nodes." in str(

View File

@ -1,7 +1,6 @@
from collections import OrderedDict from pytest import mark
from graphql_relay.utils import base64 from graphql_relay.utils import base64
from promise import Promise
from ...types import ObjectType, Schema, String from ...types import ObjectType, Schema, String
from ..connection import Connection, ConnectionField, PageInfo from ..connection import Connection, ConnectionField, PageInfo
@ -25,15 +24,15 @@ class LetterConnection(Connection):
class Query(ObjectType): class Query(ObjectType):
letters = ConnectionField(LetterConnection) letters = ConnectionField(LetterConnection)
connection_letters = ConnectionField(LetterConnection) connection_letters = ConnectionField(LetterConnection)
promise_letters = ConnectionField(LetterConnection) async_letters = ConnectionField(LetterConnection)
node = Node.Field() node = Node.Field()
def resolve_letters(self, info, **args): def resolve_letters(self, info, **args):
return list(letters.values()) return list(letters.values())
def resolve_promise_letters(self, info, **args): async def resolve_async_letters(self, info, **args):
return Promise.resolve(list(letters.values())) return list(letters.values())
def resolve_connection_letters(self, info, **args): def resolve_connection_letters(self, info, **args):
return LetterConnection( return LetterConnection(
@ -46,9 +45,7 @@ class Query(ObjectType):
schema = Schema(Query) schema = Schema(Query)
letters = OrderedDict() letters = {letter: Letter(id=i, letter=letter) for i, letter in enumerate(letter_chars)}
for i, letter in enumerate(letter_chars):
letters[letter] = Letter(id=i, letter=letter)
def edges(selected_letters): def edges(selected_letters):
@ -66,11 +63,11 @@ def cursor_for(ltr):
return base64("arrayconnection:%s" % letter.id) return base64("arrayconnection:%s" % letter.id)
def execute(args=""): async def execute(args=""):
if args: if args:
args = "(" + args + ")" args = "(" + args + ")"
return schema.execute( return await schema.execute_async(
""" """
{ {
letters%s { letters%s {
@ -94,8 +91,8 @@ def execute(args=""):
) )
def check(args, letters, has_previous_page=False, has_next_page=False): async def check(args, letters, has_previous_page=False, has_next_page=False):
result = execute(args) result = await execute(args)
expected_edges = edges(letters) expected_edges = edges(letters)
expected_page_info = { expected_page_info = {
"hasPreviousPage": has_previous_page, "hasPreviousPage": has_previous_page,
@ -110,96 +107,118 @@ def check(args, letters, has_previous_page=False, has_next_page=False):
} }
def test_returns_all_elements_without_filters(): @mark.asyncio
check("", "ABCDE") async def test_returns_all_elements_without_filters():
await check("", "ABCDE")
def test_respects_a_smaller_first(): @mark.asyncio
check("first: 2", "AB", has_next_page=True) async def test_respects_a_smaller_first():
await check("first: 2", "AB", has_next_page=True)
def test_respects_an_overly_large_first(): @mark.asyncio
check("first: 10", "ABCDE") async def test_respects_an_overly_large_first():
await check("first: 10", "ABCDE")
def test_respects_a_smaller_last(): @mark.asyncio
check("last: 2", "DE", has_previous_page=True) async def test_respects_a_smaller_last():
await check("last: 2", "DE", has_previous_page=True)
def test_respects_an_overly_large_last(): @mark.asyncio
check("last: 10", "ABCDE") async def test_respects_an_overly_large_last():
await check("last: 10", "ABCDE")
def test_respects_first_and_after(): @mark.asyncio
check('first: 2, after: "{}"'.format(cursor_for("B")), "CD", has_next_page=True) async def test_respects_first_and_after():
await check(
'first: 2, after: "{}"'.format(cursor_for("B")), "CD", has_next_page=True
)
def test_respects_first_and_after_with_long_first(): @mark.asyncio
check('first: 10, after: "{}"'.format(cursor_for("B")), "CDE") async def test_respects_first_and_after_with_long_first():
await check('first: 10, after: "{}"'.format(cursor_for("B")), "CDE")
def test_respects_last_and_before(): @mark.asyncio
check('last: 2, before: "{}"'.format(cursor_for("D")), "BC", has_previous_page=True) async def test_respects_last_and_before():
await check(
'last: 2, before: "{}"'.format(cursor_for("D")), "BC", has_previous_page=True
)
def test_respects_last_and_before_with_long_last(): @mark.asyncio
check('last: 10, before: "{}"'.format(cursor_for("D")), "ABC") async def test_respects_last_and_before_with_long_last():
await check('last: 10, before: "{}"'.format(cursor_for("D")), "ABC")
def test_respects_first_and_after_and_before_too_few(): @mark.asyncio
check( async def test_respects_first_and_after_and_before_too_few():
await check(
'first: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")), 'first: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BC", "BC",
has_next_page=True, has_next_page=True,
) )
def test_respects_first_and_after_and_before_too_many(): @mark.asyncio
check( async def test_respects_first_and_after_and_before_too_many():
await check(
'first: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")), 'first: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD", "BCD",
) )
def test_respects_first_and_after_and_before_exactly_right(): @mark.asyncio
check( async def test_respects_first_and_after_and_before_exactly_right():
await check(
'first: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")), 'first: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD", "BCD",
) )
def test_respects_last_and_after_and_before_too_few(): @mark.asyncio
check( async def test_respects_last_and_after_and_before_too_few():
await check(
'last: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")), 'last: 2, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"CD", "CD",
has_previous_page=True, has_previous_page=True,
) )
def test_respects_last_and_after_and_before_too_many(): @mark.asyncio
check( async def test_respects_last_and_after_and_before_too_many():
await check(
'last: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")), 'last: 4, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD", "BCD",
) )
def test_respects_last_and_after_and_before_exactly_right(): @mark.asyncio
check( async def test_respects_last_and_after_and_before_exactly_right():
await check(
'last: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")), 'last: 3, after: "{}", before: "{}"'.format(cursor_for("A"), cursor_for("E")),
"BCD", "BCD",
) )
def test_returns_no_elements_if_first_is_0(): @mark.asyncio
check("first: 0", "", has_next_page=True) async def test_returns_no_elements_if_first_is_0():
await check("first: 0", "", has_next_page=True)
def test_returns_all_elements_if_cursors_are_invalid(): @mark.asyncio
check('before: "invalid" after: "invalid"', "ABCDE") async def test_returns_all_elements_if_cursors_are_invalid():
await check('before: "invalid" after: "invalid"', "ABCDE")
def test_returns_all_elements_if_cursors_are_on_the_outside(): @mark.asyncio
check( async def test_returns_all_elements_if_cursors_are_on_the_outside():
await check(
'before: "{}" after: "{}"'.format( 'before: "{}" after: "{}"'.format(
base64("arrayconnection:%s" % 6), base64("arrayconnection:%s" % -1) base64("arrayconnection:%s" % 6), base64("arrayconnection:%s" % -1)
), ),
@ -207,8 +226,9 @@ def test_returns_all_elements_if_cursors_are_on_the_outside():
) )
def test_returns_no_elements_if_cursors_cross(): @mark.asyncio
check( async def test_returns_no_elements_if_cursors_cross():
await check(
'before: "{}" after: "{}"'.format( 'before: "{}" after: "{}"'.format(
base64("arrayconnection:%s" % 2), base64("arrayconnection:%s" % 4) base64("arrayconnection:%s" % 2), base64("arrayconnection:%s" % 4)
), ),
@ -216,8 +236,9 @@ def test_returns_no_elements_if_cursors_cross():
) )
def test_connection_type_nodes(): @mark.asyncio
result = schema.execute( async def test_connection_type_nodes():
result = await schema.execute_async(
""" """
{ {
connectionLetters { connectionLetters {
@ -248,11 +269,12 @@ def test_connection_type_nodes():
} }
def test_connection_promise(): @mark.asyncio
result = schema.execute( async def test_connection_async():
result = await schema.execute_async(
""" """
{ {
promiseLetters(first:1) { asyncLetters(first:1) {
edges { edges {
node { node {
id id
@ -270,7 +292,7 @@ def test_connection_promise():
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
"promiseLetters": { "asyncLetters": {
"edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}], "edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}],
"pageInfo": {"hasPreviousPage": False, "hasNextPage": True}, "pageInfo": {"hasPreviousPage": False, "hasNextPage": True},
} }

View File

@ -17,7 +17,7 @@ class User(ObjectType):
name = String() name = String()
class Info(object): class Info:
def __init__(self, parent_type): def __init__(self, parent_type):
self.parent_type = GrapheneObjectType( self.parent_type = GrapheneObjectType(
graphene_type=parent_type, graphene_type=parent_type,

View File

@ -1,5 +1,4 @@
import pytest from pytest import mark, raises
from promise import Promise
from ...types import ( from ...types import (
ID, ID,
@ -15,7 +14,7 @@ from ...types.scalars import String
from ..mutation import ClientIDMutation from ..mutation import ClientIDMutation
class SharedFields(object): class SharedFields:
shared = String() shared = String()
@ -37,7 +36,7 @@ class SaySomething(ClientIDMutation):
return SaySomething(phrase=str(what)) return SaySomething(phrase=str(what))
class FixedSaySomething(object): class FixedSaySomething:
__slots__ = ("phrase",) __slots__ = ("phrase",)
def __init__(self, phrase): def __init__(self, phrase):
@ -55,15 +54,15 @@ class SaySomethingFixed(ClientIDMutation):
return FixedSaySomething(phrase=str(what)) return FixedSaySomething(phrase=str(what))
class SaySomethingPromise(ClientIDMutation): class SaySomethingAsync(ClientIDMutation):
class Input: class Input:
what = String() what = String()
phrase = String() phrase = String()
@staticmethod @staticmethod
def mutate_and_get_payload(self, info, what, client_mutation_id=None): async def mutate_and_get_payload(self, info, what, client_mutation_id=None):
return Promise.resolve(SaySomething(phrase=str(what))) return SaySomething(phrase=str(what))
# MyEdge = MyNode.Connection.Edge # MyEdge = MyNode.Connection.Edge
@ -81,11 +80,11 @@ class OtherMutation(ClientIDMutation):
@staticmethod @staticmethod
def mutate_and_get_payload( def mutate_and_get_payload(
self, info, shared="", additional_field="", client_mutation_id=None self, info, shared, additional_field, client_mutation_id=None
): ):
edge_type = MyEdge edge_type = MyEdge
return OtherMutation( return OtherMutation(
name=shared + additional_field, name=(shared or "") + (additional_field or ""),
my_node_edge=edge_type(cursor="1", node=MyNode(name="name")), my_node_edge=edge_type(cursor="1", node=MyNode(name="name")),
) )
@ -97,7 +96,7 @@ class RootQuery(ObjectType):
class Mutation(ObjectType): class Mutation(ObjectType):
say = SaySomething.Field() say = SaySomething.Field()
say_fixed = SaySomethingFixed.Field() say_fixed = SaySomethingFixed.Field()
say_promise = SaySomethingPromise.Field() say_async = SaySomethingAsync.Field()
other = OtherMutation.Field() other = OtherMutation.Field()
@ -105,7 +104,7 @@ schema = Schema(query=RootQuery, mutation=Mutation)
def test_no_mutate_and_get_payload(): def test_no_mutate_and_get_payload():
with pytest.raises(AssertionError) as excinfo: with raises(AssertionError) as excinfo:
class MyMutation(ClientIDMutation): class MyMutation(ClientIDMutation):
pass pass
@ -118,12 +117,12 @@ def test_no_mutate_and_get_payload():
def test_mutation(): def test_mutation():
fields = SaySomething._meta.fields fields = SaySomething._meta.fields
assert list(fields.keys()) == ["phrase", "client_mutation_id"] assert list(fields) == ["phrase", "client_mutation_id"]
assert SaySomething._meta.name == "SaySomethingPayload" assert SaySomething._meta.name == "SaySomethingPayload"
assert isinstance(fields["phrase"], Field) assert isinstance(fields["phrase"], Field)
field = SaySomething.Field() field = SaySomething.Field()
assert field.type == SaySomething assert field.type == SaySomething
assert list(field.args.keys()) == ["input"] assert list(field.args) == ["input"]
assert isinstance(field.args["input"], Argument) assert isinstance(field.args["input"], Argument)
assert isinstance(field.args["input"].type, NonNull) assert isinstance(field.args["input"].type, NonNull)
assert field.args["input"].type.of_type == SaySomething.Input assert field.args["input"].type.of_type == SaySomething.Input
@ -136,7 +135,7 @@ def test_mutation_input():
Input = SaySomething.Input Input = SaySomething.Input
assert issubclass(Input, InputObjectType) assert issubclass(Input, InputObjectType)
fields = Input._meta.fields fields = Input._meta.fields
assert list(fields.keys()) == ["what", "client_mutation_id"] assert list(fields) == ["what", "client_mutation_id"]
assert isinstance(fields["what"], InputField) assert isinstance(fields["what"], InputField)
assert fields["what"].type == String assert fields["what"].type == String
assert isinstance(fields["client_mutation_id"], InputField) assert isinstance(fields["client_mutation_id"], InputField)
@ -145,11 +144,11 @@ def test_mutation_input():
def test_subclassed_mutation(): def test_subclassed_mutation():
fields = OtherMutation._meta.fields fields = OtherMutation._meta.fields
assert list(fields.keys()) == ["name", "my_node_edge", "client_mutation_id"] assert list(fields) == ["name", "my_node_edge", "client_mutation_id"]
assert isinstance(fields["name"], Field) assert isinstance(fields["name"], Field)
field = OtherMutation.Field() field = OtherMutation.Field()
assert field.type == OtherMutation assert field.type == OtherMutation
assert list(field.args.keys()) == ["input"] assert list(field.args) == ["input"]
assert isinstance(field.args["input"], Argument) assert isinstance(field.args["input"], Argument)
assert isinstance(field.args["input"].type, NonNull) assert isinstance(field.args["input"].type, NonNull)
assert field.args["input"].type.of_type == OtherMutation.Input assert field.args["input"].type.of_type == OtherMutation.Input
@ -159,7 +158,7 @@ def test_subclassed_mutation_input():
Input = OtherMutation.Input Input = OtherMutation.Input
assert issubclass(Input, InputObjectType) assert issubclass(Input, InputObjectType)
fields = Input._meta.fields fields = Input._meta.fields
assert list(fields.keys()) == ["shared", "additional_field", "client_mutation_id"] assert list(fields) == ["shared", "additional_field", "client_mutation_id"]
assert isinstance(fields["shared"], InputField) assert isinstance(fields["shared"], InputField)
assert fields["shared"].type == String assert fields["shared"].type == String
assert isinstance(fields["additional_field"], InputField) assert isinstance(fields["additional_field"], InputField)
@ -185,12 +184,13 @@ def test_node_query_fixed():
) )
def test_node_query_promise(): @mark.asyncio
executed = schema.execute( async def test_node_query_async():
'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }' executed = await schema.execute_async(
'mutation a { sayAsync(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
) )
assert not executed.errors assert not executed.errors
assert executed.data == {"sayPromise": {"phrase": "hello"}} assert executed.data == {"sayAsync": {"phrase": "hello"}}
def test_edge_query(): def test_edge_query():

View File

@ -1,12 +1,12 @@
from collections import OrderedDict
from graphql_relay import to_global_id from graphql_relay import to_global_id
from graphql.pyutils import dedent
from ...types import ObjectType, Schema, String from ...types import ObjectType, Schema, String
from ..node import Node, is_node from ..node import Node, is_node
class SharedNodeFields(object): class SharedNodeFields:
shared = String() shared = String()
something_else = String() something_else = String()
@ -70,17 +70,13 @@ def test_subclassed_node_query():
% to_global_id("MyOtherNode", 1) % to_global_id("MyOtherNode", 1)
) )
assert not executed.errors assert not executed.errors
assert executed.data == OrderedDict( assert executed.data == {
{ "node": {
"node": OrderedDict( "shared": "1",
[ "extraField": "extra field info.",
("shared", "1"), "somethingElse": "----",
("extraField", "extra field info."),
("somethingElse", "----"),
]
)
} }
) }
def test_node_requesting_non_node(): def test_node_requesting_non_node():
@ -124,7 +120,7 @@ def test_node_field_only_type_wrong():
% Node.to_global_id("MyOtherNode", 1) % Node.to_global_id("MyOtherNode", 1)
) )
assert len(executed.errors) == 1 assert len(executed.errors) == 1
assert str(executed.errors[0]) == "Must receive a MyNode id." assert str(executed.errors[0]).startswith("Must receive a MyNode id.")
assert executed.data == {"onlyNode": None} assert executed.data == {"onlyNode": None}
@ -143,39 +139,48 @@ def test_node_field_only_lazy_type_wrong():
% Node.to_global_id("MyOtherNode", 1) % Node.to_global_id("MyOtherNode", 1)
) )
assert len(executed.errors) == 1 assert len(executed.errors) == 1
assert str(executed.errors[0]) == "Must receive a MyNode id." assert str(executed.errors[0]).startswith("Must receive a MyNode id.")
assert executed.data == {"onlyNodeLazy": None} assert executed.data == {"onlyNodeLazy": None}
def test_str_schema(): def test_str_schema():
assert ( assert str(schema) == dedent(
str(schema) '''
== """ schema {
schema { query: RootQuery
query: RootQuery }
}
type MyNode implements Node { type MyNode implements Node {
id: ID! """The ID of the object"""
name: String id: ID!
} name: String
}
type MyOtherNode implements Node { type MyOtherNode implements Node {
id: ID! """The ID of the object"""
shared: String id: ID!
somethingElse: String shared: String
extraField: String somethingElse: String
} extraField: String
}
interface Node { """An object with an ID"""
id: ID! interface Node {
} """The ID of the object"""
id: ID!
}
type RootQuery { type RootQuery {
first: String first: String
node(id: ID!): Node
onlyNode(id: ID!): MyNode """The ID of the object"""
onlyNodeLazy(id: ID!): MyNode node(id: ID!): Node
}
""".lstrip() """The ID of the object"""
onlyNode(id: ID!): MyNode
"""The ID of the object"""
onlyNodeLazy(id: ID!): MyNode
}
'''
) )

View File

@ -1,4 +1,5 @@
from graphql import graphql from graphql import graphql_sync
from graphql.pyutils import dedent
from ...types import Interface, ObjectType, Schema from ...types import Interface, ObjectType, Schema
from ...types.scalars import Int, String from ...types.scalars import Int, String
@ -15,7 +16,7 @@ class CustomNode(Node):
@staticmethod @staticmethod
def get_node_from_global_id(info, id, only_type=None): def get_node_from_global_id(info, id, only_type=None):
assert info.schema == schema assert info.schema is graphql_schema
if id in user_data: if id in user_data:
return user_data.get(id) return user_data.get(id)
else: else:
@ -23,14 +24,14 @@ class CustomNode(Node):
class BasePhoto(Interface): class BasePhoto(Interface):
width = Int() width = Int(description="The width of the photo in pixels")
class User(ObjectType): class User(ObjectType):
class Meta: class Meta:
interfaces = [CustomNode] interfaces = [CustomNode]
name = String() name = String(description="The full name of the user")
class Photo(ObjectType): class Photo(ObjectType):
@ -48,37 +49,47 @@ class RootQuery(ObjectType):
schema = Schema(query=RootQuery, types=[User, Photo]) schema = Schema(query=RootQuery, types=[User, Photo])
graphql_schema = schema.graphql_schema
def test_str_schema_correct(): def test_str_schema_correct():
assert ( assert str(schema) == dedent(
str(schema) '''
== """schema { schema {
query: RootQuery query: RootQuery
} }
interface BasePhoto { interface BasePhoto {
width: Int """The width of the photo in pixels"""
} width: Int
}
interface Node { interface Node {
id: ID! """The ID of the object"""
} id: ID!
}
type Photo implements Node, BasePhoto { type Photo implements Node & BasePhoto {
id: ID! """The ID of the object"""
width: Int id: ID!
}
type RootQuery { """The width of the photo in pixels"""
node(id: ID!): Node width: Int
} }
type User implements Node { type RootQuery {
id: ID! """The ID of the object"""
name: String node(id: ID!): Node
} }
"""
type User implements Node {
"""The ID of the object"""
id: ID!
"""The full name of the user"""
name: String
}
'''
) )
@ -91,7 +102,7 @@ def test_gets_the_correct_id_for_users():
} }
""" """
expected = {"node": {"id": "1"}} expected = {"node": {"id": "1"}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -105,7 +116,7 @@ def test_gets_the_correct_id_for_photos():
} }
""" """
expected = {"node": {"id": "4"}} expected = {"node": {"id": "4"}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -122,7 +133,7 @@ def test_gets_the_correct_name_for_users():
} }
""" """
expected = {"node": {"id": "1", "name": "John Doe"}} expected = {"node": {"id": "1", "name": "John Doe"}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -139,7 +150,7 @@ def test_gets_the_correct_width_for_photos():
} }
""" """
expected = {"node": {"id": "4", "width": 400}} expected = {"node": {"id": "4", "width": 400}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -154,7 +165,7 @@ def test_gets_the_correct_typename_for_users():
} }
""" """
expected = {"node": {"id": "1", "__typename": "User"}} expected = {"node": {"id": "1", "__typename": "User"}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -169,7 +180,7 @@ def test_gets_the_correct_typename_for_photos():
} }
""" """
expected = {"node": {"id": "4", "__typename": "Photo"}} expected = {"node": {"id": "4", "__typename": "Photo"}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -186,7 +197,7 @@ def test_ignores_photo_fragments_on_user():
} }
""" """
expected = {"node": {"id": "1"}} expected = {"node": {"id": "1"}}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -200,7 +211,7 @@ def test_returns_null_for_bad_ids():
} }
""" """
expected = {"node": None} expected = {"node": None}
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -239,7 +250,7 @@ def test_have_correct_node_interface():
], ],
} }
} }
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -291,6 +302,6 @@ def test_has_correct_node_root_field():
} }
} }
} }
result = graphql(schema, query) result = graphql_sync(graphql_schema, query)
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected

View File

@ -8,24 +8,19 @@ from graphene.types.schema import Schema
def default_format_error(error): def default_format_error(error):
if isinstance(error, GraphQLError): if isinstance(error, GraphQLError):
return format_graphql_error(error) return format_graphql_error(error)
return {"message": str(error)} return {"message": str(error)}
def format_execution_result(execution_result, format_error): def format_execution_result(execution_result, format_error):
if execution_result: if execution_result:
response = {} response = {}
if execution_result.errors: if execution_result.errors:
response["errors"] = [format_error(e) for e in execution_result.errors] response["errors"] = [format_error(e) for e in execution_result.errors]
response["data"] = execution_result.data
if not execution_result.invalid:
response["data"] = execution_result.data
return response return response
class Client(object): class Client:
def __init__(self, schema, format_error=None, **execute_options): def __init__(self, schema, format_error=None, **execute_options):
assert isinstance(schema, Schema) assert isinstance(schema, Schema)
self.schema = schema self.schema = schema

View File

@ -21,7 +21,7 @@ class CreatePostResult(graphene.Union):
class CreatePost(graphene.Mutation): class CreatePost(graphene.Mutation):
class Input: class Arguments:
text = graphene.String(required=True) text = graphene.String(required=True)
result = graphene.Field(CreatePostResult) result = graphene.Field(CreatePostResult)

View File

@ -1,6 +1,6 @@
# https://github.com/graphql-python/graphene/issues/356 # https://github.com/graphql-python/graphene/issues/356
import pytest from pytest import raises
import graphene import graphene
from graphene import relay from graphene import relay
@ -23,10 +23,11 @@ def test_issue():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
things = relay.ConnectionField(MyUnion) things = relay.ConnectionField(MyUnion)
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
graphene.Schema(query=Query) graphene.Schema(query=Query)
assert str(exc_info.value) == ( assert str(exc_info.value) == (
"IterableConnectionField type have to be a subclass of Connection. " "Query fields cannot be resolved:"
'Received "MyUnion".' " IterableConnectionField type has to be a subclass of Connection."
' Received "MyUnion".'
) )

View File

@ -1,5 +1,5 @@
# flake8: noqa # flake8: noqa
from graphql import ResolveInfo from graphql import GraphQLResolveInfo as ResolveInfo
from .objecttype import ObjectType from .objecttype import ObjectType
from .interface import Interface from .interface import Interface

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from itertools import chain from itertools import chain
from .dynamic import Dynamic from .dynamic import Dynamic
@ -81,7 +80,7 @@ def to_arguments(args, extra_args=None):
else: else:
extra_args = [] extra_args = []
iter_arguments = chain(args.items(), extra_args) iter_arguments = chain(args.items(), extra_args)
arguments = OrderedDict() arguments = {}
for default_name, arg in iter_arguments: for default_name, arg in iter_arguments:
if isinstance(arg, Dynamic): if isinstance(arg, Dynamic):
arg = arg.get_type() arg = arg.get_type()

View File

@ -4,7 +4,7 @@ from ..utils.subclass_with_meta import SubclassWithMeta
from ..utils.trim_docstring import trim_docstring from ..utils.trim_docstring import trim_docstring
class BaseOptions(object): class BaseOptions:
name = None # type: str name = None # type: str
description = None # type: str description = None # type: str

View File

@ -1,4 +1,4 @@
class Context(object): class Context:
""" """
Context can be used to make a convenient container for attributes to provide Context can be used to make a convenient container for attributes to provide
for execution for resolvers of a GraphQL operation like a query. for execution for resolvers of a GraphQL operation like a query.

View File

@ -3,7 +3,8 @@ from __future__ import absolute_import
import datetime import datetime
from aniso8601 import parse_date, parse_datetime, parse_time from aniso8601 import parse_date, parse_datetime, parse_time
from graphql.language import ast from graphql.error import INVALID
from graphql.language import StringValueNode
from .scalars import Scalar from .scalars import Scalar
@ -26,7 +27,7 @@ class Date(Scalar):
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node):
if isinstance(node, ast.StringValue): if isinstance(node, StringValueNode):
return cls.parse_value(node.value) return cls.parse_value(node.value)
@staticmethod @staticmethod
@ -37,7 +38,7 @@ class Date(Scalar):
elif isinstance(value, str): elif isinstance(value, str):
return parse_date(value) return parse_date(value)
except ValueError: except ValueError:
return None return INVALID
class DateTime(Scalar): class DateTime(Scalar):
@ -56,7 +57,7 @@ class DateTime(Scalar):
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node):
if isinstance(node, ast.StringValue): if isinstance(node, StringValueNode):
return cls.parse_value(node.value) return cls.parse_value(node.value)
@staticmethod @staticmethod
@ -67,7 +68,7 @@ class DateTime(Scalar):
elif isinstance(value, str): elif isinstance(value, str):
return parse_datetime(value) return parse_datetime(value)
except ValueError: except ValueError:
return None return INVALID
class Time(Scalar): class Time(Scalar):
@ -86,7 +87,7 @@ class Time(Scalar):
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node):
if isinstance(node, ast.StringValue): if isinstance(node, StringValueNode):
return cls.parse_value(node.value) return cls.parse_value(node.value)
@classmethod @classmethod
@ -97,4 +98,4 @@ class Time(Scalar):
elif isinstance(value, str): elif isinstance(value, str):
return parse_time(value) return parse_time(value)
except ValueError: except ValueError:
return None return INVALID

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
from decimal import Decimal as _Decimal from decimal import Decimal as _Decimal
from graphql.language import ast from graphql.language.ast import StringValueNode
from .scalars import Scalar from .scalars import Scalar
@ -23,7 +23,7 @@ class Decimal(Scalar):
@classmethod @classmethod
def parse_literal(cls, node): def parse_literal(cls, node):
if isinstance(node, ast.StringValue): if isinstance(node, StringValueNode):
return cls.parse_value(node.value) return cls.parse_value(node.value)
@staticmethod @staticmethod

View File

@ -8,7 +8,7 @@ from graphql import (
) )
class GrapheneGraphQLType(object): class GrapheneGraphQLType:
""" """
A class for extending the base GraphQLType with the related A class for extending the base GraphQLType with the related
graphene_type graphene_type

View File

@ -1,7 +1,7 @@
from collections import OrderedDict from enum import Enum as PyEnum
from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta
from ..pyutils.compat import Enum as PyEnum
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
@ -22,13 +22,13 @@ class EnumOptions(BaseOptions):
class EnumMeta(SubclassWithMeta_Meta): class EnumMeta(SubclassWithMeta_Meta):
def __new__(cls, name, bases, classdict, **options): def __new__(cls, name, bases, classdict, **options):
enum_members = OrderedDict(classdict, __eq__=eq_enum) enum_members = dict(classdict, __eq__=eq_enum)
# We remove the Meta attribute from the class to not collide # We remove the Meta attribute from the class to not collide
# with the enum values. # with the enum values.
enum_members.pop("Meta", None) enum_members.pop("Meta", None)
enum = PyEnum(cls.__name__, enum_members) enum = PyEnum(cls.__name__, enum_members)
return SubclassWithMeta_Meta.__new__( return SubclassWithMeta_Meta.__new__(
cls, name, bases, OrderedDict(classdict, __enum__=enum), **options cls, name, bases, dict(classdict, __enum__=enum), **options
) )
def get(cls, value): def get(cls, value):
@ -38,7 +38,7 @@ class EnumMeta(SubclassWithMeta_Meta):
return cls._meta.enum[value] return cls._meta.enum[value]
def __prepare__(name, bases, **kwargs): # noqa: N805 def __prepare__(name, bases, **kwargs): # noqa: N805
return OrderedDict() return {}
def __call__(cls, *args, **kwargs): # noqa: N805 def __call__(cls, *args, **kwargs): # noqa: N805
if cls is Enum: if cls is Enum:

View File

@ -1,5 +1,5 @@
import inspect import inspect
from collections import Mapping, OrderedDict from collections.abc import Mapping
from functools import partial from functools import partial
from .argument import Argument, to_arguments from .argument import Argument, to_arguments
@ -100,7 +100,7 @@ class Field(MountedType):
self.name = name self.name = name
self._type = type self._type = type
self.args = to_arguments(args or OrderedDict(), extra_args) self.args = to_arguments(args or {}, extra_args)
if source: if source:
resolver = partial(source_resolver, source) resolver = partial(source_resolver, source)
self.resolver = resolver self.resolver = resolver

View File

@ -1,12 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from graphql.language.ast import ( from graphql.language.ast import (
BooleanValue, BooleanValueNode,
FloatValue, FloatValueNode,
IntValue, IntValueNode,
ListValue, ListValueNode,
ObjectValue, ObjectValueNode,
StringValue, StringValueNode,
) )
from graphene.types.scalars import MAX_INT, MIN_INT from graphene.types.scalars import MAX_INT, MIN_INT
@ -30,17 +30,17 @@ class GenericScalar(Scalar):
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast):
if isinstance(ast, (StringValue, BooleanValue)): if isinstance(ast, (StringValueNode, BooleanValueNode)):
return ast.value return ast.value
elif isinstance(ast, IntValue): elif isinstance(ast, IntValueNode):
num = int(ast.value) num = int(ast.value)
if MIN_INT <= num <= MAX_INT: if MIN_INT <= num <= MAX_INT:
return num return num
elif isinstance(ast, FloatValue): elif isinstance(ast, FloatValueNode):
return float(ast.value) return float(ast.value)
elif isinstance(ast, ListValue): elif isinstance(ast, ListValueNode):
return [GenericScalar.parse_literal(value) for value in ast.values] return [GenericScalar.parse_literal(value) for value in ast.values]
elif isinstance(ast, ObjectValue): elif isinstance(ast, ObjectValueNode):
return { return {
field.name.value: GenericScalar.parse_literal(field.value) field.name.value: GenericScalar.parse_literal(field.value)
for field in ast.fields for field in ast.fields

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .inputfield import InputField from .inputfield import InputField
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
@ -22,7 +20,7 @@ class InputObjectTypeContainer(dict, BaseType):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs) dict.__init__(self, *args, **kwargs)
for key in self._meta.fields.keys(): for key in self._meta.fields:
setattr(self, key, self.get(key, None)) setattr(self, key, self.get(key, None))
def __init_subclass__(cls, *args, **kwargs): def __init_subclass__(cls, *args, **kwargs):
@ -70,7 +68,7 @@ class InputObjectType(UnmountedType, BaseType):
if not _meta: if not _meta:
_meta = InputObjectTypeOptions(cls) _meta = InputObjectTypeOptions(cls)
fields = OrderedDict() fields = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=InputField)) fields.update(yank_fields_from_attrs(base.__dict__, _as=InputField))

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .field import Field from .field import Field
from .utils import yank_fields_from_attrs from .utils import yank_fields_from_attrs
@ -51,7 +49,7 @@ class Interface(BaseType):
if not _meta: if not _meta:
_meta = InterfaceOptions(cls) _meta = InterfaceOptions(cls)
fields = OrderedDict() fields = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field)) fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
import json import json
from graphql.language import ast from graphql.language.ast import StringValueNode
from .scalars import Scalar from .scalars import Scalar
@ -21,7 +21,7 @@ class JSONString(Scalar):
@staticmethod @staticmethod
def parse_literal(node): def parse_literal(node):
if isinstance(node, ast.StringValue): if isinstance(node, StringValueNode):
return json.loads(node.value) return json.loads(node.value)
@staticmethod @staticmethod

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from ..utils.deprecated import warn_deprecation from ..utils.deprecated import warn_deprecation
from ..utils.get_unbound_function import get_unbound_function from ..utils.get_unbound_function import get_unbound_function
from ..utils.props import props from ..utils.props import props
@ -90,7 +88,7 @@ class Mutation(ObjectType):
if not output: if not output:
# If output is defined, we don't need to get the fields # If output is defined, we don't need to get the fields
fields = OrderedDict() fields = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
fields.update(yank_fields_from_attrs(base.__dict__, _as=Field)) fields.update(yank_fields_from_attrs(base.__dict__, _as=Field))
output = cls output = cls
@ -103,7 +101,7 @@ class Mutation(ObjectType):
warn_deprecation( warn_deprecation(
( (
"Please use {name}.Arguments instead of {name}.Input." "Please use {name}.Arguments instead of {name}.Input."
"Input is now only used in ClientMutationID.\n" " Input is now only used in ClientMutationID.\n"
"Read more:" "Read more:"
" https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input" " https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#mutation-input"
).format(name=cls.__name__) ).format(name=cls.__name__)

View File

@ -1,5 +1,3 @@
from collections import OrderedDict
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .field import Field from .field import Field
from .interface import Interface from .interface import Interface
@ -100,7 +98,7 @@ class ObjectType(BaseType):
if not _meta: if not _meta:
_meta = ObjectTypeOptions(cls) _meta = ObjectTypeOptions(cls)
fields = OrderedDict() fields = {}
for interface in interfaces: for interface in interfaces:
assert issubclass(interface, Interface), ( assert issubclass(interface, Interface), (

View File

@ -1,6 +1,11 @@
from typing import Any from typing import Any
from graphql.language.ast import BooleanValue, FloatValue, IntValue, StringValue from graphql.language.ast import (
BooleanValueNode,
FloatValueNode,
IntValueNode,
StringValueNode,
)
from .base import BaseOptions, BaseType from .base import BaseOptions, BaseType
from .unmountedtype import UnmountedType from .unmountedtype import UnmountedType
@ -71,7 +76,7 @@ class Int(Scalar):
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast):
if isinstance(ast, IntValue): if isinstance(ast, IntValueNode):
num = int(ast.value) num = int(ast.value)
if MIN_INT <= num <= MAX_INT: if MIN_INT <= num <= MAX_INT:
return num return num
@ -97,7 +102,7 @@ class Float(Scalar):
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast):
if isinstance(ast, (FloatValue, IntValue)): if isinstance(ast, (FloatValueNode, IntValueNode)):
return float(ast.value) return float(ast.value)
@ -119,7 +124,7 @@ class String(Scalar):
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast):
if isinstance(ast, StringValue): if isinstance(ast, StringValueNode):
return ast.value return ast.value
@ -133,7 +138,7 @@ class Boolean(Scalar):
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast):
if isinstance(ast, BooleanValue): if isinstance(ast, BooleanValueNode):
return ast.value return ast.value
@ -151,5 +156,5 @@ class ID(Scalar):
@staticmethod @staticmethod
def parse_literal(ast): def parse_literal(ast):
if isinstance(ast, (StringValue, IntValue)): if isinstance(ast, (StringValueNode, IntValueNode)):
return ast.value return ast.value

View File

@ -1,104 +1,133 @@
import inspect import inspect
from functools import partial
from graphql import GraphQLObjectType, GraphQLSchema, graphql, is_type from graphql import (
from graphql.type.directives import ( default_type_resolver,
GraphQLDirective, get_introspection_query,
GraphQLIncludeDirective, graphql,
GraphQLSkipDirective, graphql_sync,
introspection_types,
is_type,
print_schema,
GraphQLArgument,
GraphQLBoolean,
GraphQLEnumValue,
GraphQLField,
GraphQLFloat,
GraphQLID,
GraphQLInputField,
GraphQLInt,
GraphQLList,
GraphQLNonNull,
GraphQLObjectType,
GraphQLSchema,
GraphQLString,
INVALID,
) )
from graphql.type.introspection import IntrospectionSchema
from graphql.utils.introspection_query import introspection_query
from graphql.utils.schema_printer import print_schema
from .definitions import GrapheneGraphQLType from ..utils.str_converters import to_camel_case
from ..utils.get_unbound_function import get_unbound_function
from .definitions import (
GrapheneEnumType,
GrapheneGraphQLType,
GrapheneInputObjectType,
GrapheneInterfaceType,
GrapheneObjectType,
GrapheneScalarType,
GrapheneUnionType,
)
from .dynamic import Dynamic
from .enum import Enum
from .field import Field
from .inputobjecttype import InputObjectType
from .interface import Interface
from .objecttype import ObjectType from .objecttype import ObjectType
from .typemap import TypeMap, is_graphene_type from .resolver import get_default_resolver
from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull
from .union import Union
from .utils import get_field_as
introspection_query = get_introspection_query()
IntrospectionSchema = introspection_types["__Schema"]
def assert_valid_root_type(_type): def assert_valid_root_type(type_):
if _type is None: if type_ is None:
return return
is_graphene_objecttype = inspect.isclass(_type) and issubclass(_type, ObjectType) is_graphene_objecttype = inspect.isclass(type_) and issubclass(type_, ObjectType)
is_graphql_objecttype = isinstance(_type, GraphQLObjectType) is_graphql_objecttype = isinstance(type_, GraphQLObjectType)
assert is_graphene_objecttype or is_graphql_objecttype, ( assert is_graphene_objecttype or is_graphql_objecttype, (
"Type {} is not a valid ObjectType." "Type {} is not a valid ObjectType."
).format(_type) ).format(type_)
class Schema(GraphQLSchema): def is_graphene_type(type_):
""" if isinstance(type_, (List, NonNull)):
Graphene Schema can execute operations (query, mutation, subscription) against the defined return True
types. if inspect.isclass(type_) and issubclass(
type_, (ObjectType, InputObjectType, Scalar, Interface, Union, Enum)
):
return True
For advanced purposes, the schema can be used to lookup type definitions and answer questions
about the types through introspection.
Args: def resolve_type(resolve_type_func, map_, type_name, root, info, _type):
query (ObjectType): Root query *ObjectType*. Describes entry point for fields to *read* type_ = resolve_type_func(root, info)
data in your Schema.
mutation (ObjectType, optional): Root mutation *ObjectType*. Describes entry point for if not type_:
fields to *create, update or delete* data in your API. return_type = map_[type_name]
subscription (ObjectType, optional): Root subscription *ObjectType*. Describes entry point return default_type_resolver(root, info, return_type)
for fields to receive continuous updates.
directives (List[GraphQLDirective], optional): List of custom directives to include in if inspect.isclass(type_) and issubclass(type_, ObjectType):
GraphQL schema. Defaults to only include directives definved by GraphQL spec (@include graphql_type = map_.get(type_._meta.name)
and @skip) [GraphQLIncludeDirective, GraphQLSkipDirective]. assert graphql_type, "Can't find type {} in schema".format(type_._meta.name)
types (List[GraphQLType], optional): List of any types to include in schema that assert graphql_type.graphene_type == type_, (
may not be introspected through root types. "The type {} does not match with the associated graphene type {}."
auto_camelcase (bool): Fieldnames will be transformed in Schema's TypeMap from snake_case ).format(type_, graphql_type.graphene_type)
to camelCase (preferred by GraphQL standard). Default True. return graphql_type
"""
return type_
def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types)
class GrapheneGraphQLSchema(GraphQLSchema):
"""A GraphQLSchema that can deal with Graphene types as well."""
def __init__( def __init__(
self, self,
query=None, query=None,
mutation=None, mutation=None,
subscription=None, subscription=None,
directives=None,
types=None, types=None,
directives=None,
auto_camelcase=True, auto_camelcase=True,
): ):
assert_valid_root_type(query) assert_valid_root_type(query)
assert_valid_root_type(mutation) assert_valid_root_type(mutation)
assert_valid_root_type(subscription) assert_valid_root_type(subscription)
self._query = query
self._mutation = mutation
self._subscription = subscription
self.types = types
self.auto_camelcase = auto_camelcase self.auto_camelcase = auto_camelcase
if directives is None: super().__init__(query, mutation, subscription, types, directives)
directives = [GraphQLIncludeDirective, GraphQLSkipDirective]
assert all( if query:
isinstance(d, GraphQLDirective) for d in directives self.query_type = self.get_type(
), "Schema directives must be List[GraphQLDirective] if provided but got: {}.".format( query.name if isinstance(query, GraphQLObjectType) else query._meta.name
directives )
) if mutation:
self._directives = directives self.mutation_type = self.get_type(
self.build_typemap() mutation.name
if isinstance(mutation, GraphQLObjectType)
def get_query_type(self): else mutation._meta.name
return self.get_graphql_type(self._query) )
if subscription:
def get_mutation_type(self): self.subscription_type = self.get_type(
return self.get_graphql_type(self._mutation) subscription.name
if isinstance(subscription, GraphQLObjectType)
def get_subscription_type(self): else subscription._meta.name
return self.get_graphql_type(self._subscription) )
def __getattr__(self, type_name):
"""
This function let the developer select a type in a given schema
by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema
"""
_type = super(Schema, self).get_type(type_name)
if _type is None:
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type
return _type
def get_graphql_type(self, _type): def get_graphql_type(self, _type):
if not _type: if not _type:
@ -114,56 +143,383 @@ class Schema(GraphQLSchema):
return graphql_type return graphql_type
raise Exception("{} is not a valid GraphQL type.".format(_type)) raise Exception("{} is not a valid GraphQL type.".format(_type))
def execute(self, *args, **kwargs): # noinspection PyMethodOverriding
""" def type_map_reducer(self, map_, type_):
Use the `graphql` function from `graphql-core` to provide the result for a query string. if not type_:
Most of the time this method will be called by one of the Graphene :ref:`Integrations` return map_
via a web request. if inspect.isfunction(type_):
type_ = type_()
if is_graphene_type(type_):
return self.graphene_reducer(map_, type_)
return super().type_map_reducer(map_, type_)
Args: def graphene_reducer(self, map_, type_):
request_string (str or Document): GraphQL request (query, mutation or subscription) in if isinstance(type_, (List, NonNull)):
string or parsed AST form from `graphql-core`. return self.type_map_reducer(map_, type_.of_type)
root (Any, optional): Value to use as the parent value object when resolving root if type_._meta.name in map_:
types. _type = map_[type_._meta.name]
context (Any, optional): Value to be made avaiable to all resolvers via if isinstance(_type, GrapheneGraphQLType):
`info.context`. Can be used to share authorization, dataloaders or other assert _type.graphene_type == type_, (
information needed to resolve an operation. "Found different types with the same name in the schema: {}, {}."
variables (dict, optional): If variables are used in the request string, they can be ).format(_type.graphene_type, type_)
provided in dictionary form mapping the variable name to the variable value. return map_
operation_name (str, optional): If mutiple operations are provided in the
request_string, an operation name must be provided for the result to be provided.
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
defined in `graphql-core`.
backend (GraphQLCoreBackend, optional): Override the default GraphQLCoreBackend.
**execute_options (Any): Depends on backend selected. Default backend has several
options such as: validate, allow_subscriptions, return_promise, executor.
Returns: if issubclass(type_, ObjectType):
:obj:`ExecutionResult` containing any data and errors for the operation. internal_type = self.construct_objecttype(map_, type_)
""" elif issubclass(type_, InputObjectType):
return graphql(self, *args, **kwargs) internal_type = self.construct_inputobjecttype(map_, type_)
elif issubclass(type_, Interface):
internal_type = self.construct_interface(map_, type_)
elif issubclass(type_, Scalar):
internal_type = self.construct_scalar(type_)
elif issubclass(type_, Enum):
internal_type = self.construct_enum(type_)
elif issubclass(type_, Union):
internal_type = self.construct_union(map_, type_)
else:
raise Exception("Expected Graphene type, but received: {}.".format(type_))
def introspect(self): return super().type_map_reducer(map_, internal_type)
instrospection = self.execute(introspection_query)
if instrospection.errors: @staticmethod
raise instrospection.errors[0] def construct_scalar(type_):
return instrospection.data # We have a mapping to the original GraphQL types
# so there are no collisions.
_scalars = {
String: GraphQLString,
Int: GraphQLInt,
Float: GraphQLFloat,
Boolean: GraphQLBoolean,
ID: GraphQLID,
}
if type_ in _scalars:
return _scalars[type_]
return GrapheneScalarType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
serialize=getattr(type_, "serialize", None),
parse_value=getattr(type_, "parse_value", None),
parse_literal=getattr(type_, "parse_literal", None),
)
@staticmethod
def construct_enum(type_):
values = {}
for name, value in type_._meta.enum.__members__.items():
description = getattr(value, "description", None)
deprecation_reason = getattr(value, "deprecation_reason", None)
if not description and callable(type_._meta.description):
description = type_._meta.description(value)
if not deprecation_reason and callable(type_._meta.deprecation_reason):
deprecation_reason = type_._meta.deprecation_reason(value)
values[name] = GraphQLEnumValue(
value=value.value,
description=description,
deprecation_reason=deprecation_reason,
)
type_description = (
type_._meta.description(None)
if callable(type_._meta.description)
else type_._meta.description
)
return GrapheneEnumType(
graphene_type=type_,
values=values,
name=type_._meta.name,
description=type_description,
)
def construct_objecttype(self, map_, type_):
if type_._meta.name in map_:
_type = map_[type_._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type_, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type_)
return _type
def interfaces():
interfaces = []
for interface in type_._meta.interfaces:
self.graphene_reducer(map_, interface)
internal_type = map_[interface._meta.name]
assert internal_type.graphene_type == interface
interfaces.append(internal_type)
return interfaces
if type_._meta.possible_types:
is_type_of = partial(
is_type_of_from_possible_types, type_._meta.possible_types
)
else:
is_type_of = type_.is_type_of
return GrapheneObjectType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
fields=partial(self.construct_fields_for_type, map_, type_),
is_type_of=is_type_of,
interfaces=interfaces,
)
def construct_interface(self, map_, type_):
if type_._meta.name in map_:
_type = map_[type_._meta.name]
if isinstance(_type, GrapheneInterfaceType):
assert _type.graphene_type == type_, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type_)
return _type
_resolve_type = None
if type_.resolve_type:
_resolve_type = partial(
resolve_type, type_.resolve_type, map_, type_._meta.name
)
return GrapheneInterfaceType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
fields=partial(self.construct_fields_for_type, map_, type_),
resolve_type=_resolve_type,
)
def construct_inputobjecttype(self, map_, type_):
return GrapheneInputObjectType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
out_type=type_._meta.container,
fields=partial(
self.construct_fields_for_type, map_, type_, is_input_type=True
),
)
def construct_union(self, map_, type_):
_resolve_type = None
if type_.resolve_type:
_resolve_type = partial(
resolve_type, type_.resolve_type, map_, type_._meta.name
)
def types():
union_types = []
for objecttype in type_._meta.types:
self.graphene_reducer(map_, objecttype)
internal_type = map_[objecttype._meta.name]
assert internal_type.graphene_type == objecttype
union_types.append(internal_type)
return union_types
return GrapheneUnionType(
graphene_type=type_,
name=type_._meta.name,
description=type_._meta.description,
types=types,
resolve_type=_resolve_type,
)
def get_name(self, name):
if self.auto_camelcase:
return to_camel_case(name)
return name
def construct_fields_for_type(self, map_, type_, is_input_type=False):
fields = {}
for name, field in type_._meta.fields.items():
if isinstance(field, Dynamic):
field = get_field_as(field.get_type(self), _as=Field)
if not field:
continue
map_ = self.type_map_reducer(map_, field.type)
field_type = self.get_field_type(map_, field.type)
if is_input_type:
_field = GraphQLInputField(
field_type,
default_value=field.default_value,
out_name=name,
description=field.description,
)
else:
args = {}
for arg_name, arg in field.args.items():
map_ = self.type_map_reducer(map_, arg.type)
arg_type = self.get_field_type(map_, arg.type)
processed_arg_name = arg.name or self.get_name(arg_name)
args[processed_arg_name] = GraphQLArgument(
arg_type,
out_name=arg_name,
description=arg.description,
default_value=INVALID
if isinstance(arg.type, NonNull)
else arg.default_value,
)
_field = GraphQLField(
field_type,
args=args,
resolve=field.get_resolver(
self.get_resolver_for_type(type_, name, field.default_value)
),
deprecation_reason=field.deprecation_reason,
description=field.description,
)
field_name = field.name or self.get_name(name)
fields[field_name] = _field
return fields
def get_resolver_for_type(self, type_, name, default_value):
if not issubclass(type_, ObjectType):
return
resolver = getattr(type_, "resolve_{}".format(name), None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in type_._meta.interfaces:
if name not in interface._meta.fields:
continue
interface_resolver = getattr(interface, "resolve_{}".format(name), None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver:
return get_unbound_function(resolver)
default_resolver = type_._meta.default_resolver or get_default_resolver()
return partial(default_resolver, name, default_value)
def get_field_type(self, map_, type_):
if isinstance(type_, List):
return GraphQLList(self.get_field_type(map_, type_.of_type))
if isinstance(type_, NonNull):
return GraphQLNonNull(self.get_field_type(map_, type_.of_type))
return map_.get(type_._meta.name)
class Schema:
"""Schema Definition.
A Graphene Schema can execute operations (query, mutation, subscription) against the defined
types. For advanced purposes, the schema can be used to lookup type definitions and answer
questions about the types through introspection.
Args:
query (ObjectType): Root query *ObjectType*. Describes entry point for fields to *read*
data in your Schema.
mutation (ObjectType, optional): Root mutation *ObjectType*. Describes entry point for
fields to *create, update or delete* data in your API.
subscription (ObjectType, optional): Root subscription *ObjectType*. Describes entry point
for fields to receive continuous updates.
directives (List[GraphQLDirective], optional): List of custom directives to include in the
GraphQL schema. Defaults to only include directives defined by GraphQL spec (@include
and @skip) [GraphQLIncludeDirective, GraphQLSkipDirective].
types (List[GraphQLType], optional): List of any types to include in schema that
may not be introspected through root types.
auto_camelcase (bool): Fieldnames will be transformed in Schema's TypeMap from snake_case
to camelCase (preferred by GraphQL standard). Default True.
"""
def __init__(
self,
query=None,
mutation=None,
subscription=None,
types=None,
directives=None,
auto_camelcase=True,
):
self.query = query
self.mutation = mutation
self.subscription = subscription
self.graphql_schema = GrapheneGraphQLSchema(
query,
mutation,
subscription,
types,
directives,
auto_camelcase=auto_camelcase,
)
def __str__(self): def __str__(self):
return print_schema(self) return print_schema(self.graphql_schema)
def __getattr__(self, type_name):
"""
This function let the developer select a type in a given schema
by accessing its attrs.
Example: using schema.Query for accessing the "Query" type in the Schema
"""
_type = self.graphql_schema.get_type(type_name)
if _type is None:
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
if isinstance(_type, GrapheneGraphQLType):
return _type.graphene_type
return _type
def lazy(self, _type): def lazy(self, _type):
return lambda: self.get_type(_type) return lambda: self.get_type(_type)
def build_typemap(self): def execute(self, *args, **kwargs):
initial_types = [ """Execute a GraphQL query on the schema.
self._query,
self._mutation, Use the `graphql_sync` function from `graphql-core` to provide the result
self._subscription, for a query string. Most of the time this method will be called by one of the Graphene
IntrospectionSchema, :ref:`Integrations` via a web request.
]
if self.types: Args:
initial_types += self.types request_string (str or Document): GraphQL request (query, mutation or subscription)
self._type_map = TypeMap( as string or parsed AST form from `graphql-core`.
initial_types, auto_camelcase=self.auto_camelcase, schema=self root_value (Any, optional): Value to use as the parent value object when resolving
) root types.
context_value (Any, optional): Value to be made avaiable to all resolvers via
`info.context`. Can be used to share authorization, dataloaders or other
information needed to resolve an operation.
variable_values (dict, optional): If variables are used in the request string, they can
be provided in dictionary form mapping the variable name to the variable value.
operation_name (str, optional): If multiple operations are provided in the
request_string, an operation name must be provided for the result to be provided.
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
defined in `graphql-core`.
Returns:
:obj:`ExecutionResult` containing any data and errors for the operation.
"""
kwargs = normalize_execute_kwargs(kwargs)
return graphql_sync(self.graphql_schema, *args, **kwargs)
async def execute_async(self, *args, **kwargs):
"""Execute a GraphQL query on the schema asynchronously.
Same as `execute`, but uses `graphql` instead of `graphql_sync`.
"""
kwargs = normalize_execute_kwargs(kwargs)
return await graphql(self.graphql_schema, *args, **kwargs)
def introspect(self):
introspection = self.execute(introspection_query)
if introspection.errors:
raise introspection.errors[0]
return introspection.data
def normalize_execute_kwargs(kwargs):
"""Replace alias names in keyword arguments for graphql()"""
if "root" in kwargs and "root_value" not in kwargs:
kwargs["root_value"] = kwargs.pop("root")
if "context" in kwargs and "context_value" not in kwargs:
kwargs["context_value"] = kwargs.pop("context")
if "variables" in kwargs and "variable_values" not in kwargs:
kwargs["variable_values"] = kwargs.pop("variables")
if "operation" in kwargs and "operation_name" not in kwargs:
kwargs["operation_name"] = kwargs.pop("operation")
return kwargs

View File

@ -1,4 +1,5 @@
from .. import abstracttype from pytest import deprecated_call
from ..abstracttype import AbstractType from ..abstracttype import AbstractType
from ..field import Field from ..field import Field
from ..objecttype import ObjectType from ..objecttype import ObjectType
@ -14,24 +15,25 @@ class MyScalar(UnmountedType):
return MyType return MyType
def test_abstract_objecttype_warn_deprecation(mocker): def test_abstract_objecttype_warn_deprecation():
mocker.patch.object(abstracttype, "warn_deprecation") with deprecated_call():
class MyAbstractType(AbstractType): # noinspection PyUnusedLocal
field1 = MyScalar() class MyAbstractType(AbstractType):
field1 = MyScalar()
assert abstracttype.warn_deprecation.called
def test_generate_objecttype_inherit_abstracttype(): def test_generate_objecttype_inherit_abstracttype():
class MyAbstractType(AbstractType): with deprecated_call():
field1 = MyScalar()
class MyObjectType(ObjectType, MyAbstractType): class MyAbstractType(AbstractType):
field2 = MyScalar() field1 = MyScalar()
class MyObjectType(ObjectType, MyAbstractType):
field2 = MyScalar()
assert MyObjectType._meta.description is None assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == () assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"] assert list(MyObjectType._meta.fields) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field] assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]

View File

@ -1,6 +1,6 @@
from functools import partial from functools import partial
import pytest from pytest import raises
from ..argument import Argument, to_arguments from ..argument import Argument, to_arguments
from ..field import Field from ..field import Field
@ -43,7 +43,7 @@ def test_to_arguments():
def test_to_arguments_raises_if_field(): def test_to_arguments_raises_if_field():
args = {"arg_string": Field(String)} args = {"arg_string": Field(String)}
with pytest.raises(ValueError) as exc_info: with raises(ValueError) as exc_info:
to_arguments(args) to_arguments(args)
assert str(exc_info.value) == ( assert str(exc_info.value) == (
@ -55,7 +55,7 @@ def test_to_arguments_raises_if_field():
def test_to_arguments_raises_if_inputfield(): def test_to_arguments_raises_if_inputfield():
args = {"arg_string": InputField(String)} args = {"arg_string": InputField(String)}
with pytest.raises(ValueError) as exc_info: with raises(ValueError) as exc_info:
to_arguments(args) to_arguments(args)
assert str(exc_info.value) == ( assert str(exc_info.value) == (

View File

@ -2,7 +2,8 @@ import datetime
import pytz import pytz
from graphql import GraphQLError from graphql import GraphQLError
import pytest
from pytest import fixture, mark
from ..datetime import Date, DateTime, Time from ..datetime import Date, DateTime, Time
from ..objecttype import ObjectType from ..objecttype import ObjectType
@ -27,13 +28,13 @@ class Query(ObjectType):
schema = Schema(query=Query) schema = Schema(query=Query)
@pytest.fixture @fixture
def sample_datetime(): def sample_datetime():
utc_datetime = datetime.datetime(2019, 5, 25, 5, 30, 15, 10, pytz.utc) utc_datetime = datetime.datetime(2019, 5, 25, 5, 30, 15, 10, pytz.utc)
return utc_datetime return utc_datetime
@pytest.fixture @fixture
def sample_time(sample_datetime): def sample_time(sample_datetime):
time = datetime.time( time = datetime.time(
sample_datetime.hour, sample_datetime.hour,
@ -45,7 +46,7 @@ def sample_time(sample_datetime):
return time return time
@pytest.fixture @fixture
def sample_date(sample_datetime): def sample_date(sample_datetime):
date = sample_datetime.date() date = sample_datetime.date()
return date return date
@ -76,12 +77,16 @@ def test_time_query(sample_time):
def test_bad_datetime_query(): def test_bad_datetime_query():
not_a_date = "Some string that's not a date" not_a_date = "Some string that's not a datetime"
result = schema.execute("""{ datetime(in: "%s") }""" % not_a_date) result = schema.execute("""{ datetime(in: "%s") }""" % not_a_date)
assert len(result.errors) == 1 assert result.errors and len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError) error = result.errors[0]
assert isinstance(error, GraphQLError)
assert error.message == (
'Expected type DateTime, found "Some string that\'s not a datetime".'
)
assert result.data is None assert result.data is None
@ -90,18 +95,24 @@ def test_bad_date_query():
result = schema.execute("""{ date(in: "%s") }""" % not_a_date) result = schema.execute("""{ date(in: "%s") }""" % not_a_date)
assert len(result.errors) == 1 error = result.errors[0]
assert isinstance(result.errors[0], GraphQLError) assert isinstance(error, GraphQLError)
assert error.message == (
'Expected type Date, found "Some string that\'s not a date".'
)
assert result.data is None assert result.data is None
def test_bad_time_query(): def test_bad_time_query():
not_a_date = "Some string that's not a date" not_a_date = "Some string that's not a time"
result = schema.execute("""{ time(at: "%s") }""" % not_a_date) result = schema.execute("""{ time(at: "%s") }""" % not_a_date)
assert len(result.errors) == 1 error = result.errors[0]
assert isinstance(result.errors[0], GraphQLError) assert isinstance(error, GraphQLError)
assert error.message == (
'Expected type Time, found "Some string that\'s not a time".'
)
assert result.data is None assert result.data is None
@ -163,7 +174,7 @@ def test_time_query_variable(sample_time):
assert result.data == {"time": isoformat} assert result.data == {"time": isoformat}
@pytest.mark.xfail( @mark.xfail(
reason="creating the error message fails when un-parsable object is not JSON serializable." reason="creating the error message fails when un-parsable object is not JSON serializable."
) )
def test_bad_variables(sample_date, sample_datetime, sample_time): def test_bad_variables(sample_date, sample_datetime, sample_time):
@ -174,11 +185,11 @@ def test_bad_variables(sample_date, sample_datetime, sample_time):
), ),
variables={"input": input_}, variables={"input": input_},
) )
assert len(result.errors) == 1
# when `input` is not JSON serializable formatting the error message in # when `input` is not JSON serializable formatting the error message in
# `graphql.utils.is_valid_value` line 79 fails with a TypeError # `graphql.utils.is_valid_value` line 79 fails with a TypeError
assert isinstance(result.errors, list)
assert len(result.errors) == 1
assert isinstance(result.errors[0], GraphQLError) assert isinstance(result.errors[0], GraphQLError)
print(result.errors[0])
assert result.data is None assert result.data is None
not_a_date = dict() not_a_date = dict()

View File

@ -69,7 +69,8 @@ class MyInputObjectType(InputObjectType):
def test_defines_a_query_only_schema(): def test_defines_a_query_only_schema():
blog_schema = Schema(Query) blog_schema = Schema(Query)
assert blog_schema.get_query_type().graphene_type == Query assert blog_schema.query == Query
assert blog_schema.graphql_schema.query_type.graphene_type == Query
article_field = Query._meta.fields["article"] article_field = Query._meta.fields["article"]
assert article_field.type == Article assert article_field.type == Article
@ -95,7 +96,8 @@ def test_defines_a_query_only_schema():
def test_defines_a_mutation_schema(): def test_defines_a_mutation_schema():
blog_schema = Schema(Query, mutation=Mutation) blog_schema = Schema(Query, mutation=Mutation)
assert blog_schema.get_mutation_type().graphene_type == Mutation assert blog_schema.mutation == Mutation
assert blog_schema.graphql_schema.mutation_type.graphene_type == Mutation
write_mutation = Mutation._meta.fields["write_article"] write_mutation = Mutation._meta.fields["write_article"]
assert write_mutation.type == Article assert write_mutation.type == Article
@ -105,7 +107,8 @@ def test_defines_a_mutation_schema():
def test_defines_a_subscription_schema(): def test_defines_a_subscription_schema():
blog_schema = Schema(Query, subscription=Subscription) blog_schema = Schema(Query, subscription=Subscription)
assert blog_schema.get_subscription_type().graphene_type == Subscription assert blog_schema.subscription == Subscription
assert blog_schema.graphql_schema.subscription_type.graphene_type == Subscription
subscription = Subscription._meta.fields["article_subscribe"] subscription = Subscription._meta.fields["article_subscribe"]
assert subscription.type == Article assert subscription.type == Article
@ -126,8 +129,9 @@ def test_includes_nested_input_objects_in_the_map():
subscribe_to_something = Field(Article, input=Argument(SomeInputObject)) subscribe_to_something = Field(Article, input=Argument(SomeInputObject))
schema = Schema(query=Query, mutation=SomeMutation, subscription=SomeSubscription) schema = Schema(query=Query, mutation=SomeMutation, subscription=SomeSubscription)
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["NestedInputObject"].graphene_type is NestedInputObject assert type_map["NestedInputObject"].graphene_type is NestedInputObject
def test_includes_interfaces_thunk_subtypes_in_the_type_map(): def test_includes_interfaces_thunk_subtypes_in_the_type_map():
@ -142,8 +146,9 @@ def test_includes_interfaces_thunk_subtypes_in_the_type_map():
iface = Field(lambda: SomeInterface) iface = Field(lambda: SomeInterface)
schema = Schema(query=Query, types=[SomeSubtype]) schema = Schema(query=Query, types=[SomeSubtype])
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["SomeSubtype"].graphene_type is SomeSubtype assert type_map["SomeSubtype"].graphene_type is SomeSubtype
def test_includes_types_in_union(): def test_includes_types_in_union():
@ -161,9 +166,10 @@ def test_includes_types_in_union():
union = Field(MyUnion) union = Field(MyUnion)
schema = Schema(query=Query) schema = Schema(query=Query)
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["OtherType"].graphene_type is OtherType assert type_map["OtherType"].graphene_type is OtherType
assert schema.get_type_map()["SomeType"].graphene_type is SomeType assert type_map["SomeType"].graphene_type is SomeType
def test_maps_enum(): def test_maps_enum():
@ -181,9 +187,10 @@ def test_maps_enum():
union = Field(MyUnion) union = Field(MyUnion)
schema = Schema(query=Query) schema = Schema(query=Query)
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["OtherType"].graphene_type is OtherType assert type_map["OtherType"].graphene_type is OtherType
assert schema.get_type_map()["SomeType"].graphene_type is SomeType assert type_map["SomeType"].graphene_type is SomeType
def test_includes_interfaces_subtypes_in_the_type_map(): def test_includes_interfaces_subtypes_in_the_type_map():
@ -198,8 +205,9 @@ def test_includes_interfaces_subtypes_in_the_type_map():
iface = Field(SomeInterface) iface = Field(SomeInterface)
schema = Schema(query=Query, types=[SomeSubtype]) schema = Schema(query=Query, types=[SomeSubtype])
type_map = schema.graphql_schema.type_map
assert schema.get_type_map()["SomeSubtype"].graphene_type is SomeSubtype assert type_map["SomeSubtype"].graphene_type is SomeSubtype
def test_stringifies_simple_types(): def test_stringifies_simple_types():
@ -281,7 +289,7 @@ def test_stringifies_simple_types():
def test_does_not_mutate_passed_field_definitions(): def test_does_not_mutate_passed_field_definitions():
class CommonFields(object): class CommonFields:
field1 = String() field1 = String()
field2 = String(id=String()) field2 = String(id=String())
@ -293,7 +301,7 @@ def test_does_not_mutate_passed_field_definitions():
assert TestObject1._meta.fields == TestObject2._meta.fields assert TestObject1._meta.fields == TestObject2._meta.fields
class CommonFields(object): class CommonFields:
field1 = String() field1 = String()
field2 = String() field2 = String()

View File

@ -80,36 +80,19 @@ def test_enum_from_builtin_enum_accepts_lambda_description():
class Query(ObjectType): class Query(ObjectType):
foo = Episode() foo = Episode()
schema = Schema(query=Query) schema = Schema(query=Query).graphql_schema
GraphQLPyEpisode = schema._type_map["PyEpisode"].values episode = schema.get_type("PyEpisode")
assert schema._type_map["PyEpisode"].description == "StarWars Episodes" assert episode.description == "StarWars Episodes"
assert ( assert [
GraphQLPyEpisode[0].name == "NEWHOPE" (name, value.description, value.deprecation_reason)
and GraphQLPyEpisode[0].description == "New Hope Episode" for name, value in episode.values.items()
) ] == [
assert ( ("NEWHOPE", "New Hope Episode", "meh"),
GraphQLPyEpisode[1].name == "EMPIRE" ("EMPIRE", "Other", None),
and GraphQLPyEpisode[1].description == "Other" ("JEDI", "Other", None),
) ]
assert (
GraphQLPyEpisode[2].name == "JEDI"
and GraphQLPyEpisode[2].description == "Other"
)
assert (
GraphQLPyEpisode[0].name == "NEWHOPE"
and GraphQLPyEpisode[0].deprecation_reason == "meh"
)
assert (
GraphQLPyEpisode[1].name == "EMPIRE"
and GraphQLPyEpisode[1].deprecation_reason is None
)
assert (
GraphQLPyEpisode[2].name == "JEDI"
and GraphQLPyEpisode[2].deprecation_reason is None
)
def test_enum_from_python3_enum_uses_enum_doc(): def test_enum_from_python3_enum_uses_enum_doc():

View File

@ -1,6 +1,6 @@
from functools import partial from functools import partial
import pytest from pytest import raises
from ..argument import Argument from ..argument import Argument
from ..field import Field from ..field import Field
@ -9,7 +9,7 @@ from ..structures import NonNull
from .utils import MyLazyType from .utils import MyLazyType
class MyInstance(object): class MyInstance:
value = "value" value = "value"
value_func = staticmethod(lambda: "value_func") value_func = staticmethod(lambda: "value_func")
@ -85,7 +85,7 @@ def test_field_with_string_type():
def test_field_not_source_and_resolver(): def test_field_not_source_and_resolver():
MyType = object() MyType = object()
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
Field(MyType, source="value", resolver=lambda: None) Field(MyType, source="value", resolver=lambda: None)
assert ( assert (
str(exc_info.value) str(exc_info.value)
@ -122,7 +122,7 @@ def test_field_name_as_argument():
def test_field_source_argument_as_kw(): def test_field_source_argument_as_kw():
MyType = object() MyType = object()
field = Field(MyType, b=NonNull(True), c=Argument(None), a=NonNull(False)) field = Field(MyType, b=NonNull(True), c=Argument(None), a=NonNull(False))
assert list(field.args.keys()) == ["b", "c", "a"] assert list(field.args) == ["b", "c", "a"]
assert isinstance(field.args["b"], Argument) assert isinstance(field.args["b"], Argument)
assert isinstance(field.args["b"].type, NonNull) assert isinstance(field.args["b"].type, NonNull)
assert field.args["b"].type.of_type is True assert field.args["b"].type.of_type is True

View File

@ -8,7 +8,7 @@ from ..schema import Schema
from ..unmountedtype import UnmountedType from ..unmountedtype import UnmountedType
class MyType(object): class MyType:
pass pass
@ -50,7 +50,7 @@ def test_ordered_fields_in_inputobjecttype():
field = MyScalar() field = MyScalar()
asa = InputField(MyType) asa = InputField(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ["b", "a", "field", "asa"] assert list(MyInputObjectType._meta.fields) == ["b", "a", "field", "asa"]
def test_generate_inputobjecttype_unmountedtype(): def test_generate_inputobjecttype_unmountedtype():
@ -78,13 +78,13 @@ def test_generate_inputobjecttype_as_argument():
def test_generate_inputobjecttype_inherit_abstracttype(): def test_generate_inputobjecttype_inherit_abstracttype():
class MyAbstractType(object): class MyAbstractType:
field1 = MyScalar(MyType) field1 = MyScalar(MyType)
class MyInputObjectType(InputObjectType, MyAbstractType): class MyInputObjectType(InputObjectType, MyAbstractType):
field2 = MyScalar(MyType) field2 = MyScalar(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ["field1", "field2"] assert list(MyInputObjectType._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [ assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [
InputField, InputField,
InputField, InputField,
@ -92,13 +92,13 @@ def test_generate_inputobjecttype_inherit_abstracttype():
def test_generate_inputobjecttype_inherit_abstracttype_reversed(): def test_generate_inputobjecttype_inherit_abstracttype_reversed():
class MyAbstractType(object): class MyAbstractType:
field1 = MyScalar(MyType) field1 = MyScalar(MyType)
class MyInputObjectType(MyAbstractType, InputObjectType): class MyInputObjectType(MyAbstractType, InputObjectType):
field2 = MyScalar(MyType) field2 = MyScalar(MyType)
assert list(MyInputObjectType._meta.fields.keys()) == ["field1", "field2"] assert list(MyInputObjectType._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [ assert [type(x) for x in MyInputObjectType._meta.fields.values()] == [
InputField, InputField,
InputField, InputField,
@ -133,5 +133,6 @@ def test_inputobjecttype_of_input():
} }
""" """
) )
assert not result.errors assert not result.errors
assert result.data == {"isChild": True} assert result.data == {"isChild": True}

View File

@ -3,7 +3,7 @@ from ..interface import Interface
from ..unmountedtype import UnmountedType from ..unmountedtype import UnmountedType
class MyType(object): class MyType:
pass pass
@ -45,7 +45,7 @@ def test_ordered_fields_in_interface():
field = MyScalar() field = MyScalar()
asa = Field(MyType) asa = Field(MyType)
assert list(MyInterface._meta.fields.keys()) == ["b", "a", "field", "asa"] assert list(MyInterface._meta.fields) == ["b", "a", "field", "asa"]
def test_generate_interface_unmountedtype(): def test_generate_interface_unmountedtype():
@ -57,13 +57,13 @@ def test_generate_interface_unmountedtype():
def test_generate_interface_inherit_abstracttype(): def test_generate_interface_inherit_abstracttype():
class MyAbstractType(object): class MyAbstractType:
field1 = MyScalar() field1 = MyScalar()
class MyInterface(Interface, MyAbstractType): class MyInterface(Interface, MyAbstractType):
field2 = MyScalar() field2 = MyScalar()
assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"] assert list(MyInterface._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
@ -75,16 +75,16 @@ def test_generate_interface_inherit_interface():
field2 = MyScalar() field2 = MyScalar()
assert MyInterface._meta.name == "MyInterface" assert MyInterface._meta.name == "MyInterface"
assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"] assert list(MyInterface._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]
def test_generate_interface_inherit_abstracttype_reversed(): def test_generate_interface_inherit_abstracttype_reversed():
class MyAbstractType(object): class MyAbstractType:
field1 = MyScalar() field1 = MyScalar()
class MyInterface(MyAbstractType, Interface): class MyInterface(MyAbstractType, Interface):
field2 = MyScalar() field2 = MyScalar()
assert list(MyInterface._meta.fields.keys()) == ["field1", "field2"] assert list(MyInterface._meta.fields) == ["field1", "field2"]
assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field] assert [type(x) for x in MyInterface._meta.fields.values()] == [Field, Field]

View File

@ -1,4 +1,4 @@
import pytest from pytest import raises
from ..argument import Argument from ..argument import Argument
from ..dynamic import Dynamic from ..dynamic import Dynamic
@ -46,7 +46,7 @@ def test_generate_mutation_with_meta():
def test_mutation_raises_exception_if_no_mutate(): def test_mutation_raises_exception_if_no_mutate():
with pytest.raises(AssertionError) as excinfo: with raises(AssertionError) as excinfo:
class MyMutation(Mutation): class MyMutation(Mutation):
pass pass

View File

@ -1,4 +1,4 @@
import pytest from pytest import raises
from ..field import Field from ..field import Field
from ..interface import Interface from ..interface import Interface
@ -91,7 +91,7 @@ def test_generate_objecttype_with_private_attributes():
m = MyObjectType(_private_state="custom") m = MyObjectType(_private_state="custom")
assert m._private_state == "custom" assert m._private_state == "custom"
with pytest.raises(TypeError): with raises(TypeError):
MyObjectType(_other_private_state="Wrong") MyObjectType(_other_private_state="Wrong")
@ -102,11 +102,11 @@ def test_ordered_fields_in_objecttype():
field = MyScalar() field = MyScalar()
asa = Field(MyType) asa = Field(MyType)
assert list(MyObjectType._meta.fields.keys()) == ["b", "a", "field", "asa"] assert list(MyObjectType._meta.fields) == ["b", "a", "field", "asa"]
def test_generate_objecttype_inherit_abstracttype(): def test_generate_objecttype_inherit_abstracttype():
class MyAbstractType(object): class MyAbstractType:
field1 = MyScalar() field1 = MyScalar()
class MyObjectType(ObjectType, MyAbstractType): class MyObjectType(ObjectType, MyAbstractType):
@ -115,12 +115,12 @@ def test_generate_objecttype_inherit_abstracttype():
assert MyObjectType._meta.description is None assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == () assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"] assert list(MyObjectType._meta.fields) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field] assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]
def test_generate_objecttype_inherit_abstracttype_reversed(): def test_generate_objecttype_inherit_abstracttype_reversed():
class MyAbstractType(object): class MyAbstractType:
field1 = MyScalar() field1 = MyScalar()
class MyObjectType(MyAbstractType, ObjectType): class MyObjectType(MyAbstractType, ObjectType):
@ -129,7 +129,7 @@ def test_generate_objecttype_inherit_abstracttype_reversed():
assert MyObjectType._meta.description is None assert MyObjectType._meta.description is None
assert MyObjectType._meta.interfaces == () assert MyObjectType._meta.interfaces == ()
assert MyObjectType._meta.name == "MyObjectType" assert MyObjectType._meta.name == "MyObjectType"
assert list(MyObjectType._meta.fields.keys()) == ["field1", "field2"] assert list(MyObjectType._meta.fields) == ["field1", "field2"]
assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field] assert list(map(type, MyObjectType._meta.fields.values())) == [Field, Field]
@ -142,15 +142,11 @@ def test_generate_objecttype_unmountedtype():
def test_parent_container_get_fields(): def test_parent_container_get_fields():
assert list(Container._meta.fields.keys()) == ["field1", "field2"] assert list(Container._meta.fields) == ["field1", "field2"]
def test_parent_container_interface_get_fields(): def test_parent_container_interface_get_fields():
assert list(ContainerWithInterface._meta.fields.keys()) == [ assert list(ContainerWithInterface._meta.fields) == ["ifield", "field1", "field2"]
"ifield",
"field1",
"field2",
]
def test_objecttype_as_container_only_args(): def test_objecttype_as_container_only_args():
@ -177,14 +173,14 @@ def test_objecttype_as_container_all_kwargs():
def test_objecttype_as_container_extra_args(): def test_objecttype_as_container_extra_args():
with pytest.raises(IndexError) as excinfo: with raises(IndexError) as excinfo:
Container("1", "2", "3") Container("1", "2", "3")
assert "Number of args exceeds number of fields" == str(excinfo.value) assert "Number of args exceeds number of fields" == str(excinfo.value)
def test_objecttype_as_container_invalid_kwargs(): def test_objecttype_as_container_invalid_kwargs():
with pytest.raises(TypeError) as excinfo: with raises(TypeError) as excinfo:
Container(unexisting_field="3") Container(unexisting_field="3")
assert "'unexisting_field' is an invalid keyword argument for Container" == str( assert "'unexisting_field' is an invalid keyword argument for Container" == str(
@ -218,7 +214,7 @@ def test_objecttype_with_possible_types():
def test_objecttype_with_possible_types_and_is_type_of_should_raise(): def test_objecttype_with_possible_types_and_is_type_of_should_raise():
with pytest.raises(AssertionError) as excinfo: with raises(AssertionError) as excinfo:
class MyObjectType(ObjectType): class MyObjectType(ObjectType):
class Meta: class Meta:

View File

@ -1,7 +1,13 @@
import json import json
from functools import partial from functools import partial
from graphql import GraphQLError, ResolveInfo, Source, execute, parse from graphql import (
GraphQLError,
GraphQLResolveInfo as ResolveInfo,
Source,
execute,
parse,
)
from ..context import Context from ..context import Context
from ..dynamic import Dynamic from ..dynamic import Dynamic
@ -28,7 +34,7 @@ def test_query():
def test_query_source(): def test_query_source():
class Root(object): class Root:
_hello = "World" _hello = "World"
def hello(self): def hello(self):
@ -45,10 +51,10 @@ def test_query_source():
def test_query_union(): def test_query_union():
class one_object(object): class one_object:
pass pass
class two_object(object): class two_object:
pass pass
class One(ObjectType): class One(ObjectType):
@ -83,10 +89,10 @@ def test_query_union():
def test_query_interface(): def test_query_interface():
class one_object(object): class one_object:
pass pass
class two_object(object): class two_object:
pass pass
class MyInterface(Interface): class MyInterface(Interface):
@ -175,7 +181,7 @@ def test_query_wrong_default_value():
assert len(executed.errors) == 1 assert len(executed.errors) == 1
assert ( assert (
executed.errors[0].message executed.errors[0].message
== GraphQLError('Expected value of type "MyType" but got: str.').message == GraphQLError("Expected value of type 'MyType' but got: 'hello'.").message
) )
assert executed.data == {"hello": None} assert executed.data == {"hello": None}
@ -223,11 +229,11 @@ def test_query_arguments():
result = test_schema.execute("{ test }", None) result = test_schema.execute("{ test }", None)
assert not result.errors assert not result.errors
assert result.data == {"test": "[null,{}]"} assert result.data == {"test": '[null,{"a_str":null,"a_int":null}]'}
result = test_schema.execute('{ test(aStr: "String!") }', "Source!") result = test_schema.execute('{ test(aStr: "String!") }', "Source!")
assert not result.errors assert not result.errors
assert result.data == {"test": '["Source!",{"a_str":"String!"}]'} assert result.data == {"test": '["Source!",{"a_str":"String!","a_int":null}]'}
result = test_schema.execute('{ test(aInt: -123, aStr: "String!") }', "Source!") result = test_schema.execute('{ test(aInt: -123, aStr: "String!") }', "Source!")
assert not result.errors assert not result.errors
@ -252,18 +258,21 @@ def test_query_input_field():
result = test_schema.execute("{ test }", None) result = test_schema.execute("{ test }", None)
assert not result.errors assert not result.errors
assert result.data == {"test": "[null,{}]"} assert result.data == {"test": '[null,{"a_input":null}]'}
result = test_schema.execute('{ test(aInput: {aField: "String!"} ) }', "Source!") result = test_schema.execute('{ test(aInput: {aField: "String!"} ) }', "Source!")
assert not result.errors assert not result.errors
assert result.data == {"test": '["Source!",{"a_input":{"a_field":"String!"}}]'} assert result.data == {
"test": '["Source!",{"a_input":{"a_field":"String!","recursive_field":null}}]'
}
result = test_schema.execute( result = test_schema.execute(
'{ test(aInput: {recursiveField: {aField: "String!"}}) }', "Source!" '{ test(aInput: {recursiveField: {aField: "String!"}}) }', "Source!"
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
"test": '["Source!",{"a_input":{"recursive_field":{"a_field":"String!"}}}]' "test": '["Source!",{"a_input":{"a_field":null,"recursive_field":'
'{"a_field":"String!","recursive_field":null}}}]'
} }
@ -279,8 +288,7 @@ def test_query_middlewares():
return "other" return "other"
def reversed_middleware(next, *args, **kwargs): def reversed_middleware(next, *args, **kwargs):
p = next(*args, **kwargs) return next(*args, **kwargs)[::-1]
return p.then(lambda x: x[::-1])
hello_schema = Schema(Query) hello_schema = Schema(Query)
@ -342,10 +350,11 @@ def test_big_list_query_compiled_query_benchmark(benchmark):
return big_list return big_list
hello_schema = Schema(Query) hello_schema = Schema(Query)
graphql_schema = hello_schema.graphql_schema
source = Source("{ allInts }") source = Source("{ allInts }")
query_ast = parse(source) query_ast = parse(source)
big_list_query = partial(execute, hello_schema, query_ast) big_list_query = partial(execute, graphql_schema, query_ast)
result = benchmark(big_list_query) result = benchmark(big_list_query)
assert not result.errors assert not result.errors
assert result.data == {"allInts": list(big_list)} assert result.data == {"allInts": list(big_list)}

View File

@ -13,7 +13,7 @@ info = None
demo_dict = {"attr": "value"} demo_dict = {"attr": "value"}
class demo_obj(object): class demo_obj:
attr = "value" attr = "value"

View File

@ -1,4 +1,6 @@
import pytest from pytest import raises
from graphql.pyutils import dedent
from ..field import Field from ..field import Field
from ..objecttype import ObjectType from ..objecttype import ObjectType
@ -15,8 +17,8 @@ class Query(ObjectType):
def test_schema(): def test_schema():
schema = Schema(Query) schema = Schema(Query).graphql_schema
assert schema.get_query_type() == schema.get_graphql_type(Query) assert schema.query_type == schema.get_graphql_type(Query)
def test_schema_get_type(): def test_schema_get_type():
@ -27,7 +29,7 @@ def test_schema_get_type():
def test_schema_get_type_error(): def test_schema_get_type_error():
schema = Schema(Query) schema = Schema(Query)
with pytest.raises(AttributeError) as exc_info: with raises(AttributeError) as exc_info:
schema.X schema.X
assert str(exc_info.value) == 'Type "X" not found in the Schema' assert str(exc_info.value) == 'Type "X" not found in the Schema'
@ -35,20 +37,16 @@ def test_schema_get_type_error():
def test_schema_str(): def test_schema_str():
schema = Schema(Query) schema = Schema(Query)
assert ( assert str(schema) == dedent(
str(schema) """
== """schema { type MyOtherType {
query: Query field: String
} }
type MyOtherType { type Query {
field: String inner: MyOtherType
} }
"""
type Query {
inner: MyOtherType
}
"""
) )

View File

@ -1,6 +1,6 @@
from functools import partial from functools import partial
import pytest from pytest import raises
from ..scalars import String from ..scalars import String
from ..structures import List, NonNull from ..structures import List, NonNull
@ -14,7 +14,7 @@ def test_list():
def test_list_with_unmounted_type(): def test_list_with_unmounted_type():
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
List(String()) List(String())
assert ( assert (
@ -82,7 +82,7 @@ def test_nonnull_inherited_works_list():
def test_nonnull_inherited_dont_work_nonnull(): def test_nonnull_inherited_dont_work_nonnull():
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
NonNull(NonNull(String)) NonNull(NonNull(String))
assert ( assert (
@ -92,7 +92,7 @@ def test_nonnull_inherited_dont_work_nonnull():
def test_nonnull_with_unmounted_type(): def test_nonnull_with_unmounted_type():
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
NonNull(String()) NonNull(String())
assert ( assert (

View File

@ -1,10 +1,11 @@
import pytest from pytest import raises
from graphql.type import ( from graphql.type import (
GraphQLArgument, GraphQLArgument,
GraphQLEnumType, GraphQLEnumType,
GraphQLEnumValue, GraphQLEnumValue,
GraphQLField, GraphQLField,
GraphQLInputObjectField, GraphQLInputField,
GraphQLInputObjectType, GraphQLInputObjectType,
GraphQLInterfaceType, GraphQLInterfaceType,
GraphQLObjectType, GraphQLObjectType,
@ -20,7 +21,13 @@ from ..interface import Interface
from ..objecttype import ObjectType from ..objecttype import ObjectType
from ..scalars import Int, String from ..scalars import Int, String
from ..structures import List, NonNull from ..structures import List, NonNull
from ..typemap import TypeMap, resolve_type from ..schema import GrapheneGraphQLSchema, resolve_type
def create_type_map(types, auto_camelcase=True):
query = GraphQLObjectType("Query", {})
schema = GrapheneGraphQLSchema(query, types=types, auto_camelcase=auto_camelcase)
return schema.type_map
def test_enum(): def test_enum():
@ -39,22 +46,18 @@ def test_enum():
if self == MyEnum.foo: if self == MyEnum.foo:
return "Is deprecated" return "Is deprecated"
typemap = TypeMap([MyEnum]) type_map = create_type_map([MyEnum])
assert "MyEnum" in typemap assert "MyEnum" in type_map
graphql_enum = typemap["MyEnum"] graphql_enum = type_map["MyEnum"]
assert isinstance(graphql_enum, GraphQLEnumType) assert isinstance(graphql_enum, GraphQLEnumType)
assert graphql_enum.name == "MyEnum" assert graphql_enum.name == "MyEnum"
assert graphql_enum.description == "Description" assert graphql_enum.description == "Description"
values = graphql_enum.values assert graphql_enum.values == {
assert values == [ "foo": GraphQLEnumValue(
GraphQLEnumValue( value=1, description="Description foo=1", deprecation_reason="Is deprecated"
name="foo",
value=1,
description="Description foo=1",
deprecation_reason="Is deprecated",
), ),
GraphQLEnumValue(name="bar", value=2, description="Description bar=2"), "bar": GraphQLEnumValue(value=2, description="Description bar=2"),
] }
def test_objecttype(): def test_objecttype():
@ -70,15 +73,15 @@ def test_objecttype():
def resolve_foo(self, bar): def resolve_foo(self, bar):
return bar return bar
typemap = TypeMap([MyObjectType]) type_map = create_type_map([MyObjectType])
assert "MyObjectType" in typemap assert "MyObjectType" in type_map
graphql_type = typemap["MyObjectType"] graphql_type = type_map["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType) assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == "MyObjectType" assert graphql_type.name == "MyObjectType"
assert graphql_type.description == "Description" assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ["foo", "gizmo"] assert list(fields) == ["foo", "gizmo"]
foo_field = fields["foo"] foo_field = fields["foo"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.description == "Field description" assert foo_field.description == "Field description"
@ -100,13 +103,13 @@ def test_dynamic_objecttype():
bar = Dynamic(lambda: Field(String)) bar = Dynamic(lambda: Field(String))
own = Field(lambda: MyObjectType) own = Field(lambda: MyObjectType)
typemap = TypeMap([MyObjectType]) type_map = create_type_map([MyObjectType])
assert "MyObjectType" in typemap assert "MyObjectType" in type_map
assert list(MyObjectType._meta.fields.keys()) == ["bar", "own"] assert list(MyObjectType._meta.fields) == ["bar", "own"]
graphql_type = typemap["MyObjectType"] graphql_type = type_map["MyObjectType"]
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ["bar", "own"] assert list(fields) == ["bar", "own"]
assert fields["bar"].type == GraphQLString assert fields["bar"].type == GraphQLString
assert fields["own"].type == graphql_type assert fields["own"].type == graphql_type
@ -125,21 +128,21 @@ def test_interface():
def resolve_foo(self, args, info): def resolve_foo(self, args, info):
return args.get("bar") return args.get("bar")
typemap = TypeMap([MyInterface]) type_map = create_type_map([MyInterface])
assert "MyInterface" in typemap assert "MyInterface" in type_map
graphql_type = typemap["MyInterface"] graphql_type = type_map["MyInterface"]
assert isinstance(graphql_type, GraphQLInterfaceType) assert isinstance(graphql_type, GraphQLInterfaceType)
assert graphql_type.name == "MyInterface" assert graphql_type.name == "MyInterface"
assert graphql_type.description == "Description" assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ["foo", "gizmo", "own"] assert list(fields) == ["foo", "gizmo", "own"]
assert fields["own"].type == graphql_type assert fields["own"].type == graphql_type
assert list(fields["gizmo"].args.keys()) == ["firstArg", "oth_arg"] assert list(fields["gizmo"].args) == ["firstArg", "oth_arg"]
foo_field = fields["foo"] foo_field = fields["foo"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.description == "Field description" assert foo_field.description == "Field description"
assert not foo_field.resolver # Resolver not attached in interfaces assert not foo_field.resolve # Resolver not attached in interfaces
assert foo_field.args == { assert foo_field.args == {
"bar": GraphQLArgument( "bar": GraphQLArgument(
GraphQLString, GraphQLString,
@ -169,23 +172,23 @@ def test_inputobject():
def resolve_foo_bar(self, args, info): def resolve_foo_bar(self, args, info):
return args.get("bar") return args.get("bar")
typemap = TypeMap([MyInputObjectType]) type_map = create_type_map([MyInputObjectType])
assert "MyInputObjectType" in typemap assert "MyInputObjectType" in type_map
graphql_type = typemap["MyInputObjectType"] graphql_type = type_map["MyInputObjectType"]
assert isinstance(graphql_type, GraphQLInputObjectType) assert isinstance(graphql_type, GraphQLInputObjectType)
assert graphql_type.name == "MyInputObjectType" assert graphql_type.name == "MyInputObjectType"
assert graphql_type.description == "Description" assert graphql_type.description == "Description"
other_graphql_type = typemap["OtherObjectType"] other_graphql_type = type_map["OtherObjectType"]
inner_graphql_type = typemap["MyInnerObjectType"] inner_graphql_type = type_map["MyInnerObjectType"]
container = graphql_type.create_container( container = graphql_type.out_type(
{ {
"bar": "oh!", "bar": "oh!",
"baz": inner_graphql_type.create_container( "baz": inner_graphql_type.out_type(
{ {
"some_other_field": [ "some_other_field": [
other_graphql_type.create_container({"thingy": 1}), other_graphql_type.out_type({"thingy": 1}),
other_graphql_type.create_container({"thingy": 2}), other_graphql_type.out_type({"thingy": 2}),
] ]
} }
), ),
@ -201,11 +204,11 @@ def test_inputobject():
assert container.baz.some_other_field[1].thingy == 2 assert container.baz.some_other_field[1].thingy == 2
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ["fooBar", "gizmo", "baz", "own"] assert list(fields) == ["fooBar", "gizmo", "baz", "own"]
own_field = fields["own"] own_field = fields["own"]
assert own_field.type == graphql_type assert own_field.type == graphql_type
foo_field = fields["fooBar"] foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLInputObjectField) assert isinstance(foo_field, GraphQLInputField)
assert foo_field.description == "Field description" assert foo_field.description == "Field description"
@ -215,19 +218,19 @@ def test_objecttype_camelcase():
foo_bar = String(bar_foo=String()) foo_bar = String(bar_foo=String())
typemap = TypeMap([MyObjectType]) type_map = create_type_map([MyObjectType])
assert "MyObjectType" in typemap assert "MyObjectType" in type_map
graphql_type = typemap["MyObjectType"] graphql_type = type_map["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType) assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == "MyObjectType" assert graphql_type.name == "MyObjectType"
assert graphql_type.description == "Description" assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ["fooBar"] assert list(fields) == ["fooBar"]
foo_field = fields["fooBar"] foo_field = fields["fooBar"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.args == { assert foo_field.args == {
"barFoo": GraphQLArgument(GraphQLString, out_name="bar_foo") "barFoo": GraphQLArgument(GraphQLString, default_value=None, out_name="bar_foo")
} }
@ -237,19 +240,21 @@ def test_objecttype_camelcase_disabled():
foo_bar = String(bar_foo=String()) foo_bar = String(bar_foo=String())
typemap = TypeMap([MyObjectType], auto_camelcase=False) type_map = create_type_map([MyObjectType], auto_camelcase=False)
assert "MyObjectType" in typemap assert "MyObjectType" in type_map
graphql_type = typemap["MyObjectType"] graphql_type = type_map["MyObjectType"]
assert isinstance(graphql_type, GraphQLObjectType) assert isinstance(graphql_type, GraphQLObjectType)
assert graphql_type.name == "MyObjectType" assert graphql_type.name == "MyObjectType"
assert graphql_type.description == "Description" assert graphql_type.description == "Description"
fields = graphql_type.fields fields = graphql_type.fields
assert list(fields.keys()) == ["foo_bar"] assert list(fields) == ["foo_bar"]
foo_field = fields["foo_bar"] foo_field = fields["foo_bar"]
assert isinstance(foo_field, GraphQLField) assert isinstance(foo_field, GraphQLField)
assert foo_field.args == { assert foo_field.args == {
"bar_foo": GraphQLArgument(GraphQLString, out_name="bar_foo") "bar_foo": GraphQLArgument(
GraphQLString, default_value=None, out_name="bar_foo"
)
} }
@ -262,8 +267,8 @@ def test_objecttype_with_possible_types():
foo_bar = String() foo_bar = String()
typemap = TypeMap([MyObjectType]) type_map = create_type_map([MyObjectType])
graphql_type = typemap["MyObjectType"] graphql_type = type_map["MyObjectType"]
assert graphql_type.is_type_of assert graphql_type.is_type_of
assert graphql_type.is_type_of({}, None) is True assert graphql_type.is_type_of({}, None) is True
assert graphql_type.is_type_of(MyObjectType(), None) is False assert graphql_type.is_type_of(MyObjectType(), None) is False
@ -279,8 +284,8 @@ def test_resolve_type_with_missing_type():
def resolve_type_func(root, info): def resolve_type_func(root, info):
return MyOtherObjectType return MyOtherObjectType
typemap = TypeMap([MyObjectType]) type_map = create_type_map([MyObjectType])
with pytest.raises(AssertionError) as excinfo: with raises(AssertionError) as excinfo:
resolve_type(resolve_type_func, typemap, "MyOtherObjectType", {}, {}) resolve_type(resolve_type_func, type_map, "MyOtherObjectType", {}, {}, None)
assert "MyOtherObjectTyp" in str(excinfo.value) assert "MyOtherObjectTyp" in str(excinfo.value)

View File

@ -1,4 +1,4 @@
import pytest from pytest import raises
from ..field import Field from ..field import Field
from ..objecttype import ObjectType from ..objecttype import ObjectType
@ -38,7 +38,7 @@ def test_generate_union_with_meta():
def test_generate_union_with_no_types(): def test_generate_union_with_no_types():
with pytest.raises(Exception) as exc_info: with raises(Exception) as exc_info:
class MyUnion(Union): class MyUnion(Union):
pass pass

View File

@ -1,337 +0,0 @@
import inspect
from collections import OrderedDict
from functools import partial
from graphql import (
GraphQLArgument,
GraphQLBoolean,
GraphQLField,
GraphQLFloat,
GraphQLID,
GraphQLInputObjectField,
GraphQLInt,
GraphQLList,
GraphQLNonNull,
GraphQLString,
)
from graphql.execution.executor import get_default_resolve_type_fn
from graphql.type import GraphQLEnumValue
from graphql.type.typemap import GraphQLTypeMap
from ..utils.get_unbound_function import get_unbound_function
from ..utils.str_converters import to_camel_case
from .definitions import (
GrapheneEnumType,
GrapheneGraphQLType,
GrapheneInputObjectType,
GrapheneInterfaceType,
GrapheneObjectType,
GrapheneScalarType,
GrapheneUnionType,
)
from .dynamic import Dynamic
from .enum import Enum
from .field import Field
from .inputobjecttype import InputObjectType
from .interface import Interface
from .objecttype import ObjectType
from .resolver import get_default_resolver
from .scalars import ID, Boolean, Float, Int, Scalar, String
from .structures import List, NonNull
from .union import Union
from .utils import get_field_as
def is_graphene_type(_type):
if isinstance(_type, (List, NonNull)):
return True
if inspect.isclass(_type) and issubclass(
_type, (ObjectType, InputObjectType, Scalar, Interface, Union, Enum)
):
return True
def resolve_type(resolve_type_func, map, type_name, root, info):
_type = resolve_type_func(root, info)
if not _type:
return_type = map[type_name]
return get_default_resolve_type_fn(root, info, return_type)
if inspect.isclass(_type) and issubclass(_type, ObjectType):
graphql_type = map.get(_type._meta.name)
assert graphql_type, "Can't find type {} in schema".format(_type._meta.name)
assert graphql_type.graphene_type == _type, (
"The type {} does not match with the associated graphene type {}."
).format(_type, graphql_type.graphene_type)
return graphql_type
return _type
def is_type_of_from_possible_types(possible_types, root, info):
return isinstance(root, possible_types)
class TypeMap(GraphQLTypeMap):
def __init__(self, types, auto_camelcase=True, schema=None):
self.auto_camelcase = auto_camelcase
self.schema = schema
super(TypeMap, self).__init__(types)
def reducer(self, map, type):
if not type:
return map
if inspect.isfunction(type):
type = type()
if is_graphene_type(type):
return self.graphene_reducer(map, type)
return GraphQLTypeMap.reducer(map, type)
def graphene_reducer(self, map, type):
if isinstance(type, (List, NonNull)):
return self.reducer(map, type.of_type)
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type)
return map
if issubclass(type, ObjectType):
internal_type = self.construct_objecttype(map, type)
elif issubclass(type, InputObjectType):
internal_type = self.construct_inputobjecttype(map, type)
elif issubclass(type, Interface):
internal_type = self.construct_interface(map, type)
elif issubclass(type, Scalar):
internal_type = self.construct_scalar(map, type)
elif issubclass(type, Enum):
internal_type = self.construct_enum(map, type)
elif issubclass(type, Union):
internal_type = self.construct_union(map, type)
else:
raise Exception("Expected Graphene type, but received: {}.".format(type))
return GraphQLTypeMap.reducer(map, internal_type)
def construct_scalar(self, map, type):
# We have a mapping to the original GraphQL types
# so there are no collisions.
_scalars = {
String: GraphQLString,
Int: GraphQLInt,
Float: GraphQLFloat,
Boolean: GraphQLBoolean,
ID: GraphQLID,
}
if type in _scalars:
return _scalars[type]
return GrapheneScalarType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
serialize=getattr(type, "serialize", None),
parse_value=getattr(type, "parse_value", None),
parse_literal=getattr(type, "parse_literal", None),
)
def construct_enum(self, map, type):
values = OrderedDict()
for name, value in type._meta.enum.__members__.items():
description = getattr(value, "description", None)
deprecation_reason = getattr(value, "deprecation_reason", None)
if not description and callable(type._meta.description):
description = type._meta.description(value)
if not deprecation_reason and callable(type._meta.deprecation_reason):
deprecation_reason = type._meta.deprecation_reason(value)
values[name] = GraphQLEnumValue(
name=name,
value=value.value,
description=description,
deprecation_reason=deprecation_reason,
)
type_description = (
type._meta.description(None)
if callable(type._meta.description)
else type._meta.description
)
return GrapheneEnumType(
graphene_type=type,
values=values,
name=type._meta.name,
description=type_description,
)
def construct_objecttype(self, map, type):
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneGraphQLType):
assert _type.graphene_type == type, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type)
return _type
def interfaces():
interfaces = []
for interface in type._meta.interfaces:
self.graphene_reducer(map, interface)
internal_type = map[interface._meta.name]
assert internal_type.graphene_type == interface
interfaces.append(internal_type)
return interfaces
if type._meta.possible_types:
is_type_of = partial(
is_type_of_from_possible_types, type._meta.possible_types
)
else:
is_type_of = type.is_type_of
return GrapheneObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=partial(self.construct_fields_for_type, map, type),
is_type_of=is_type_of,
interfaces=interfaces,
)
def construct_interface(self, map, type):
if type._meta.name in map:
_type = map[type._meta.name]
if isinstance(_type, GrapheneInterfaceType):
assert _type.graphene_type == type, (
"Found different types with the same name in the schema: {}, {}."
).format(_type.graphene_type, type)
return _type
_resolve_type = None
if type.resolve_type:
_resolve_type = partial(
resolve_type, type.resolve_type, map, type._meta.name
)
return GrapheneInterfaceType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
fields=partial(self.construct_fields_for_type, map, type),
resolve_type=_resolve_type,
)
def construct_inputobjecttype(self, map, type):
return GrapheneInputObjectType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
container_type=type._meta.container,
fields=partial(
self.construct_fields_for_type, map, type, is_input_type=True
),
)
def construct_union(self, map, type):
_resolve_type = None
if type.resolve_type:
_resolve_type = partial(
resolve_type, type.resolve_type, map, type._meta.name
)
def types():
union_types = []
for objecttype in type._meta.types:
self.graphene_reducer(map, objecttype)
internal_type = map[objecttype._meta.name]
assert internal_type.graphene_type == objecttype
union_types.append(internal_type)
return union_types
return GrapheneUnionType(
graphene_type=type,
name=type._meta.name,
description=type._meta.description,
types=types,
resolve_type=_resolve_type,
)
def get_name(self, name):
if self.auto_camelcase:
return to_camel_case(name)
return name
def construct_fields_for_type(self, map, type, is_input_type=False):
fields = OrderedDict()
for name, field in type._meta.fields.items():
if isinstance(field, Dynamic):
field = get_field_as(field.get_type(self.schema), _as=Field)
if not field:
continue
map = self.reducer(map, field.type)
field_type = self.get_field_type(map, field.type)
if is_input_type:
_field = GraphQLInputObjectField(
field_type,
default_value=field.default_value,
out_name=name,
description=field.description,
)
else:
args = OrderedDict()
for arg_name, arg in field.args.items():
map = self.reducer(map, arg.type)
arg_type = self.get_field_type(map, arg.type)
processed_arg_name = arg.name or self.get_name(arg_name)
args[processed_arg_name] = GraphQLArgument(
arg_type,
out_name=arg_name,
description=arg.description,
default_value=arg.default_value,
)
_field = GraphQLField(
field_type,
args=args,
resolver=field.get_resolver(
self.get_resolver_for_type(type, name, field.default_value)
),
deprecation_reason=field.deprecation_reason,
description=field.description,
)
field_name = field.name or self.get_name(name)
fields[field_name] = _field
return fields
def get_resolver_for_type(self, type, name, default_value):
if not issubclass(type, ObjectType):
return
resolver = getattr(type, "resolve_{}".format(name), None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in type._meta.interfaces:
if name not in interface._meta.fields:
continue
interface_resolver = getattr(interface, "resolve_{}".format(name), None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver:
return get_unbound_function(resolver)
default_resolver = type._meta.default_resolver or get_default_resolver()
return partial(default_resolver, name, default_value)
def get_field_type(self, map, type):
if isinstance(type, List):
return GraphQLList(self.get_field_type(map, type.of_type))
if isinstance(type, NonNull):
return GraphQLNonNull(self.get_field_type(map, type.of_type))
return map.get(type._meta.name)

View File

@ -1,5 +1,4 @@
import inspect import inspect
from collections import OrderedDict
from functools import partial from functools import partial
from ..utils.module_loading import import_string from ..utils.module_loading import import_string
@ -33,7 +32,7 @@ def yank_fields_from_attrs(attrs, _as=None, sort=True):
if sort: if sort:
fields_with_names = sorted(fields_with_names, key=lambda f: f[1]) fields_with_names = sorted(fields_with_names, key=lambda f: f[1])
return OrderedDict(fields_with_names) return dict(fields_with_names)
def get_type(_type): def get_type(_type):

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from uuid import UUID as _UUID from uuid import UUID as _UUID
from graphql.language import ast from graphql.language.ast import StringValueNode
from .scalars import Scalar from .scalars import Scalar
@ -24,7 +24,7 @@ class UUID(Scalar):
@staticmethod @staticmethod
def parse_literal(node): def parse_literal(node):
if isinstance(node, ast.StringValue): if isinstance(node, StringValueNode):
return _UUID(node.value) return _UUID(node.value)
@staticmethod @staticmethod

View File

@ -1,5 +1,5 @@
import json import json
from collections import Mapping from collections.abc import Mapping
def to_key(value): def to_key(value):

View File

@ -1,4 +1,4 @@
from collections import Mapping, OrderedDict from collections.abc import Mapping
def deflate(node, index=None, path=None): def deflate(node, index=None, path=None):
@ -16,10 +16,9 @@ def deflate(node, index=None, path=None):
else: else:
index[cache_key] = True index[cache_key] = True
field_names = node.keys() result = {}
result = OrderedDict()
for field_name in field_names: for field_name in node:
value = node[field_name] value = node[field_name]
new_path = path + [field_name] new_path = path + [field_name]

View File

@ -2,7 +2,7 @@ from functools import total_ordering
@total_ordering @total_ordering
class OrderedType(object): class OrderedType:
creation_counter = 1 creation_counter = 1
def __init__(self, _creation_counter=None): def __init__(self, _creation_counter=None):

View File

@ -2,7 +2,7 @@ class _OldClass:
pass pass
class _NewClass(object): class _NewClass:
pass pass

View File

@ -43,7 +43,7 @@ class SubclassWithMeta(metaclass=SubclassWithMeta_Meta):
assert not options, ( assert not options, (
"Abstract types can only contain the abstract attribute. " "Abstract types can only contain the abstract attribute. "
"Received: abstract, {option_keys}" "Received: abstract, {option_keys}"
).format(option_keys=", ".join(options.keys())) ).format(option_keys=", ".join(options))
else: else:
super_class = super(cls, cls) super_class = super(cls, cls)
if hasattr(super_class, "__init_subclass_with_meta__"): if hasattr(super_class, "__init_subclass_with_meta__"):

View File

@ -1,10 +1,9 @@
import pytest from pytest import mark
from collections import OrderedDict
from ..crunch import crunch from ..crunch import crunch
@pytest.mark.parametrize( @mark.parametrize(
"description,uncrunched,crunched", "description,uncrunched,crunched",
[ [
["number primitive", 0, [0]], ["number primitive", 0, [0]],
@ -28,28 +27,22 @@ from ..crunch import crunch
["single-item object", {"a": None}, [None, {"a": 0}]], ["single-item object", {"a": None}, [None, {"a": 0}]],
[ [
"multi-item all distinct object", "multi-item all distinct object",
OrderedDict([("a", None), ("b", 0), ("c", True), ("d", "string")]), {"a": None, "b": 0, "c": True, "d": "string"},
[None, 0, True, "string", {"a": 0, "b": 1, "c": 2, "d": 3}], [None, 0, True, "string", {"a": 0, "b": 1, "c": 2, "d": 3}],
], ],
[ [
"multi-item repeated object", "multi-item repeated object",
OrderedDict([("a", True), ("b", True), ("c", True), ("d", True)]), {"a": True, "b": True, "c": True, "d": True},
[True, {"a": 0, "b": 0, "c": 0, "d": 0}], [True, {"a": 0, "b": 0, "c": 0, "d": 0}],
], ],
[ [
"complex array", "complex array",
[OrderedDict([("a", True), ("b", [1, 2, 3])]), [1, 2, 3]], [{"a": True, "b": [1, 2, 3]}, [1, 2, 3]],
[True, 1, 2, 3, [1, 2, 3], {"a": 0, "b": 4}, [5, 4]], [True, 1, 2, 3, [1, 2, 3], {"a": 0, "b": 4}, [5, 4]],
], ],
[ [
"complex object", "complex object",
OrderedDict( {"a": True, "b": [1, 2, 3], "c": {"a": True, "b": [1, 2, 3]}},
[
("a", True),
("b", [1, 2, 3]),
("c", OrderedDict([("a", True), ("b", [1, 2, 3])])),
]
),
[True, 1, 2, 3, [1, 2, 3], {"a": 0, "b": 4}, {"a": 0, "b": 4, "c": 5}], [True, 1, 2, 3, [1, 2, 3], {"a": 0, "b": 4}, {"a": 0, "b": 4, "c": 5}],
], ],
], ],

View File

@ -150,8 +150,8 @@ def test_example_end_to_end():
result = schema.execute(query) result = schema.execute(query)
assert not result.errors assert not result.errors
result.data = deflate(result.data) data = deflate(result.data)
assert result.data == { assert data == {
"events": [ "events": [
{ {
"__typename": "Event", "__typename": "Event",

View File

@ -1,4 +1,4 @@
import pytest from pytest import raises
from .. import deprecated from .. import deprecated
from ..deprecated import deprecated as deprecated_decorator from ..deprecated import deprecated as deprecated_decorator
@ -71,5 +71,5 @@ def test_deprecated_class_text(mocker):
def test_deprecated_other_object(mocker): def test_deprecated_other_object(mocker):
mocker.patch.object(deprecated, "warn_deprecation") mocker.patch.object(deprecated, "warn_deprecation")
with pytest.raises(TypeError): with raises(TypeError):
deprecated_decorator({}) deprecated_decorator({})

View File

@ -2,7 +2,7 @@ from ..trim_docstring import trim_docstring
def test_trim_docstring(): def test_trim_docstring():
class WellDocumentedObject(object): class WellDocumentedObject:
""" """
This object is very well-documented. It has multiple lines in its This object is very well-documented. It has multiple lines in its
description. description.
@ -16,7 +16,7 @@ def test_trim_docstring():
"description.\n\nMultiple paragraphs too" "description.\n\nMultiple paragraphs too"
) )
class UndocumentedObject(object): class UndocumentedObject:
pass pass
assert trim_docstring(UndocumentedObject.__doc__) is None assert trim_docstring(UndocumentedObject.__doc__) is None

View File

@ -1,28 +1,15 @@
""" """
This file is used mainly as a bridge for thenable abstractions. This file is used mainly as a bridge for thenable abstractions.
This includes:
- Promises
- Asyncio Coroutines
""" """
try: from inspect import isawaitable
from promise import Promise, is_thenable # type: ignore
except ImportError:
class Promise(object): # type: ignore
pass
def is_thenable(obj): # type: ignore
return False
try: def await_and_execute(obj, on_resolve):
from inspect import isawaitable async def build_resolve_async():
from .thenables_asyncio import await_and_execute return on_resolve(await obj)
except ImportError:
def isawaitable(obj): # type: ignore return build_resolve_async()
return False
def maybe_thenable(obj, on_resolve): def maybe_thenable(obj, on_resolve):
@ -31,12 +18,8 @@ def maybe_thenable(obj, on_resolve):
returning the same type of object inputed. returning the same type of object inputed.
If the object is not thenable, it should return on_resolve(obj) If the object is not thenable, it should return on_resolve(obj)
""" """
if isawaitable(obj) and not isinstance(obj, Promise): if isawaitable(obj):
return await_and_execute(obj, on_resolve) return await_and_execute(obj, on_resolve)
if is_thenable(obj): # If it's not awaitable, return the function executed over the object
return Promise.resolve(obj).then(on_resolve)
# If it's not awaitable not a Promise, return
# the function executed over the object
return on_resolve(obj) return on_resolve(obj)

View File

@ -1,5 +0,0 @@
def await_and_execute(obj, on_resolve):
async def build_resolve_async():
return on_resolve(await obj)
return build_resolve_async()

View File

@ -80,15 +80,11 @@ setup(
keywords="api graphql protocol rest relay graphene", keywords="api graphql protocol rest relay graphene",
packages=find_packages(exclude=["tests", "tests.*", "examples"]), packages=find_packages(exclude=["tests", "tests.*", "examples"]),
install_requires=[ install_requires=[
"graphql-core>=2.1,<3", "graphql-core>=3.0.0a0,<4",
"graphql-relay>=2,<3", "graphql-relay>=3.0.0a0,<4",
"aniso8601>=3,<=7", "aniso8601>=6,<8",
], ],
tests_require=tests_require, tests_require=tests_require,
extras_require={ extras_require={"test": tests_require},
"test": tests_require,
"django": ["graphene-django"],
"sqlalchemy": ["graphene-sqlalchemy"],
},
cmdclass={"test": PyTest}, cmdclass={"test": PyTest},
) )

View File

@ -1,7 +1,4 @@
import pytest from pytest import mark
from collections import OrderedDict
from graphql.execution.executors.asyncio import AsyncioExecutor
from graphql_relay.utils import base64 from graphql_relay.utils import base64
@ -27,14 +24,14 @@ class LetterConnection(Connection):
class Query(ObjectType): class Query(ObjectType):
letters = ConnectionField(LetterConnection) letters = ConnectionField(LetterConnection)
connection_letters = ConnectionField(LetterConnection) connection_letters = ConnectionField(LetterConnection)
promise_letters = ConnectionField(LetterConnection) async_letters = ConnectionField(LetterConnection)
node = Node.Field() node = Node.Field()
def resolve_letters(self, info, **args): def resolve_letters(self, info, **args):
return list(letters.values()) return list(letters.values())
async def resolve_promise_letters(self, info, **args): async def resolve_async_letters(self, info, **args):
return list(letters.values()) return list(letters.values())
def resolve_connection_letters(self, info, **args): def resolve_connection_letters(self, info, **args):
@ -48,9 +45,7 @@ class Query(ObjectType):
schema = Schema(Query) schema = Schema(Query)
letters = OrderedDict() letters = {letter: Letter(id=i, letter=letter) for i, letter in enumerate(letter_chars)}
for i, letter in enumerate(letter_chars):
letters[letter] = Letter(id=i, letter=letter)
def edges(selected_letters): def edges(selected_letters):
@ -96,12 +91,12 @@ def execute(args=""):
) )
@pytest.mark.asyncio @mark.asyncio
async def test_connection_promise(): async def test_connection_async():
result = await schema.execute( result = await schema.execute_async(
""" """
{ {
promiseLetters(first:1) { asyncLetters(first:1) {
edges { edges {
node { node {
id id
@ -114,14 +109,12 @@ async def test_connection_promise():
} }
} }
} }
""", """
executor=AsyncioExecutor(),
return_promise=True,
) )
assert not result.errors assert not result.errors
assert result.data == { assert result.data == {
"promiseLetters": { "asyncLetters": {
"edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}], "edges": [{"node": {"id": "TGV0dGVyOjA=", "letter": "A"}}],
"pageInfo": {"hasPreviousPage": False, "hasNextPage": True}, "pageInfo": {"hasPreviousPage": False, "hasNextPage": True},
} }

View File

@ -1,5 +1,4 @@
import pytest from pytest import mark
from graphql.execution.executors.asyncio import AsyncioExecutor
from graphene.types import ID, Field, ObjectType, Schema from graphene.types import ID, Field, ObjectType, Schema
from graphene.types.scalars import String from graphene.types.scalars import String
@ -43,11 +42,11 @@ class OtherMutation(ClientIDMutation):
@staticmethod @staticmethod
def mutate_and_get_payload( def mutate_and_get_payload(
self, info, shared="", additional_field="", client_mutation_id=None self, info, shared, additional_field, client_mutation_id=None
): ):
edge_type = MyEdge edge_type = MyEdge
return OtherMutation( return OtherMutation(
name=shared + additional_field, name=(shared or "") + (additional_field or ""),
my_node_edge=edge_type(cursor="1", node=MyNode(name="name")), my_node_edge=edge_type(cursor="1", node=MyNode(name="name")),
) )
@ -64,23 +63,19 @@ class Mutation(ObjectType):
schema = Schema(query=RootQuery, mutation=Mutation) schema = Schema(query=RootQuery, mutation=Mutation)
@pytest.mark.asyncio @mark.asyncio
async def test_node_query_promise(): async def test_node_query_promise():
executed = await schema.execute( executed = await schema.execute_async(
'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }', 'mutation a { sayPromise(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
executor=AsyncioExecutor(),
return_promise=True,
) )
assert not executed.errors assert not executed.errors
assert executed.data == {"sayPromise": {"phrase": "hello"}} assert executed.data == {"sayPromise": {"phrase": "hello"}}
@pytest.mark.asyncio @mark.asyncio
async def test_edge_query(): async def test_edge_query():
executed = await schema.execute( executed = await schema.execute_async(
'mutation a { other(input: {clientMutationId:"1"}) { clientMutationId, myNodeEdge { cursor node { name }} } }', 'mutation a { other(input: {clientMutationId:"1"}) { clientMutationId, myNodeEdge { cursor node { name }} } }'
executor=AsyncioExecutor(),
return_promise=True,
) )
assert not executed.errors assert not executed.errors
assert dict(executed.data) == { assert dict(executed.data) == {

View File

@ -22,14 +22,16 @@ commands =
[testenv:mypy] [testenv:mypy]
basepython=python3.7 basepython=python3.7
deps = deps =
mypy mypy>=0.720
commands = commands =
mypy graphene mypy graphene
[testenv:flake8] [testenv:flake8]
deps = flake8 basepython=python3.6
deps =
flake8>=3.7,<4
commands = commands =
pip install -e . pip install --pre -e .
flake8 graphene flake8 graphene
[pytest] [pytest]