From de81875fe7801e297ac2db7a67f64208b5a94ae7 Mon Sep 17 00:00:00 2001 From: Syrus Date: Sun, 26 Jul 2020 21:58:41 -0700 Subject: [PATCH] Added full support for subscriptions --- graphene/relay/connection.py | 4 +- graphene/relay/node.py | 4 +- graphene/relay/tests/test_global_id.py | 4 +- graphene/types/field.py | 20 +++++- graphene/types/schema.py | 73 ++++++++++++++------ graphene/types/tests/test_subscribe_async.py | 23 +++++- 6 files changed, 100 insertions(+), 28 deletions(-) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index cfb6cb63..1a4684e5 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -171,8 +171,8 @@ class IterableConnectionField(Field): on_resolve = partial(cls.resolve_connection, connection_type, args) return maybe_thenable(resolved, on_resolve) - def get_resolver(self, parent_resolver): - resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) + def wrap_resolve(self, parent_resolver): + resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver) return partial(self.connection_resolver, resolver, self.type) diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 13fb8cea..b189bc97 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -37,7 +37,7 @@ class GlobalID(Field): parent_type_name = parent_type_name or info.parent_type.name return node.to_global_id(parent_type_name, type_id) # root._meta.name - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial( self.id_resolver, parent_resolver, @@ -60,7 +60,7 @@ class NodeField(Field): **kwargs, ) - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial(self.node_type.node_resolver, get_type(self.field_type)) diff --git a/graphene/relay/tests/test_global_id.py b/graphene/relay/tests/test_global_id.py index 2fe81300..81860d9d 100644 --- a/graphene/relay/tests/test_global_id.py +++ b/graphene/relay/tests/test_global_id.py @@ -45,7 +45,7 @@ def test_global_id_allows_overriding_of_node_and_required(): def test_global_id_defaults_to_info_parent_type(): my_id = "1" gid = GlobalID() - id_resolver = gid.get_resolver(lambda *_: my_id) + id_resolver = gid.wrap_resolve(lambda *_: my_id) my_global_id = id_resolver(None, Info(User)) assert my_global_id == to_global_id(User._meta.name, my_id) @@ -53,6 +53,6 @@ def test_global_id_defaults_to_info_parent_type(): def test_global_id_allows_setting_customer_parent_type(): my_id = "1" gid = GlobalID(parent_type=User) - id_resolver = gid.get_resolver(lambda *_: my_id) + id_resolver = gid.wrap_resolve(lambda *_: my_id) my_global_id = id_resolver(None, None) assert my_global_id == to_global_id(User._meta.name, my_id) diff --git a/graphene/types/field.py b/graphene/types/field.py index 1a1ccf93..423fe821 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -8,6 +8,7 @@ from .resolver import default_resolver from .structures import NonNull from .unmountedtype import UnmountedType from .utils import get_type +from ..utils.deprecated import warn_deprecation base_type = type @@ -114,5 +115,22 @@ class Field(MountedType): def type(self): return get_type(self._type) - def get_resolver(self, parent_resolver): + get_resolver = None + + def wrap_resolve(self, parent_resolver): + ''' + Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME} + (parent_resolver) if the Field definition has no resolver. + ''' + if self.get_resolver is not None: + warn_deprecation("The get_resolver method is being deprecated, please rename it to wrap_resolve.") + return self.get_resolver(parent_resolver) + return self.resolver or parent_resolver + + def wrap_subscribe(self, parent_subscribe): + ''' + Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME} + (parent_subscribe) if the Field definition has no subscribe. + ''' + return parent_subscribe diff --git a/graphene/types/schema.py b/graphene/types/schema.py index ce0c7439..e844d041 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -10,8 +10,11 @@ from graphql import ( parse, print_schema, subscribe, + validate, + ExecutionResult, GraphQLArgument, GraphQLBoolean, + GraphQLError, GraphQLEnumValue, GraphQLField, GraphQLFloat, @@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info): return isinstance(root, possible_types) +# We use this resolver for subscriptions +def identity_resolve(root, info): + return root + + class TypeMap(dict): def __init__( self, @@ -307,22 +315,39 @@ class TypeMap(dict): if isinstance(arg.type, NonNull) else arg.default_value, ) + subscribe = field.wrap_subscribe( + self.get_function_for_type( + graphene_type, + f"subscribe_{name}", + name, + field.default_value, + ) + ) + + # If we are in a subscription, we use (by default) an + # identity-based resolver for the root, rather than the + # default resolver for objects/dicts. + if subscribe: + field_default_resolver = identity_resolve + elif issubclass(graphene_type, ObjectType): + default_resolver = ( + graphene_type._meta.default_resolver or get_default_resolver() + ) + field_default_resolver = partial(default_resolver, name, field.default_value) + else: + field_default_resolver = None + + resolve = field.wrap_resolve( + self.get_function_for_type( + graphene_type, f"resolve_{name}", name, field.default_value + ) or field_default_resolver + ) + _field = GraphQLField( field_type, args=args, - resolve=field.get_resolver( - self.get_resolver_for_type( - graphene_type, f"resolve_{name}", name, field.default_value - ) - ), - subscribe=field.get_resolver( - self.get_resolver_for_type( - graphene_type, - f"subscribe_{name}", - name, - field.default_value, - ) - ), + resolve=resolve, + subscribe=subscribe, deprecation_reason=field.deprecation_reason, description=field.description, ) @@ -330,7 +355,8 @@ class TypeMap(dict): fields[field_name] = _field return fields - def get_resolver_for_type(self, graphene_type, func_name, name, default_value): + def get_function_for_type(self, graphene_type, func_name, name, default_value): + '''Gets a resolve or subscribe function for a given ObjectType''' if not issubclass(graphene_type, ObjectType): return resolver = getattr(graphene_type, func_name, None) @@ -350,11 +376,6 @@ class TypeMap(dict): if resolver: return get_unbound_function(resolver) - default_resolver = ( - graphene_type._meta.default_resolver or get_default_resolver() - ) - return partial(default_resolver, name, default_value) - def resolve_type(self, resolve_type_func, type_name, root, info, _type): type_ = resolve_type_func(root, info) @@ -476,7 +497,19 @@ class Schema: return await graphql(self.graphql_schema, *args, **kwargs) async def subscribe(self, query, *args, **kwargs): - document = parse(query) + """Execute a GraphQL subscription on the schema asynchronously.""" + # Do parsing + try: + document = parse(query) + except GraphQLError as error: + return ExecutionResult(data=None, errors=[error]) + + # Do validation + validation_errors = validate(self.graphql_schema, document) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + + # Execute queryss kwargs = normalize_execute_kwargs(kwargs) return await subscribe(self.graphql_schema, document, *args, **kwargs) diff --git a/graphene/types/tests/test_subscribe_async.py b/graphene/types/tests/test_subscribe_async.py index bf985d58..40f8081b 100644 --- a/graphene/types/tests/test_subscribe_async.py +++ b/graphene/types/tests/test_subscribe_async.py @@ -17,7 +17,7 @@ class Subscription(ObjectType): count = 0 while count < 10: count += 1 - yield {"count_to_ten": count} + yield count schema = Schema(query=Query, subscription=Subscription) @@ -31,3 +31,24 @@ async def test_subscription(): async for item in result: count = item.data["countToTen"] assert count == 10 + + +@mark.asyncio +async def test_subscription_fails_with_invalid_query(): + # It fails if the provided query is invalid + subscription = "subscription { " + result = await schema.subscribe(subscription) + assert not result.data + assert result.errors + assert "Syntax Error: Expected Name, found " in str(result.errors[0]) + + +@mark.asyncio +async def test_subscription_fails_when_query_is_not_valid(): + # It can't subscribe to two fields at the same time, triggering a + # validation error. + subscription = "subscription { countToTen, b: countToTen }" + result = await schema.subscribe(subscription) + assert not result.data + assert result.errors + assert "Anonymous Subscription must select only one top level field." in str(result.errors[0])