feat: Add subscribe to Field

This commit is contained in:
Victor 2022-12-27 23:56:31 +03:00
parent 52143473ef
commit 780012da47
2 changed files with 23 additions and 4 deletions

View File

@ -51,6 +51,8 @@ class Field(MountedType):
value object. If not set, the default resolver method for the schema is used. value object. If not set, the default resolver method for the schema is used.
source (optional, str): attribute name to resolve for this field from the parent value source (optional, str): attribute name to resolve for this field from the parent value
object. Alternative to resolver (cannot set both source and resolver). object. Alternative to resolver (cannot set both source and resolver).
subscribe (optional, AsyncIterable): Asynchronous iterator to get the value for a Field from the parent
value object. If not set, the default resolver method for the schema is used.
deprecation_reason (optional, str): Setting this value indicates that the field is deprecation_reason (optional, str): Setting this value indicates that the field is
depreciated and may provide instruction or reason on how for clients to proceed. depreciated and may provide instruction or reason on how for clients to proceed.
required (optional, bool): indicates this field as not null in the graphql schema. Same behavior as required (optional, bool): indicates this field as not null in the graphql schema. Same behavior as
@ -69,6 +71,7 @@ class Field(MountedType):
args=None, args=None,
resolver=None, resolver=None,
source=None, source=None,
subscribe=None,
deprecation_reason=None, deprecation_reason=None,
name=None, name=None,
description=None, description=None,
@ -107,6 +110,7 @@ class Field(MountedType):
if source: if source:
resolver = partial(source_resolver, source) resolver = partial(source_resolver, source)
self.resolver = resolver self.resolver = resolver
self.subscribe = subscribe
self.deprecation_reason = deprecation_reason self.deprecation_reason = deprecation_reason
self.description = description self.description = description
self.default_value = default_value self.default_value = default_value
@ -131,8 +135,8 @@ class Field(MountedType):
return self.resolver or parent_resolver return self.resolver or parent_resolver
def wrap_subscribe(self, parent_subscribe): def wrap_subscribe(self, parent_subscribe):
"""Wraps a function subscribe.
- using the ObjectType subscribe_{FIELD_NAME} (parent_subscribe) if the Field definition has no subscribe.
- using the Field.subscribe
""" """
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME} return parent_subscribe or self.subscribe
(parent_subscribe) if the Field definition has no subscribe.
"""
return parent_subscribe

View File

@ -10,13 +10,21 @@ class Query(ObjectType):
return "Hello, world!" return "Hello, world!"
async def subscribe_count_to_five(root, info):
for count in range(1, 6):
yield count
class Subscription(ObjectType): class Subscription(ObjectType):
count_to_ten = Field(Int) count_to_ten = Field(Int)
async def subscribe_count_to_ten(root, info): async def subscribe_count_to_ten(root, info):
for count in range(1, 11): for count in range(1, 11):
yield count yield count
count_to_five = Field(Int, subscribe=subscribe_count_to_five)
schema = Schema(query=Query, subscription=Subscription) schema = Schema(query=Query, subscription=Subscription)
@ -30,6 +38,13 @@ async def test_subscription():
count = item.data["countToTen"] count = item.data["countToTen"]
assert count == 10 assert count == 10
subscription = "subscription { countToFive }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countToFive"]
assert count == 5
@mark.asyncio @mark.asyncio
async def test_subscription_fails_with_invalid_query(): async def test_subscription_fails_with_invalid_query():