mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-16 19:40:39 +03:00
Subscription revamp (#1235)
* Integrate async tests into main code * Added full support for subscriptions * Fixed syntax using black * Fixed typo
This commit is contained in:
parent
2130005406
commit
d085c8852b
6
Makefile
6
Makefile
|
@ -8,7 +8,7 @@ install-dev:
|
||||||
pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
test:
|
test:
|
||||||
py.test graphene examples tests_asyncio
|
py.test graphene examples
|
||||||
|
|
||||||
.PHONY: docs ## Generate docs
|
.PHONY: docs ## Generate docs
|
||||||
docs: install-dev
|
docs: install-dev
|
||||||
|
@ -20,8 +20,8 @@ docs-live: install-dev
|
||||||
|
|
||||||
.PHONY: format
|
.PHONY: format
|
||||||
format:
|
format:
|
||||||
black graphene examples setup.py tests_asyncio
|
black graphene examples setup.py
|
||||||
|
|
||||||
.PHONY: lint
|
.PHONY: lint
|
||||||
lint:
|
lint:
|
||||||
flake8 graphene examples setup.py tests_asyncio
|
flake8 graphene examples setup.py
|
||||||
|
|
|
@ -171,8 +171,8 @@ class IterableConnectionField(Field):
|
||||||
on_resolve = partial(cls.resolve_connection, connection_type, args)
|
on_resolve = partial(cls.resolve_connection, connection_type, args)
|
||||||
return maybe_thenable(resolved, on_resolve)
|
return maybe_thenable(resolved, on_resolve)
|
||||||
|
|
||||||
def get_resolver(self, parent_resolver):
|
def wrap_resolve(self, parent_resolver):
|
||||||
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
|
resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
|
||||||
return partial(self.connection_resolver, resolver, self.type)
|
return partial(self.connection_resolver, resolver, self.type)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class GlobalID(Field):
|
||||||
parent_type_name = parent_type_name or info.parent_type.name
|
parent_type_name = parent_type_name or info.parent_type.name
|
||||||
return node.to_global_id(parent_type_name, type_id) # root._meta.name
|
return node.to_global_id(parent_type_name, type_id) # root._meta.name
|
||||||
|
|
||||||
def get_resolver(self, parent_resolver):
|
def wrap_resolve(self, parent_resolver):
|
||||||
return partial(
|
return partial(
|
||||||
self.id_resolver,
|
self.id_resolver,
|
||||||
parent_resolver,
|
parent_resolver,
|
||||||
|
@ -60,7 +60,7 @@ class NodeField(Field):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_resolver(self, parent_resolver):
|
def wrap_resolve(self, parent_resolver):
|
||||||
return partial(self.node_type.node_resolver, get_type(self.field_type))
|
return partial(self.node_type.node_resolver, get_type(self.field_type))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ def test_global_id_allows_overriding_of_node_and_required():
|
||||||
def test_global_id_defaults_to_info_parent_type():
|
def test_global_id_defaults_to_info_parent_type():
|
||||||
my_id = "1"
|
my_id = "1"
|
||||||
gid = GlobalID()
|
gid = GlobalID()
|
||||||
id_resolver = gid.get_resolver(lambda *_: my_id)
|
id_resolver = gid.wrap_resolve(lambda *_: my_id)
|
||||||
my_global_id = id_resolver(None, Info(User))
|
my_global_id = id_resolver(None, Info(User))
|
||||||
assert my_global_id == to_global_id(User._meta.name, my_id)
|
assert my_global_id == to_global_id(User._meta.name, my_id)
|
||||||
|
|
||||||
|
@ -53,6 +53,6 @@ def test_global_id_defaults_to_info_parent_type():
|
||||||
def test_global_id_allows_setting_customer_parent_type():
|
def test_global_id_allows_setting_customer_parent_type():
|
||||||
my_id = "1"
|
my_id = "1"
|
||||||
gid = GlobalID(parent_type=User)
|
gid = GlobalID(parent_type=User)
|
||||||
id_resolver = gid.get_resolver(lambda *_: my_id)
|
id_resolver = gid.wrap_resolve(lambda *_: my_id)
|
||||||
my_global_id = id_resolver(None, None)
|
my_global_id = id_resolver(None, None)
|
||||||
assert my_global_id == to_global_id(User._meta.name, my_id)
|
assert my_global_id == to_global_id(User._meta.name, my_id)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from .resolver import default_resolver
|
||||||
from .structures import NonNull
|
from .structures import NonNull
|
||||||
from .unmountedtype import UnmountedType
|
from .unmountedtype import UnmountedType
|
||||||
from .utils import get_type
|
from .utils import get_type
|
||||||
|
from ..utils.deprecated import warn_deprecation
|
||||||
|
|
||||||
base_type = type
|
base_type = type
|
||||||
|
|
||||||
|
@ -114,5 +115,24 @@ class Field(MountedType):
|
||||||
def type(self):
|
def type(self):
|
||||||
return get_type(self._type)
|
return get_type(self._type)
|
||||||
|
|
||||||
def get_resolver(self, parent_resolver):
|
get_resolver = None
|
||||||
|
|
||||||
|
def wrap_resolve(self, parent_resolver):
|
||||||
|
"""
|
||||||
|
Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
|
||||||
|
(parent_resolver) if the Field definition has no resolver.
|
||||||
|
"""
|
||||||
|
if self.get_resolver is not None:
|
||||||
|
warn_deprecation(
|
||||||
|
"The get_resolver method is being deprecated, please rename it to wrap_resolve."
|
||||||
|
)
|
||||||
|
return self.get_resolver(parent_resolver)
|
||||||
|
|
||||||
return self.resolver or parent_resolver
|
return self.resolver or parent_resolver
|
||||||
|
|
||||||
|
def wrap_subscribe(self, parent_subscribe):
|
||||||
|
"""
|
||||||
|
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
|
||||||
|
(parent_subscribe) if the Field definition has no subscribe.
|
||||||
|
"""
|
||||||
|
return parent_subscribe
|
||||||
|
|
|
@ -10,8 +10,11 @@ from graphql import (
|
||||||
parse,
|
parse,
|
||||||
print_schema,
|
print_schema,
|
||||||
subscribe,
|
subscribe,
|
||||||
|
validate,
|
||||||
|
ExecutionResult,
|
||||||
GraphQLArgument,
|
GraphQLArgument,
|
||||||
GraphQLBoolean,
|
GraphQLBoolean,
|
||||||
|
GraphQLError,
|
||||||
GraphQLEnumValue,
|
GraphQLEnumValue,
|
||||||
GraphQLField,
|
GraphQLField,
|
||||||
GraphQLFloat,
|
GraphQLFloat,
|
||||||
|
@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info):
|
||||||
return isinstance(root, possible_types)
|
return isinstance(root, possible_types)
|
||||||
|
|
||||||
|
|
||||||
|
# We use this resolver for subscriptions
|
||||||
|
def identity_resolve(root, info):
|
||||||
|
return root
|
||||||
|
|
||||||
|
|
||||||
class TypeMap(dict):
|
class TypeMap(dict):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -307,22 +315,39 @@ class TypeMap(dict):
|
||||||
if isinstance(arg.type, NonNull)
|
if isinstance(arg.type, NonNull)
|
||||||
else arg.default_value,
|
else arg.default_value,
|
||||||
)
|
)
|
||||||
|
subscribe = field.wrap_subscribe(
|
||||||
|
self.get_function_for_type(
|
||||||
|
graphene_type, f"subscribe_{name}", name, field.default_value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we are in a subscription, we use (by default) an
|
||||||
|
# identity-based resolver for the root, rather than the
|
||||||
|
# default resolver for objects/dicts.
|
||||||
|
if subscribe:
|
||||||
|
field_default_resolver = identity_resolve
|
||||||
|
elif issubclass(graphene_type, ObjectType):
|
||||||
|
default_resolver = (
|
||||||
|
graphene_type._meta.default_resolver or get_default_resolver()
|
||||||
|
)
|
||||||
|
field_default_resolver = partial(
|
||||||
|
default_resolver, name, field.default_value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
field_default_resolver = None
|
||||||
|
|
||||||
|
resolve = field.wrap_resolve(
|
||||||
|
self.get_function_for_type(
|
||||||
|
graphene_type, f"resolve_{name}", name, field.default_value
|
||||||
|
)
|
||||||
|
or field_default_resolver
|
||||||
|
)
|
||||||
|
|
||||||
_field = GraphQLField(
|
_field = GraphQLField(
|
||||||
field_type,
|
field_type,
|
||||||
args=args,
|
args=args,
|
||||||
resolve=field.get_resolver(
|
resolve=resolve,
|
||||||
self.get_resolver_for_type(
|
subscribe=subscribe,
|
||||||
graphene_type, f"resolve_{name}", name, field.default_value
|
|
||||||
)
|
|
||||||
),
|
|
||||||
subscribe=field.get_resolver(
|
|
||||||
self.get_resolver_for_type(
|
|
||||||
graphene_type,
|
|
||||||
f"subscribe_{name}",
|
|
||||||
name,
|
|
||||||
field.default_value,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
deprecation_reason=field.deprecation_reason,
|
deprecation_reason=field.deprecation_reason,
|
||||||
description=field.description,
|
description=field.description,
|
||||||
)
|
)
|
||||||
|
@ -330,7 +355,8 @@ class TypeMap(dict):
|
||||||
fields[field_name] = _field
|
fields[field_name] = _field
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
|
def get_function_for_type(self, graphene_type, func_name, name, default_value):
|
||||||
|
"""Gets a resolve or subscribe function for a given ObjectType"""
|
||||||
if not issubclass(graphene_type, ObjectType):
|
if not issubclass(graphene_type, ObjectType):
|
||||||
return
|
return
|
||||||
resolver = getattr(graphene_type, func_name, None)
|
resolver = getattr(graphene_type, func_name, None)
|
||||||
|
@ -350,11 +376,6 @@ class TypeMap(dict):
|
||||||
if resolver:
|
if resolver:
|
||||||
return get_unbound_function(resolver)
|
return get_unbound_function(resolver)
|
||||||
|
|
||||||
default_resolver = (
|
|
||||||
graphene_type._meta.default_resolver or get_default_resolver()
|
|
||||||
)
|
|
||||||
return partial(default_resolver, name, default_value)
|
|
||||||
|
|
||||||
def resolve_type(self, resolve_type_func, type_name, root, info, _type):
|
def resolve_type(self, resolve_type_func, type_name, root, info, _type):
|
||||||
type_ = resolve_type_func(root, info)
|
type_ = resolve_type_func(root, info)
|
||||||
|
|
||||||
|
@ -476,7 +497,19 @@ class Schema:
|
||||||
return await graphql(self.graphql_schema, *args, **kwargs)
|
return await graphql(self.graphql_schema, *args, **kwargs)
|
||||||
|
|
||||||
async def subscribe(self, query, *args, **kwargs):
|
async def subscribe(self, query, *args, **kwargs):
|
||||||
document = parse(query)
|
"""Execute a GraphQL subscription on the schema asynchronously."""
|
||||||
|
# Do parsing
|
||||||
|
try:
|
||||||
|
document = parse(query)
|
||||||
|
except GraphQLError as error:
|
||||||
|
return ExecutionResult(data=None, errors=[error])
|
||||||
|
|
||||||
|
# Do validation
|
||||||
|
validation_errors = validate(self.graphql_schema, document)
|
||||||
|
if validation_errors:
|
||||||
|
return ExecutionResult(data=None, errors=validation_errors)
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
kwargs = normalize_execute_kwargs(kwargs)
|
kwargs = normalize_execute_kwargs(kwargs)
|
||||||
return await subscribe(self.graphql_schema, document, *args, **kwargs)
|
return await subscribe(self.graphql_schema, document, *args, **kwargs)
|
||||||
|
|
||||||
|
|
56
graphene/types/tests/test_subscribe_async.py
Normal file
56
graphene/types/tests/test_subscribe_async.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
from pytest import mark
|
||||||
|
|
||||||
|
from graphene import ObjectType, Int, String, Schema, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Query(ObjectType):
|
||||||
|
hello = String()
|
||||||
|
|
||||||
|
def resolve_hello(root, info):
|
||||||
|
return "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(ObjectType):
|
||||||
|
count_to_ten = Field(Int)
|
||||||
|
|
||||||
|
async def subscribe_count_to_ten(root, info):
|
||||||
|
count = 0
|
||||||
|
while count < 10:
|
||||||
|
count += 1
|
||||||
|
yield count
|
||||||
|
|
||||||
|
|
||||||
|
schema = Schema(query=Query, subscription=Subscription)
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_subscription():
|
||||||
|
subscription = "subscription { countToTen }"
|
||||||
|
result = await schema.subscribe(subscription)
|
||||||
|
count = 0
|
||||||
|
async for item in result:
|
||||||
|
count = item.data["countToTen"]
|
||||||
|
assert count == 10
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_subscription_fails_with_invalid_query():
|
||||||
|
# It fails if the provided query is invalid
|
||||||
|
subscription = "subscription { "
|
||||||
|
result = await schema.subscribe(subscription)
|
||||||
|
assert not result.data
|
||||||
|
assert result.errors
|
||||||
|
assert "Syntax Error: Expected Name, found <EOF>" in str(result.errors[0])
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
|
async def test_subscription_fails_when_query_is_not_valid():
|
||||||
|
# It can't subscribe to two fields at the same time, triggering a
|
||||||
|
# validation error.
|
||||||
|
subscription = "subscription { countToTen, b: countToTen }"
|
||||||
|
result = await schema.subscribe(subscription)
|
||||||
|
assert not result.data
|
||||||
|
assert result.errors
|
||||||
|
assert "Anonymous Subscription must select only one top level field." in str(
|
||||||
|
result.errors[0]
|
||||||
|
)
|
|
@ -1,33 +0,0 @@
|
||||||
from pytest import mark
|
|
||||||
|
|
||||||
from graphene import ObjectType, Int, String, Schema, Field
|
|
||||||
|
|
||||||
|
|
||||||
class Query(ObjectType):
|
|
||||||
hello = String()
|
|
||||||
|
|
||||||
def resolve_hello(root, info):
|
|
||||||
return "Hello, world!"
|
|
||||||
|
|
||||||
|
|
||||||
class Subscription(ObjectType):
|
|
||||||
count_to_ten = Field(Int)
|
|
||||||
|
|
||||||
async def subscribe_count_to_ten(root, info):
|
|
||||||
count = 0
|
|
||||||
while count < 10:
|
|
||||||
count += 1
|
|
||||||
yield {"count_to_ten": count}
|
|
||||||
|
|
||||||
|
|
||||||
schema = Schema(query=Query, subscription=Subscription)
|
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
|
||||||
async def test_subscription():
|
|
||||||
subscription = "subscription { countToTen }"
|
|
||||||
result = await schema.subscribe(subscription)
|
|
||||||
count = 0
|
|
||||||
async for item in result:
|
|
||||||
count = item.data["countToTen"]
|
|
||||||
assert count == 10
|
|
2
tox.ini
2
tox.ini
|
@ -8,7 +8,7 @@ deps =
|
||||||
setenv =
|
setenv =
|
||||||
PYTHONPATH = .:{envdir}
|
PYTHONPATH = .:{envdir}
|
||||||
commands =
|
commands =
|
||||||
py{36,37}: pytest --cov=graphene graphene examples tests_asyncio {posargs}
|
py{36,37}: pytest --cov=graphene graphene examples {posargs}
|
||||||
|
|
||||||
[testenv:pre-commit]
|
[testenv:pre-commit]
|
||||||
basepython=python3.7
|
basepython=python3.7
|
||||||
|
|
Loading…
Reference in New Issue
Block a user