mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-27 09:12:50 +03:00
Fix subscription resolver
This commit is contained in:
parent
40627bb713
commit
2ec251c290
|
@ -24,7 +24,10 @@ from graphql import (
|
||||||
GraphQLSchema,
|
GraphQLSchema,
|
||||||
GraphQLString,
|
GraphQLString,
|
||||||
Undefined,
|
Undefined,
|
||||||
|
ExecutionContext,
|
||||||
|
OperationType,
|
||||||
)
|
)
|
||||||
|
from graphql.subscription.map_async_iterator import MapAsyncIterator
|
||||||
|
|
||||||
from ..utils.str_converters import to_camel_case
|
from ..utils.str_converters import to_camel_case
|
||||||
from ..utils.get_unbound_function import get_unbound_function
|
from ..utils.get_unbound_function import get_unbound_function
|
||||||
|
@ -76,6 +79,10 @@ def is_type_of_from_possible_types(possible_types, root, _info):
|
||||||
return isinstance(root, possible_types)
|
return isinstance(root, possible_types)
|
||||||
|
|
||||||
|
|
||||||
|
def map_payload_to_object(name, payload):
|
||||||
|
return {name: payload}
|
||||||
|
|
||||||
|
|
||||||
class TypeMap(dict):
|
class TypeMap(dict):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -307,22 +314,19 @@ class TypeMap(dict):
|
||||||
if isinstance(arg.type, NonNull)
|
if isinstance(arg.type, NonNull)
|
||||||
else arg.default_value,
|
else arg.default_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO only look for subscribe function if in Subscription type
|
||||||
_field = GraphQLField(
|
_field = GraphQLField(
|
||||||
field_type,
|
field_type,
|
||||||
args=args,
|
args=args,
|
||||||
resolve=field.get_resolver(
|
resolve=field.get_resolver(
|
||||||
self.get_resolver_for_type(
|
self.get_resolver(graphene_type, name, field.default_value)
|
||||||
graphene_type, f"resolve_{name}", name, field.default_value
|
),
|
||||||
|
subscribe=field.get_resolver(
|
||||||
|
self.get_subscribe_resolver(
|
||||||
|
graphene_type, name, field.default_value,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
# subscribe=field.get_resolver(
|
|
||||||
# self.get_resolver_for_type(
|
|
||||||
# graphene_type,
|
|
||||||
# f"subscribe_{name}",
|
|
||||||
# name,
|
|
||||||
# field.default_value,
|
|
||||||
# )
|
|
||||||
# ),
|
|
||||||
deprecation_reason=field.deprecation_reason,
|
deprecation_reason=field.deprecation_reason,
|
||||||
description=field.description,
|
description=field.description,
|
||||||
)
|
)
|
||||||
|
@ -330,9 +334,30 @@ class TypeMap(dict):
|
||||||
fields[field_name] = _field
|
fields[field_name] = _field
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
|
def get_subscribe_resolver(self, graphene_type, name, default_value):
|
||||||
if not issubclass(graphene_type, ObjectType):
|
if not issubclass(graphene_type, ObjectType):
|
||||||
return
|
return
|
||||||
|
func_name = f"subscribe_{name}"
|
||||||
|
resolver = getattr(graphene_type, func_name, None)
|
||||||
|
|
||||||
|
if not resolver:
|
||||||
|
# TODO
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolver = get_unbound_function(resolver)
|
||||||
|
|
||||||
|
# TODO wrap resolver
|
||||||
|
def wrapped_resolver(*args, **kwargs):
|
||||||
|
result = resolver(*args, **kwargs)
|
||||||
|
return MapAsyncIterator(result, partial(map_payload_to_object, name))
|
||||||
|
|
||||||
|
return wrapped_resolver
|
||||||
|
|
||||||
|
def get_resolver(self, graphene_type, name, default_value):
|
||||||
|
if not issubclass(graphene_type, ObjectType):
|
||||||
|
return
|
||||||
|
|
||||||
|
func_name = f"resolver_{name}"
|
||||||
resolver = getattr(graphene_type, func_name, None)
|
resolver = getattr(graphene_type, func_name, None)
|
||||||
if not resolver:
|
if not resolver:
|
||||||
# If we don't find the resolver in the ObjectType class, then try to
|
# If we don't find the resolver in the ObjectType class, then try to
|
||||||
|
@ -475,18 +500,32 @@ class Schema:
|
||||||
kwargs = normalize_execute_kwargs(kwargs)
|
kwargs = normalize_execute_kwargs(kwargs)
|
||||||
return await graphql(self.graphql_schema, *args, **kwargs)
|
return await graphql(self.graphql_schema, *args, **kwargs)
|
||||||
|
|
||||||
async def subscribe(self, query, *args, **kwargs):
|
async def subscribe(self, query, **kwargs):
|
||||||
document = parse(query)
|
document = parse(query)
|
||||||
kwargs = normalize_execute_kwargs(kwargs)
|
kwargs = normalize_execute_kwargs(kwargs)
|
||||||
subscription = self.graphql_schema.subscription_type
|
context = ExecutionContext.build(
|
||||||
return await subscribe(
|
|
||||||
self.graphql_schema,
|
self.graphql_schema,
|
||||||
document,
|
document,
|
||||||
*args,
|
root_value=kwargs.get("root_value", None),
|
||||||
subscribe_field_resolver=subscription.graphene_type._meta.resolver,
|
context_value=kwargs.get("context_value", None),
|
||||||
**kwargs,
|
raw_variable_values=kwargs.get("variable_values", None),
|
||||||
|
operation_name=kwargs.get("operation_name", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if context.operation.operation != OperationType.SUBSCRIPTION:
|
||||||
|
raise RuntimeError("Subscription requires a subscribe operation type")
|
||||||
|
|
||||||
|
type_ = self.graphql_schema.subscription_type
|
||||||
|
fields = context.collect_fields(
|
||||||
|
type_, context.operation.selection_set, {}, set()
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(fields.keys()) > 1:
|
||||||
|
raise RuntimeError("Can't select more than 1 field to subscribe to")
|
||||||
|
|
||||||
|
# subscription_type = self.graphql_schema.subscription_type
|
||||||
|
return await subscribe(self.graphql_schema, document, **kwargs,)
|
||||||
|
|
||||||
def introspect(self):
|
def introspect(self):
|
||||||
introspection = self.execute(introspection_query)
|
introspection = self.execute(introspection_query)
|
||||||
if introspection.errors:
|
if introspection.errors:
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
|
from datetime import datetime
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
from graphene import ObjectType, Int, String, Schema
|
from graphene import ObjectType, Int, String, Schema, DateTime
|
||||||
from graphene.types.objecttype import ObjectTypeOptions
|
from graphene.types.objecttype import ObjectTypeOptions
|
||||||
from graphene.utils.get_unbound_function import get_unbound_function
|
from graphene.utils.get_unbound_function import get_unbound_function
|
||||||
|
|
||||||
|
|
||||||
|
MYPY = False
|
||||||
|
if MYPY:
|
||||||
|
from typing import Callable # NOQA
|
||||||
|
|
||||||
|
|
||||||
class Query(ObjectType):
|
class Query(ObjectType):
|
||||||
a = String()
|
a = String()
|
||||||
|
|
||||||
|
@ -16,32 +22,22 @@ class SubscriptionOptions(ObjectTypeOptions):
|
||||||
class Subscription(ObjectType):
|
class Subscription(ObjectType):
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass_with_meta__(
|
def __init_subclass_with_meta__(
|
||||||
cls,
|
cls, _meta=None, **options,
|
||||||
resolver=None,
|
|
||||||
_meta=None,
|
|
||||||
**options,
|
|
||||||
):
|
):
|
||||||
if not _meta:
|
if not _meta:
|
||||||
_meta = SubscriptionOptions(cls)
|
_meta = SubscriptionOptions(cls)
|
||||||
|
|
||||||
if not resolver:
|
|
||||||
subscribe = getattr(cls, "subscribe", None)
|
|
||||||
assert subscribe, "The Subscribe class must define a subscribe method"
|
|
||||||
resolver = get_unbound_function(subscribe)
|
|
||||||
|
|
||||||
_meta.resolver = resolver
|
|
||||||
|
|
||||||
super().__init_subclass_with_meta__(_meta=_meta, **options)
|
super().__init_subclass_with_meta__(_meta=_meta, **options)
|
||||||
|
|
||||||
|
|
||||||
class MySubscription(Subscription):
|
class MySubscription(Subscription):
|
||||||
count_to_ten = Int()
|
count_to_ten = Int(yes=Int())
|
||||||
|
|
||||||
async def subscribe(root, info):
|
async def subscribe_count_to_ten(root, info, **kwargs):
|
||||||
count = 0
|
count = 0
|
||||||
while count < 10:
|
while count < 10:
|
||||||
count += 1
|
count += 1
|
||||||
yield {"count_to_ten": count}
|
yield count
|
||||||
|
|
||||||
|
|
||||||
schema = Schema(query=Query, subscription=MySubscription)
|
schema = Schema(query=Query, subscription=MySubscription)
|
||||||
|
@ -49,7 +45,11 @@ schema = Schema(query=Query, subscription=MySubscription)
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
async def test_subscription():
|
async def test_subscription():
|
||||||
subscription = "subscription { countToTen }"
|
subscription = """
|
||||||
|
subscription {
|
||||||
|
countToTen
|
||||||
|
}
|
||||||
|
"""
|
||||||
result = await schema.subscribe(subscription)
|
result = await schema.subscribe(subscription)
|
||||||
count = 0
|
count = 0
|
||||||
async for item in result:
|
async for item in result:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user