Fix subscription resolver

This commit is contained in:
Jonathan Kim 2020-07-11 13:37:18 +01:00
parent 40627bb713
commit 2ec251c290
2 changed files with 72 additions and 33 deletions

View File

@ -24,7 +24,10 @@ from graphql import (
GraphQLSchema,
GraphQLString,
Undefined,
ExecutionContext,
OperationType,
)
from graphql.subscription.map_async_iterator import MapAsyncIterator
from ..utils.str_converters import to_camel_case
from ..utils.get_unbound_function import get_unbound_function
@ -76,6 +79,10 @@ def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types)
def map_payload_to_object(name, payload):
return {name: payload}
class TypeMap(dict):
def __init__(
self,
@ -307,22 +314,19 @@ class TypeMap(dict):
if isinstance(arg.type, NonNull)
else arg.default_value,
)
# TODO only look for subscribe function if in Subscription type
_field = GraphQLField(
field_type,
args=args,
resolve=field.get_resolver(
self.get_resolver_for_type(
graphene_type, f"resolve_{name}", name, field.default_value
self.get_resolver(graphene_type, name, field.default_value)
),
subscribe=field.get_resolver(
self.get_subscribe_resolver(
graphene_type, name, field.default_value,
)
),
# subscribe=field.get_resolver(
# self.get_resolver_for_type(
# graphene_type,
# f"subscribe_{name}",
# name,
# field.default_value,
# )
# ),
deprecation_reason=field.deprecation_reason,
description=field.description,
)
@ -330,9 +334,30 @@ class TypeMap(dict):
fields[field_name] = _field
return fields
def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
def get_subscribe_resolver(self, graphene_type, name, default_value):
if not issubclass(graphene_type, ObjectType):
return
func_name = f"subscribe_{name}"
resolver = getattr(graphene_type, func_name, None)
if not resolver:
# TODO
return None
resolver = get_unbound_function(resolver)
# TODO wrap resolver
def wrapped_resolver(*args, **kwargs):
result = resolver(*args, **kwargs)
return MapAsyncIterator(result, partial(map_payload_to_object, name))
return wrapped_resolver
def get_resolver(self, graphene_type, name, default_value):
if not issubclass(graphene_type, ObjectType):
return
func_name = f"resolver_{name}"
resolver = getattr(graphene_type, func_name, None)
if not resolver:
# If we don't find the resolver in the ObjectType class, then try to
@ -475,18 +500,32 @@ class Schema:
kwargs = normalize_execute_kwargs(kwargs)
return await graphql(self.graphql_schema, *args, **kwargs)
async def subscribe(self, query, *args, **kwargs):
async def subscribe(self, query, **kwargs):
document = parse(query)
kwargs = normalize_execute_kwargs(kwargs)
subscription = self.graphql_schema.subscription_type
return await subscribe(
context = ExecutionContext.build(
self.graphql_schema,
document,
*args,
subscribe_field_resolver=subscription.graphene_type._meta.resolver,
**kwargs,
root_value=kwargs.get("root_value", None),
context_value=kwargs.get("context_value", None),
raw_variable_values=kwargs.get("variable_values", None),
operation_name=kwargs.get("operation_name", None),
)
if context.operation.operation != OperationType.SUBSCRIPTION:
raise RuntimeError("Subscription requires a subscribe operation type")
type_ = self.graphql_schema.subscription_type
fields = context.collect_fields(
type_, context.operation.selection_set, {}, set()
)
if len(fields.keys()) > 1:
raise RuntimeError("Can't select more than 1 field to subscribe to")
# subscription_type = self.graphql_schema.subscription_type
return await subscribe(self.graphql_schema, document, **kwargs,)
def introspect(self):
introspection = self.execute(introspection_query)
if introspection.errors:

View File

@ -1,10 +1,16 @@
from datetime import datetime
from pytest import mark
from graphene import ObjectType, Int, String, Schema
from graphene import ObjectType, Int, String, Schema, DateTime
from graphene.types.objecttype import ObjectTypeOptions
from graphene.utils.get_unbound_function import get_unbound_function
MYPY = False
if MYPY:
from typing import Callable # NOQA
class Query(ObjectType):
a = String()
@ -16,32 +22,22 @@ class SubscriptionOptions(ObjectTypeOptions):
class Subscription(ObjectType):
@classmethod
def __init_subclass_with_meta__(
cls,
resolver=None,
_meta=None,
**options,
cls, _meta=None, **options,
):
if not _meta:
_meta = SubscriptionOptions(cls)
if not resolver:
subscribe = getattr(cls, "subscribe", None)
assert subscribe, "The Subscribe class must define a subscribe method"
resolver = get_unbound_function(subscribe)
_meta.resolver = resolver
super().__init_subclass_with_meta__(_meta=_meta, **options)
class MySubscription(Subscription):
count_to_ten = Int()
count_to_ten = Int(yes=Int())
async def subscribe(root, info):
async def subscribe_count_to_ten(root, info, **kwargs):
count = 0
while count < 10:
count += 1
yield {"count_to_ten": count}
yield count
schema = Schema(query=Query, subscription=MySubscription)
@ -49,7 +45,11 @@ schema = Schema(query=Query, subscription=MySubscription)
@mark.asyncio
async def test_subscription():
subscription = "subscription { countToTen }"
subscription = """
subscription {
countToTen
}
"""
result = await schema.subscribe(subscription)
count = 0
async for item in result: