diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 36ce6f25..4921aa3e 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -24,7 +24,10 @@ from graphql import ( GraphQLSchema, GraphQLString, Undefined, + ExecutionContext, + OperationType, ) +from graphql.subscription.map_async_iterator import MapAsyncIterator from ..utils.str_converters import to_camel_case from ..utils.get_unbound_function import get_unbound_function @@ -76,6 +79,10 @@ def is_type_of_from_possible_types(possible_types, root, _info): return isinstance(root, possible_types) +def map_payload_to_object(name, payload): + return {name: payload} + + class TypeMap(dict): def __init__( self, @@ -307,22 +314,19 @@ class TypeMap(dict): if isinstance(arg.type, NonNull) else arg.default_value, ) + + # TODO only look for subscribe function if in Subscription type _field = GraphQLField( field_type, args=args, resolve=field.get_resolver( - self.get_resolver_for_type( - graphene_type, f"resolve_{name}", name, field.default_value + self.get_resolver(graphene_type, name, field.default_value) + ), + subscribe=field.get_resolver( + self.get_subscribe_resolver( + graphene_type, name, field.default_value, ) ), - # subscribe=field.get_resolver( - # self.get_resolver_for_type( - # graphene_type, - # f"subscribe_{name}", - # name, - # field.default_value, - # ) - # ), deprecation_reason=field.deprecation_reason, description=field.description, ) @@ -330,9 +334,30 @@ class TypeMap(dict): fields[field_name] = _field return fields - def get_resolver_for_type(self, graphene_type, func_name, name, default_value): + def get_subscribe_resolver(self, graphene_type, name, default_value): if not issubclass(graphene_type, ObjectType): return + func_name = f"subscribe_{name}" + resolver = getattr(graphene_type, func_name, None) + + if not resolver: + # TODO + return None + + resolver = get_unbound_function(resolver) + + # TODO wrap resolver + def wrapped_resolver(*args, **kwargs): + result = resolver(*args, **kwargs) + return MapAsyncIterator(result, partial(map_payload_to_object, name)) + + return wrapped_resolver + + def get_resolver(self, graphene_type, name, default_value): + if not issubclass(graphene_type, ObjectType): + return + + func_name = f"resolver_{name}" resolver = getattr(graphene_type, func_name, None) if not resolver: # If we don't find the resolver in the ObjectType class, then try to @@ -475,18 +500,32 @@ class Schema: kwargs = normalize_execute_kwargs(kwargs) return await graphql(self.graphql_schema, *args, **kwargs) - async def subscribe(self, query, *args, **kwargs): + async def subscribe(self, query, **kwargs): document = parse(query) kwargs = normalize_execute_kwargs(kwargs) - subscription = self.graphql_schema.subscription_type - return await subscribe( + context = ExecutionContext.build( self.graphql_schema, document, - *args, - subscribe_field_resolver=subscription.graphene_type._meta.resolver, - **kwargs, + root_value=kwargs.get("root_value", None), + context_value=kwargs.get("context_value", None), + raw_variable_values=kwargs.get("variable_values", None), + operation_name=kwargs.get("operation_name", None), ) + if context.operation.operation != OperationType.SUBSCRIPTION: + raise RuntimeError("Subscription requires a subscribe operation type") + + type_ = self.graphql_schema.subscription_type + fields = context.collect_fields( + type_, context.operation.selection_set, {}, set() + ) + + if len(fields.keys()) > 1: + raise RuntimeError("Can't select more than 1 field to subscribe to") + + # subscription_type = self.graphql_schema.subscription_type + return await subscribe(self.graphql_schema, document, **kwargs,) + def introspect(self): introspection = self.execute(introspection_query) if introspection.errors: diff --git a/tests_asyncio/test_subscribe.py b/tests_asyncio/test_subscribe.py index f7ae9898..2fd3fdb8 100644 --- a/tests_asyncio/test_subscribe.py +++ b/tests_asyncio/test_subscribe.py @@ -1,10 +1,16 @@ +from datetime import datetime from pytest import mark -from graphene import ObjectType, Int, String, Schema +from graphene import ObjectType, Int, String, Schema, DateTime from graphene.types.objecttype import ObjectTypeOptions from graphene.utils.get_unbound_function import get_unbound_function +MYPY = False +if MYPY: + from typing import Callable # NOQA + + class Query(ObjectType): a = String() @@ -16,32 +22,22 @@ class SubscriptionOptions(ObjectTypeOptions): class Subscription(ObjectType): @classmethod def __init_subclass_with_meta__( - cls, - resolver=None, - _meta=None, - **options, + cls, _meta=None, **options, ): if not _meta: _meta = SubscriptionOptions(cls) - if not resolver: - subscribe = getattr(cls, "subscribe", None) - assert subscribe, "The Subscribe class must define a subscribe method" - resolver = get_unbound_function(subscribe) - - _meta.resolver = resolver - super().__init_subclass_with_meta__(_meta=_meta, **options) class MySubscription(Subscription): - count_to_ten = Int() + count_to_ten = Int(yes=Int()) - async def subscribe(root, info): + async def subscribe_count_to_ten(root, info, **kwargs): count = 0 while count < 10: count += 1 - yield {"count_to_ten": count} + yield count schema = Schema(query=Query, subscription=MySubscription) @@ -49,7 +45,11 @@ schema = Schema(query=Query, subscription=MySubscription) @mark.asyncio async def test_subscription(): - subscription = "subscription { countToTen }" + subscription = """ + subscription { + countToTen + } + """ result = await schema.subscribe(subscription) count = 0 async for item in result: