diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 55f0bf93..4fd71769 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -28,6 +28,8 @@ from graphql import ( GraphQLString, Undefined, ) +from graphql.execution import ExecutionContext +from graphql.execution.values import get_argument_values from ..utils.str_converters import to_camel_case from ..utils.get_unbound_function import get_unbound_function @@ -317,7 +319,7 @@ class TypeMap(dict): ) subscribe = field.wrap_subscribe( self.get_function_for_type( - graphene_type, f"subscribe_{name}", name, field.default_value, + graphene_type, f"subscribe_{name}", name, field.default_value ) ) @@ -394,6 +396,101 @@ class TypeMap(dict): return type_ +class UnforgivingExecutionContext(ExecutionContext): + """An execution context which doesn't swallow exceptions. + + The only difference between this execution context and the one it inherits from is + that ``except Exception`` is commented out within ``resolve_field_value_or_error``. + By removing that exception handling, only ``GraphQLError``'s are caught. + """ + + def resolve_field_value_or_error( + self, field_def, field_nodes, resolve_fn, source, info + ): + """Resolve field to a value or an error. + + Isolates the "ReturnOrAbrupt" behavior to not de-opt the resolve_field() + method. Returns the result of resolveFn or the abrupt-return Error object. + + For internal use only. + """ + try: + # Build a dictionary of arguments from the field.arguments AST, using the + # variables scope to fulfill any variable references. + args = get_argument_values(field_def, field_nodes[0], self.variable_values) + + # Note that contrary to the JavaScript implementation, we pass the context + # value as part of the resolve info. + result = resolve_fn(source, info, **args) + if self.is_awaitable(result): + # noinspection PyShadowingNames + async def await_result(): + try: + return await result + except GraphQLError as error: + return error + # except Exception as error: + # return GraphQLError(str(error), original_error=error) + + # Yes, this is commented out code. It's been intentionally + # _not_ removed to show what has changed from the original + # implementation. + + return await_result() + return result + except GraphQLError as error: + return error + # except Exception as error: + # return GraphQLError(str(error), original_error=error) + + # Yes, this is commented out code. It's been intentionally _not_ + # removed to show what has changed from the original implementation. + + def complete_value_catching_error( + self, return_type, field_nodes, info, path, result + ): + """Complete a value while catching an error. + + This is a small wrapper around completeValue which detects and logs errors in + the execution context. + """ + try: + if self.is_awaitable(result): + + async def await_result(): + value = self.complete_value( + return_type, field_nodes, info, path, await result + ) + if self.is_awaitable(value): + return await value + return value + + completed = await_result() + else: + completed = self.complete_value( + return_type, field_nodes, info, path, result + ) + if self.is_awaitable(completed): + # noinspection PyShadowingNames + async def await_completed(): + try: + return await completed + + # CHANGE WAS MADE HERE + # ``GraphQLError`` was swapped in for ``except Exception`` + except GraphQLError as error: + self.handle_field_error(error, field_nodes, path, return_type) + + return await_completed() + return completed + + # CHANGE WAS MADE HERE + # ``GraphQLError`` was swapped in for ``except Exception`` + except GraphQLError as error: + self.handle_field_error(error, field_nodes, path, return_type) + return None + + class Schema: """Schema Definition. @@ -481,6 +578,8 @@ class Schema: request_string, an operation name must be provided for the result to be provided. middleware (List[SupportsGraphQLMiddleware]): Supply request level middleware as defined in `graphql-core`. + execution_context_class (ExecutionContext, optional): The execution context class + to use when resolving queries and mutations. Returns: :obj:`ExecutionResult` containing any data and errors for the operation. diff --git a/graphene/types/tests/test_schema.py b/graphene/types/tests/test_schema.py index fe4739c9..54c48b4f 100644 --- a/graphene/types/tests/test_schema.py +++ b/graphene/types/tests/test_schema.py @@ -1,12 +1,13 @@ from graphql.type import GraphQLObjectType, GraphQLSchema -from pytest import raises +from graphql import GraphQLError +from pytest import mark, raises, fixture from graphene.tests.utils import dedent from ..field import Field from ..objecttype import ObjectType from ..scalars import String -from ..schema import Schema +from ..schema import Schema, UnforgivingExecutionContext class MyOtherType(ObjectType): @@ -68,3 +69,115 @@ def test_schema_requires_query_type(): assert len(result.errors) == 1 error = result.errors[0] assert error.message == "Query root type must be provided." + + +class TestUnforgivingExecutionContext: + @fixture + def schema(self): + class ErrorFieldsMixin: + sanity_field = String() + expected_error_field = String() + unexpected_value_error_field = String() + unexpected_type_error_field = String() + unexpected_attribute_error_field = String() + unexpected_key_error_field = String() + + @staticmethod + def resolve_sanity_field(obj, info): + return "not an error" + + @staticmethod + def resolve_expected_error_field(obj, info): + raise GraphQLError("expected error") + + @staticmethod + def resolve_unexpected_value_error_field(obj, info): + raise ValueError("unexpected error") + + @staticmethod + def resolve_unexpected_type_error_field(obj, info): + raise TypeError("unexpected error") + + @staticmethod + def resolve_unexpected_attribute_error_field(obj, info): + raise AttributeError("unexpected error") + + @staticmethod + def resolve_unexpected_key_error_field(obj, info): + return {}["fails"] + + class NestedObject(ErrorFieldsMixin, ObjectType): + pass + + class MyQuery(ErrorFieldsMixin, ObjectType): + nested_object = Field(NestedObject) + nested_object_error = Field(NestedObject) + + @staticmethod + def resolve_nested_object(obj, info): + return object() + + @staticmethod + def resolve_nested_object_error(obj, info): + raise TypeError() + + schema = Schema(query=MyQuery) + return schema + + def test_sanity_check(self, schema): + # this should pass with no errors (sanity check) + result = schema.execute( + "query { sanityField }", + execution_context_class=UnforgivingExecutionContext, + ) + assert not result.errors + assert result.data == {"sanityField": "not an error"} + + def test_nested_sanity_check(self, schema): + # this should pass with no errors (sanity check) + result = schema.execute( + r"query { nestedObject { sanityField } }", + execution_context_class=UnforgivingExecutionContext, + ) + assert not result.errors + assert result.data == {"nestedObject": {"sanityField": "not an error"}} + + def test_graphql_error(self, schema): + result = schema.execute( + "query { expectedErrorField }", + execution_context_class=UnforgivingExecutionContext, + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "expected error" + assert result.data == {"expectedErrorField": None} + + def test_nested_graphql_error(self, schema): + result = schema.execute( + r"query { nestedObject { expectedErrorField } }", + execution_context_class=UnforgivingExecutionContext, + ) + assert len(result.errors) == 1 + assert result.errors[0].message == "expected error" + assert result.data == {"nestedObject": {"expectedErrorField": None}} + + @mark.parametrize( + "field,exception", + [ + ("unexpectedValueErrorField", ValueError), + ("unexpectedTypeErrorField", TypeError), + ("unexpectedAttributeErrorField", AttributeError), + ("unexpectedKeyErrorField", KeyError), + ("nestedObject { unexpectedValueErrorField }", ValueError), + ("nestedObject { unexpectedTypeErrorField }", TypeError), + ("nestedObject { unexpectedAttributeErrorField }", AttributeError), + ("nestedObject { unexpectedKeyErrorField }", KeyError), + ("nestedObjectError { __typename }", TypeError), + ], + ) + def test_unexpected_error(self, field, exception, schema): + with raises(exception): + # no result, but the exception should be propagated + schema.execute( + f"query {{ {field} }}", + execution_context_class=UnforgivingExecutionContext, + )