Add UnforgivingExecutionContext (#1255)

This commit is contained in:
Alec Rosenbaum 2020-10-21 05:13:32 -04:00 committed by GitHub
parent a53b782bf8
commit e24ac547d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 215 additions and 3 deletions

View File

@ -28,6 +28,8 @@ from graphql import (
GraphQLString,
Undefined,
)
from graphql.execution import ExecutionContext
from graphql.execution.values import get_argument_values
from ..utils.str_converters import to_camel_case
from ..utils.get_unbound_function import get_unbound_function
@ -317,7 +319,7 @@ class TypeMap(dict):
)
subscribe = field.wrap_subscribe(
self.get_function_for_type(
graphene_type, f"subscribe_{name}", name, field.default_value,
graphene_type, f"subscribe_{name}", name, field.default_value
)
)
@ -394,6 +396,101 @@ class TypeMap(dict):
return type_
class UnforgivingExecutionContext(ExecutionContext):
"""An execution context which doesn't swallow exceptions.
The only difference between this execution context and the one it inherits from is
that ``except Exception`` is commented out within ``resolve_field_value_or_error``.
By removing that exception handling, only ``GraphQLError``'s are caught.
"""
def resolve_field_value_or_error(
self, field_def, field_nodes, resolve_fn, source, info
):
"""Resolve field to a value or an error.
Isolates the "ReturnOrAbrupt" behavior to not de-opt the resolve_field()
method. Returns the result of resolveFn or the abrupt-return Error object.
For internal use only.
"""
try:
# Build a dictionary of arguments from the field.arguments AST, using the
# variables scope to fulfill any variable references.
args = get_argument_values(field_def, field_nodes[0], self.variable_values)
# Note that contrary to the JavaScript implementation, we pass the context
# value as part of the resolve info.
result = resolve_fn(source, info, **args)
if self.is_awaitable(result):
# noinspection PyShadowingNames
async def await_result():
try:
return await result
except GraphQLError as error:
return error
# except Exception as error:
# return GraphQLError(str(error), original_error=error)
# Yes, this is commented out code. It's been intentionally
# _not_ removed to show what has changed from the original
# implementation.
return await_result()
return result
except GraphQLError as error:
return error
# except Exception as error:
# return GraphQLError(str(error), original_error=error)
# Yes, this is commented out code. It's been intentionally _not_
# removed to show what has changed from the original implementation.
def complete_value_catching_error(
self, return_type, field_nodes, info, path, result
):
"""Complete a value while catching an error.
This is a small wrapper around completeValue which detects and logs errors in
the execution context.
"""
try:
if self.is_awaitable(result):
async def await_result():
value = self.complete_value(
return_type, field_nodes, info, path, await result
)
if self.is_awaitable(value):
return await value
return value
completed = await_result()
else:
completed = self.complete_value(
return_type, field_nodes, info, path, result
)
if self.is_awaitable(completed):
# noinspection PyShadowingNames
async def await_completed():
try:
return await completed
# CHANGE WAS MADE HERE
# ``GraphQLError`` was swapped in for ``except Exception``
except GraphQLError as error:
self.handle_field_error(error, field_nodes, path, return_type)
return await_completed()
return completed
# CHANGE WAS MADE HERE
# ``GraphQLError`` was swapped in for ``except Exception``
except GraphQLError as error:
self.handle_field_error(error, field_nodes, path, return_type)
return None
class Schema:
"""Schema Definition.
@ -481,6 +578,8 @@ class Schema:
request_string, an operation name must be provided for the result to be provided.
middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as
defined in `graphql-core`.
execution_context_class (ExecutionContext, optional): The execution context class
to use when resolving queries and mutations.
Returns:
:obj:`ExecutionResult` containing any data and errors for the operation.

View File

@ -1,12 +1,13 @@
from graphql.type import GraphQLObjectType, GraphQLSchema
from pytest import raises
from graphql import GraphQLError
from pytest import mark, raises, fixture
from graphene.tests.utils import dedent
from ..field import Field
from ..objecttype import ObjectType
from ..scalars import String
from ..schema import Schema
from ..schema import Schema, UnforgivingExecutionContext
class MyOtherType(ObjectType):
@ -68,3 +69,115 @@ def test_schema_requires_query_type():
assert len(result.errors) == 1
error = result.errors[0]
assert error.message == "Query root type must be provided."
class TestUnforgivingExecutionContext:
@fixture
def schema(self):
class ErrorFieldsMixin:
sanity_field = String()
expected_error_field = String()
unexpected_value_error_field = String()
unexpected_type_error_field = String()
unexpected_attribute_error_field = String()
unexpected_key_error_field = String()
@staticmethod
def resolve_sanity_field(obj, info):
return "not an error"
@staticmethod
def resolve_expected_error_field(obj, info):
raise GraphQLError("expected error")
@staticmethod
def resolve_unexpected_value_error_field(obj, info):
raise ValueError("unexpected error")
@staticmethod
def resolve_unexpected_type_error_field(obj, info):
raise TypeError("unexpected error")
@staticmethod
def resolve_unexpected_attribute_error_field(obj, info):
raise AttributeError("unexpected error")
@staticmethod
def resolve_unexpected_key_error_field(obj, info):
return {}["fails"]
class NestedObject(ErrorFieldsMixin, ObjectType):
pass
class MyQuery(ErrorFieldsMixin, ObjectType):
nested_object = Field(NestedObject)
nested_object_error = Field(NestedObject)
@staticmethod
def resolve_nested_object(obj, info):
return object()
@staticmethod
def resolve_nested_object_error(obj, info):
raise TypeError()
schema = Schema(query=MyQuery)
return schema
def test_sanity_check(self, schema):
# this should pass with no errors (sanity check)
result = schema.execute(
"query { sanityField }",
execution_context_class=UnforgivingExecutionContext,
)
assert not result.errors
assert result.data == {"sanityField": "not an error"}
def test_nested_sanity_check(self, schema):
# this should pass with no errors (sanity check)
result = schema.execute(
r"query { nestedObject { sanityField } }",
execution_context_class=UnforgivingExecutionContext,
)
assert not result.errors
assert result.data == {"nestedObject": {"sanityField": "not an error"}}
def test_graphql_error(self, schema):
result = schema.execute(
"query { expectedErrorField }",
execution_context_class=UnforgivingExecutionContext,
)
assert len(result.errors) == 1
assert result.errors[0].message == "expected error"
assert result.data == {"expectedErrorField": None}
def test_nested_graphql_error(self, schema):
result = schema.execute(
r"query { nestedObject { expectedErrorField } }",
execution_context_class=UnforgivingExecutionContext,
)
assert len(result.errors) == 1
assert result.errors[0].message == "expected error"
assert result.data == {"nestedObject": {"expectedErrorField": None}}
@mark.parametrize(
"field,exception",
[
("unexpectedValueErrorField", ValueError),
("unexpectedTypeErrorField", TypeError),
("unexpectedAttributeErrorField", AttributeError),
("unexpectedKeyErrorField", KeyError),
("nestedObject { unexpectedValueErrorField }", ValueError),
("nestedObject { unexpectedTypeErrorField }", TypeError),
("nestedObject { unexpectedAttributeErrorField }", AttributeError),
("nestedObject { unexpectedKeyErrorField }", KeyError),
("nestedObjectError { __typename }", TypeError),
],
)
def test_unexpected_error(self, field, exception, schema):
with raises(exception):
# no result, but the exception should be propagated
schema.execute(
f"query {{ {field} }}",
execution_context_class=UnforgivingExecutionContext,
)