From 1cf303a27bf1a83dc4ed14f071a91c0d02f22825 Mon Sep 17 00:00:00 2001 From: Rob Blackbourn Date: Sat, 14 Mar 2020 16:48:12 +0000 Subject: [PATCH] Added support for subscription (#1107) * Added support for subscription * Added pre-commit hooks for black and formatted changed files * Checked with flake8 * Integrated changes from master. Co-authored-by: Rob Blackbourn Co-authored-by: Rob Blackbourn --- docs/execution/execute.rst | 37 ++++++++++++++++++++++++++ graphene/types/schema.py | 47 ++++++++++++++++++++++++++++++--- tests_asyncio/test_subscribe.py | 33 +++++++++++++++++++++++ 3 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 tests_asyncio/test_subscribe.py diff --git a/docs/execution/execute.rst b/docs/execution/execute.rst index f0ea8853..cd29d72d 100644 --- a/docs/execution/execute.rst +++ b/docs/execution/execute.rst @@ -17,6 +17,43 @@ For executing a query against a schema, you can directly call the ``execute`` me ``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred. +For executing a subscription, you can directly call the ``subscribe`` method on it. +This method is async and must be awaited. + +.. code:: python + + import asyncio + from datetime import datetime + from graphene import ObjectType, String, Schema, Field + + # All schema require a query. + class Query(ObjectType): + hello = String() + + def resolve_hello(root, info): + return 'Hello, world!' + + class Subscription(ObjectType): + time_of_day = Field(String) + + async def subscribe_time_of_day(root, info): + while True: + yield { 'time_of_day': datetime.now().isoformat()} + await asyncio.sleep(1) + + SCHEMA = Schema(query=Query, subscription=Subscription) + + async def main(schema): + + subscription = 'subscription { timeOfDay }' + result = await schema.subscribe(subscription) + async for item in result: + print(item.data['timeOfDay']) + + asyncio.run(main(SCHEMA)) + +The ``result`` is an async iterator which yields items in the same manner as a query. + .. _SchemaExecuteContext: Context diff --git a/graphene/types/schema.py b/graphene/types/schema.py index f1d1337e..5228fb44 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -7,7 +7,9 @@ from graphql import ( graphql, graphql_sync, introspection_types, + parse, print_schema, + subscribe, GraphQLArgument, GraphQLBoolean, GraphQLEnumValue, @@ -309,13 +311,19 @@ class TypeMap(dict): if isinstance(arg.type, NonNull) else arg.default_value, ) - resolve = field.get_resolver( - self.get_resolver(graphene_type, name, field.default_value) - ) _field = GraphQLField( field_type, args=args, - resolve=resolve, + resolve=field.get_resolver( + self.get_resolver_for_type( + graphene_type, "resolve_{}", name, field.default_value + ) + ), + subscribe=field.get_resolver( + self.get_resolver_for_type( + graphene_type, "subscribe_{}", name, field.default_value + ) + ), deprecation_reason=field.deprecation_reason, description=field.description, ) @@ -323,6 +331,32 @@ class TypeMap(dict): fields[field_name] = _field return fields + def get_resolver_for_type(self, graphene_type, pattern, name, default_value): + if not issubclass(graphene_type, ObjectType): + return + func_name = pattern.format(name) + resolver = getattr(graphene_type, func_name, None) + if not resolver: + # If we don't find the resolver in the ObjectType class, then try to + # find it in each of the interfaces + interface_resolver = None + for interface in graphene_type._meta.interfaces: + if name not in interface._meta.fields: + continue + interface_resolver = getattr(interface, func_name, None) + if interface_resolver: + break + resolver = interface_resolver + + # Only if is not decorated with classmethod + if resolver: + return get_unbound_function(resolver) + + default_resolver = ( + graphene_type._meta.default_resolver or get_default_resolver() + ) + return partial(default_resolver, name, default_value) + def resolve_type(self, resolve_type_func, type_name, root, info, _type): type_ = resolve_type_func(root, info) @@ -468,6 +502,11 @@ class Schema: kwargs = normalize_execute_kwargs(kwargs) return await graphql(self.graphql_schema, *args, **kwargs) + async def subscribe(self, query, *args, **kwargs): + document = parse(query) + kwargs = normalize_execute_kwargs(kwargs) + return await subscribe(self.graphql_schema, document, *args, **kwargs) + def introspect(self): introspection = self.execute(introspection_query) if introspection.errors: diff --git a/tests_asyncio/test_subscribe.py b/tests_asyncio/test_subscribe.py new file mode 100644 index 00000000..bf985d58 --- /dev/null +++ b/tests_asyncio/test_subscribe.py @@ -0,0 +1,33 @@ +from pytest import mark + +from graphene import ObjectType, Int, String, Schema, Field + + +class Query(ObjectType): + hello = String() + + def resolve_hello(root, info): + return "Hello, world!" + + +class Subscription(ObjectType): + count_to_ten = Field(Int) + + async def subscribe_count_to_ten(root, info): + count = 0 + while count < 10: + count += 1 + yield {"count_to_ten": count} + + +schema = Schema(query=Query, subscription=Subscription) + + +@mark.asyncio +async def test_subscription(): + subscription = "subscription { countToTen }" + result = await schema.subscribe(subscription) + count = 0 + async for item in result: + count = item.data["countToTen"] + assert count == 10