diff --git a/graphene/types/schema.py b/graphene/types/schema.py index 29ead4a7..36ce6f25 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -315,14 +315,14 @@ class TypeMap(dict): graphene_type, f"resolve_{name}", name, field.default_value ) ), - subscribe=field.get_resolver( - self.get_resolver_for_type( - graphene_type, - f"subscribe_{name}", - name, - field.default_value, - ) - ), + # subscribe=field.get_resolver( + # self.get_resolver_for_type( + # graphene_type, + # f"subscribe_{name}", + # name, + # field.default_value, + # ) + # ), deprecation_reason=field.deprecation_reason, description=field.description, ) @@ -478,7 +478,14 @@ class Schema: async def subscribe(self, query, *args, **kwargs): document = parse(query) kwargs = normalize_execute_kwargs(kwargs) - return await subscribe(self.graphql_schema, document, *args, **kwargs) + subscription = self.graphql_schema.subscription_type + return await subscribe( + self.graphql_schema, + document, + *args, + subscribe_field_resolver=subscription.graphene_type._meta.resolver, + **kwargs, + ) def introspect(self): introspection = self.execute(introspection_query) diff --git a/tests_asyncio/test_subscribe.py b/tests_asyncio/test_subscribe.py index bf985d58..f7ae9898 100644 --- a/tests_asyncio/test_subscribe.py +++ b/tests_asyncio/test_subscribe.py @@ -1,26 +1,50 @@ from pytest import mark -from graphene import ObjectType, Int, String, Schema, Field +from graphene import ObjectType, Int, String, Schema +from graphene.types.objecttype import ObjectTypeOptions +from graphene.utils.get_unbound_function import get_unbound_function class Query(ObjectType): - hello = String() + a = String() - def resolve_hello(root, info): - return "Hello, world!" + +class SubscriptionOptions(ObjectTypeOptions): + pass class Subscription(ObjectType): - count_to_ten = Field(Int) + @classmethod + def __init_subclass_with_meta__( + cls, + resolver=None, + _meta=None, + **options, + ): + if not _meta: + _meta = SubscriptionOptions(cls) - async def subscribe_count_to_ten(root, info): + if not resolver: + subscribe = getattr(cls, "subscribe", None) + assert subscribe, "The Subscribe class must define a subscribe method" + resolver = get_unbound_function(subscribe) + + _meta.resolver = resolver + + super().__init_subclass_with_meta__(_meta=_meta, **options) + + +class MySubscription(Subscription): + count_to_ten = Int() + + async def subscribe(root, info): count = 0 while count < 10: count += 1 yield {"count_to_ten": count} -schema = Schema(query=Query, subscription=Subscription) +schema = Schema(query=Query, subscription=MySubscription) @mark.asyncio