Modify subscription api

This commit is contained in:
Jonathan Kim 2020-07-11 11:51:04 +01:00
parent c61f0f736a
commit 40627bb713
2 changed files with 47 additions and 16 deletions

View File

@ -315,14 +315,14 @@ class TypeMap(dict):
graphene_type, f"resolve_{name}", name, field.default_value
)
),
subscribe=field.get_resolver(
self.get_resolver_for_type(
graphene_type,
f"subscribe_{name}",
name,
field.default_value,
)
),
# subscribe=field.get_resolver(
# self.get_resolver_for_type(
# graphene_type,
# f"subscribe_{name}",
# name,
# field.default_value,
# )
# ),
deprecation_reason=field.deprecation_reason,
description=field.description,
)
@ -478,7 +478,14 @@ class Schema:
async def subscribe(self, query, *args, **kwargs):
document = parse(query)
kwargs = normalize_execute_kwargs(kwargs)
return await subscribe(self.graphql_schema, document, *args, **kwargs)
subscription = self.graphql_schema.subscription_type
return await subscribe(
self.graphql_schema,
document,
*args,
subscribe_field_resolver=subscription.graphene_type._meta.resolver,
**kwargs,
)
def introspect(self):
introspection = self.execute(introspection_query)

View File

@ -1,26 +1,50 @@
from pytest import mark
from graphene import ObjectType, Int, String, Schema, Field
from graphene import ObjectType, Int, String, Schema
from graphene.types.objecttype import ObjectTypeOptions
from graphene.utils.get_unbound_function import get_unbound_function
class Query(ObjectType):
hello = String()
a = String()
def resolve_hello(root, info):
return "Hello, world!"
class SubscriptionOptions(ObjectTypeOptions):
pass
class Subscription(ObjectType):
count_to_ten = Field(Int)
@classmethod
def __init_subclass_with_meta__(
cls,
resolver=None,
_meta=None,
**options,
):
if not _meta:
_meta = SubscriptionOptions(cls)
async def subscribe_count_to_ten(root, info):
if not resolver:
subscribe = getattr(cls, "subscribe", None)
assert subscribe, "The Subscribe class must define a subscribe method"
resolver = get_unbound_function(subscribe)
_meta.resolver = resolver
super().__init_subclass_with_meta__(_meta=_meta, **options)
class MySubscription(Subscription):
count_to_ten = Int()
async def subscribe(root, info):
count = 0
while count < 10:
count += 1
yield {"count_to_ten": count}
schema = Schema(query=Query, subscription=Subscription)
schema = Schema(query=Query, subscription=MySubscription)
@mark.asyncio