Fix subscription resolver

This commit is contained in:
Jonathan Kim 2020-07-11 13:37:18 +01:00
parent 40627bb713
commit 2ec251c290
2 changed files with 72 additions and 33 deletions

View File

@ -24,7 +24,10 @@ from graphql import (
GraphQLSchema, GraphQLSchema,
GraphQLString, GraphQLString,
Undefined, Undefined,
ExecutionContext,
OperationType,
) )
from graphql.subscription.map_async_iterator import MapAsyncIterator
from ..utils.str_converters import to_camel_case from ..utils.str_converters import to_camel_case
from ..utils.get_unbound_function import get_unbound_function from ..utils.get_unbound_function import get_unbound_function
@ -76,6 +79,10 @@ def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types) return isinstance(root, possible_types)
def map_payload_to_object(name, payload):
return {name: payload}
class TypeMap(dict): class TypeMap(dict):
def __init__( def __init__(
self, self,
@ -307,22 +314,19 @@ class TypeMap(dict):
if isinstance(arg.type, NonNull) if isinstance(arg.type, NonNull)
else arg.default_value, else arg.default_value,
) )
# TODO only look for subscribe function if in Subscription type
_field = GraphQLField( _field = GraphQLField(
field_type, field_type,
args=args, args=args,
resolve=field.get_resolver( resolve=field.get_resolver(
self.get_resolver_for_type( self.get_resolver(graphene_type, name, field.default_value)
graphene_type, f"resolve_{name}", name, field.default_value ),
subscribe=field.get_resolver(
self.get_subscribe_resolver(
graphene_type, name, field.default_value,
) )
), ),
# subscribe=field.get_resolver(
# self.get_resolver_for_type(
# graphene_type,
# f"subscribe_{name}",
# name,
# field.default_value,
# )
# ),
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
description=field.description, description=field.description,
) )
@ -330,9 +334,30 @@ class TypeMap(dict):
fields[field_name] = _field fields[field_name] = _field
return fields return fields
def get_resolver_for_type(self, graphene_type, func_name, name, default_value): def get_subscribe_resolver(self, graphene_type, name, default_value):
if not issubclass(graphene_type, ObjectType): if not issubclass(graphene_type, ObjectType):
return return
func_name = f"subscribe_{name}"
resolver = getattr(graphene_type, func_name, None)
if not resolver:
# TODO
return None
resolver = get_unbound_function(resolver)
# TODO wrap resolver
def wrapped_resolver(*args, **kwargs):
result = resolver(*args, **kwargs)
return MapAsyncIterator(result, partial(map_payload_to_object, name))
return wrapped_resolver
def get_resolver(self, graphene_type, name, default_value):
if not issubclass(graphene_type, ObjectType):
return
func_name = f"resolver_{name}"
resolver = getattr(graphene_type, func_name, None) resolver = getattr(graphene_type, func_name, None)
if not resolver: if not resolver:
# If we don't find the resolver in the ObjectType class, then try to # If we don't find the resolver in the ObjectType class, then try to
@ -475,18 +500,32 @@ class Schema:
kwargs = normalize_execute_kwargs(kwargs) kwargs = normalize_execute_kwargs(kwargs)
return await graphql(self.graphql_schema, *args, **kwargs) return await graphql(self.graphql_schema, *args, **kwargs)
async def subscribe(self, query, *args, **kwargs): async def subscribe(self, query, **kwargs):
document = parse(query) document = parse(query)
kwargs = normalize_execute_kwargs(kwargs) kwargs = normalize_execute_kwargs(kwargs)
subscription = self.graphql_schema.subscription_type context = ExecutionContext.build(
return await subscribe(
self.graphql_schema, self.graphql_schema,
document, document,
*args, root_value=kwargs.get("root_value", None),
subscribe_field_resolver=subscription.graphene_type._meta.resolver, context_value=kwargs.get("context_value", None),
**kwargs, raw_variable_values=kwargs.get("variable_values", None),
operation_name=kwargs.get("operation_name", None),
) )
if context.operation.operation != OperationType.SUBSCRIPTION:
raise RuntimeError("Subscription requires a subscribe operation type")
type_ = self.graphql_schema.subscription_type
fields = context.collect_fields(
type_, context.operation.selection_set, {}, set()
)
if len(fields.keys()) > 1:
raise RuntimeError("Can't select more than 1 field to subscribe to")
# subscription_type = self.graphql_schema.subscription_type
return await subscribe(self.graphql_schema, document, **kwargs,)
def introspect(self): def introspect(self):
introspection = self.execute(introspection_query) introspection = self.execute(introspection_query)
if introspection.errors: if introspection.errors:

View File

@ -1,10 +1,16 @@
from datetime import datetime
from pytest import mark from pytest import mark
from graphene import ObjectType, Int, String, Schema from graphene import ObjectType, Int, String, Schema, DateTime
from graphene.types.objecttype import ObjectTypeOptions from graphene.types.objecttype import ObjectTypeOptions
from graphene.utils.get_unbound_function import get_unbound_function from graphene.utils.get_unbound_function import get_unbound_function
MYPY = False
if MYPY:
from typing import Callable # NOQA
class Query(ObjectType): class Query(ObjectType):
a = String() a = String()
@ -16,32 +22,22 @@ class SubscriptionOptions(ObjectTypeOptions):
class Subscription(ObjectType): class Subscription(ObjectType):
@classmethod @classmethod
def __init_subclass_with_meta__( def __init_subclass_with_meta__(
cls, cls, _meta=None, **options,
resolver=None,
_meta=None,
**options,
): ):
if not _meta: if not _meta:
_meta = SubscriptionOptions(cls) _meta = SubscriptionOptions(cls)
if not resolver:
subscribe = getattr(cls, "subscribe", None)
assert subscribe, "The Subscribe class must define a subscribe method"
resolver = get_unbound_function(subscribe)
_meta.resolver = resolver
super().__init_subclass_with_meta__(_meta=_meta, **options) super().__init_subclass_with_meta__(_meta=_meta, **options)
class MySubscription(Subscription): class MySubscription(Subscription):
count_to_ten = Int() count_to_ten = Int(yes=Int())
async def subscribe(root, info): async def subscribe_count_to_ten(root, info, **kwargs):
count = 0 count = 0
while count < 10: while count < 10:
count += 1 count += 1
yield {"count_to_ten": count} yield count
schema = Schema(query=Query, subscription=MySubscription) schema = Schema(query=Query, subscription=MySubscription)
@ -49,7 +45,11 @@ schema = Schema(query=Query, subscription=MySubscription)
@mark.asyncio @mark.asyncio
async def test_subscription(): async def test_subscription():
subscription = "subscription { countToTen }" subscription = """
subscription {
countToTen
}
"""
result = await schema.subscribe(subscription) result = await schema.subscribe(subscription)
count = 0 count = 0
async for item in result: async for item in result: