Added support for subscription (#1107)

* Added support for subscription

* Added pre-commit hooks for black and formatted changed files

* Checked with flake8

* Integrated changes from master.

Co-authored-by: Rob Blackbourn <rblackbourn@bhdgsystematic.com>
Co-authored-by: Rob Blackbourn <rtb@beast.jetblack.net>
This commit is contained in:
Rob Blackbourn 2020-03-14 16:48:12 +00:00 committed by GitHub
parent 88f79b2850
commit 1cf303a27b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 4 deletions

View File

@ -17,6 +17,43 @@ For executing a query against a schema, you can directly call the ``execute`` me
``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred. ``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred.
For executing a subscription, you can directly call the ``subscribe`` method on it.
This method is async and must be awaited.
.. code:: python
import asyncio
from datetime import datetime
from graphene import ObjectType, String, Schema, Field
# All schema require a query.
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return 'Hello, world!'
class Subscription(ObjectType):
time_of_day = Field(String)
async def subscribe_time_of_day(root, info):
while True:
yield { 'time_of_day': datetime.now().isoformat()}
await asyncio.sleep(1)
SCHEMA = Schema(query=Query, subscription=Subscription)
async def main(schema):
subscription = 'subscription { timeOfDay }'
result = await schema.subscribe(subscription)
async for item in result:
print(item.data['timeOfDay'])
asyncio.run(main(SCHEMA))
The ``result`` is an async iterator which yields items in the same manner as a query.
.. _SchemaExecuteContext: .. _SchemaExecuteContext:
Context Context

View File

@ -7,7 +7,9 @@ from graphql import (
graphql, graphql,
graphql_sync, graphql_sync,
introspection_types, introspection_types,
parse,
print_schema, print_schema,
subscribe,
GraphQLArgument, GraphQLArgument,
GraphQLBoolean, GraphQLBoolean,
GraphQLEnumValue, GraphQLEnumValue,
@ -309,13 +311,19 @@ class TypeMap(dict):
if isinstance(arg.type, NonNull) if isinstance(arg.type, NonNull)
else arg.default_value, else arg.default_value,
) )
resolve = field.get_resolver(
self.get_resolver(graphene_type, name, field.default_value)
)
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolve=resolve, resolve=field.get_resolver(
self.get_resolver_for_type(
graphene_type, "resolve_{}", name, field.default_value
)
),
subscribe=field.get_resolver(
self.get_resolver_for_type(
graphene_type, "subscribe_{}", name, field.default_value
)
),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description, description=field.description,
) )
@ -323,6 +331,32 @@ class TypeMap(dict):
fields[field_name] = _field fields[field_name] = _field
return fields return fields
def get_resolver_for_type(self, graphene_type, pattern, name, default_value):
if not issubclass(graphene_type, ObjectType):
return
func_name = pattern.format(name)
resolver = getattr(graphene_type, func_name, None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
# find it in each of the interfaces
interface_resolver = None
for interface in graphene_type._meta.interfaces:
if name not in interface._meta.fields:
continue
interface_resolver = getattr(interface, func_name, None)
if interface_resolver:
break
resolver = interface_resolver
# Only if is not decorated with classmethod
if resolver:
return get_unbound_function(resolver)
default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
return partial(default_resolver, name, default_value)
def resolve_type(self, resolve_type_func, type_name, root, info, _type): def resolve_type(self, resolve_type_func, type_name, root, info, _type):
type_ = resolve_type_func(root, info) type_ = resolve_type_func(root, info)
@ -468,6 +502,11 @@ class Schema:
kwargs = normalize_execute_kwargs(kwargs) kwargs = normalize_execute_kwargs(kwargs)
return await graphql(self.graphql_schema, *args, **kwargs) return await graphql(self.graphql_schema, *args, **kwargs)
async def subscribe(self, query, *args, **kwargs):
document = parse(query)
kwargs = normalize_execute_kwargs(kwargs)
return await subscribe(self.graphql_schema, document, *args, **kwargs)
def introspect(self): def introspect(self):
introspection = self.execute(introspection_query) introspection = self.execute(introspection_query)
if introspection.errors: if introspection.errors:

View File

@ -0,0 +1,33 @@
from pytest import mark
from graphene import ObjectType, Int, String, Schema, Field
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return "Hello, world!"
class Subscription(ObjectType):
count_to_ten = Field(Int)
async def subscribe_count_to_ten(root, info):
count = 0
while count < 10:
count += 1
yield {"count_to_ten": count}
schema = Schema(query=Query, subscription=Subscription)
@mark.asyncio
async def test_subscription():
subscription = "subscription { countToTen }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countToTen"]
assert count == 10