Subscription revamp (#1235)

* Integrate async tests into main code

* Added full support for subscriptions

* Fixed syntax using black

* Fixed typo
This commit is contained in:
Syrus Akbary 2020-07-28 13:33:21 -07:00 committed by GitHub
parent 2130005406
commit d085c8852b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 140 additions and 64 deletions

View File

@ -8,7 +8,7 @@ install-dev:
pip install -e ".[dev]"
test:
py.test graphene examples tests_asyncio
py.test graphene examples
.PHONY: docs ## Generate docs
docs: install-dev
@ -20,8 +20,8 @@ docs-live: install-dev
.PHONY: format
format:
black graphene examples setup.py tests_asyncio
black graphene examples setup.py
.PHONY: lint
lint:
flake8 graphene examples setup.py tests_asyncio
flake8 graphene examples setup.py

View File

@ -171,8 +171,8 @@ class IterableConnectionField(Field):
on_resolve = partial(cls.resolve_connection, connection_type, args)
return maybe_thenable(resolved, on_resolve)
def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
def wrap_resolve(self, parent_resolver):
resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
return partial(self.connection_resolver, resolver, self.type)

View File

@ -37,7 +37,7 @@ class GlobalID(Field):
parent_type_name = parent_type_name or info.parent_type.name
return node.to_global_id(parent_type_name, type_id) # root._meta.name
def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
return partial(
self.id_resolver,
parent_resolver,
@ -60,7 +60,7 @@ class NodeField(Field):
**kwargs,
)
def get_resolver(self, parent_resolver):
def wrap_resolve(self, parent_resolver):
return partial(self.node_type.node_resolver, get_type(self.field_type))

View File

@ -45,7 +45,7 @@ def test_global_id_allows_overriding_of_node_and_required():
def test_global_id_defaults_to_info_parent_type():
my_id = "1"
gid = GlobalID()
id_resolver = gid.get_resolver(lambda *_: my_id)
id_resolver = gid.wrap_resolve(lambda *_: my_id)
my_global_id = id_resolver(None, Info(User))
assert my_global_id == to_global_id(User._meta.name, my_id)
@ -53,6 +53,6 @@ def test_global_id_defaults_to_info_parent_type():
def test_global_id_allows_setting_customer_parent_type():
my_id = "1"
gid = GlobalID(parent_type=User)
id_resolver = gid.get_resolver(lambda *_: my_id)
id_resolver = gid.wrap_resolve(lambda *_: my_id)
my_global_id = id_resolver(None, None)
assert my_global_id == to_global_id(User._meta.name, my_id)

View File

@ -8,6 +8,7 @@ from .resolver import default_resolver
from .structures import NonNull
from .unmountedtype import UnmountedType
from .utils import get_type
from ..utils.deprecated import warn_deprecation
base_type = type
@ -114,5 +115,24 @@ class Field(MountedType):
def type(self):
return get_type(self._type)
def get_resolver(self, parent_resolver):
get_resolver = None
def wrap_resolve(self, parent_resolver):
"""
Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
(parent_resolver) if the Field definition has no resolver.
"""
if self.get_resolver is not None:
warn_deprecation(
"The get_resolver method is being deprecated, please rename it to wrap_resolve."
)
return self.get_resolver(parent_resolver)
return self.resolver or parent_resolver
def wrap_subscribe(self, parent_subscribe):
"""
Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
(parent_subscribe) if the Field definition has no subscribe.
"""
return parent_subscribe

View File

@ -10,8 +10,11 @@ from graphql import (
parse,
print_schema,
subscribe,
validate,
ExecutionResult,
GraphQLArgument,
GraphQLBoolean,
GraphQLError,
GraphQLEnumValue,
GraphQLField,
GraphQLFloat,
@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info):
return isinstance(root, possible_types)
# We use this resolver for subscriptions
def identity_resolve(root, info):
return root
class TypeMap(dict):
def __init__(
self,
@ -307,22 +315,39 @@ class TypeMap(dict):
if isinstance(arg.type, NonNull)
else arg.default_value,
)
subscribe = field.wrap_subscribe(
self.get_function_for_type(
graphene_type, f"subscribe_{name}", name, field.default_value,
)
)
# If we are in a subscription, we use (by default) an
# identity-based resolver for the root, rather than the
# default resolver for objects/dicts.
if subscribe:
field_default_resolver = identity_resolve
elif issubclass(graphene_type, ObjectType):
default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
field_default_resolver = partial(
default_resolver, name, field.default_value
)
else:
field_default_resolver = None
resolve = field.wrap_resolve(
self.get_function_for_type(
graphene_type, f"resolve_{name}", name, field.default_value
)
or field_default_resolver
)
_field = GraphQLField(
field_type,
args=args,
resolve=field.get_resolver(
self.get_resolver_for_type(
graphene_type, f"resolve_{name}", name, field.default_value
)
),
subscribe=field.get_resolver(
self.get_resolver_for_type(
graphene_type,
f"subscribe_{name}",
name,
field.default_value,
)
),
resolve=resolve,
subscribe=subscribe,
deprecation_reason=field.deprecation_reason,
description=field.description,
)
@ -330,7 +355,8 @@ class TypeMap(dict):
fields[field_name] = _field
return fields
def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
def get_function_for_type(self, graphene_type, func_name, name, default_value):
"""Gets a resolve or subscribe function for a given ObjectType"""
if not issubclass(graphene_type, ObjectType):
return
resolver = getattr(graphene_type, func_name, None)
@ -350,11 +376,6 @@ class TypeMap(dict):
if resolver:
return get_unbound_function(resolver)
default_resolver = (
graphene_type._meta.default_resolver or get_default_resolver()
)
return partial(default_resolver, name, default_value)
def resolve_type(self, resolve_type_func, type_name, root, info, _type):
type_ = resolve_type_func(root, info)
@ -476,7 +497,19 @@ class Schema:
return await graphql(self.graphql_schema, *args, **kwargs)
async def subscribe(self, query, *args, **kwargs):
"""Execute a GraphQL subscription on the schema asynchronously."""
# Do parsing
try:
document = parse(query)
except GraphQLError as error:
return ExecutionResult(data=None, errors=[error])
# Do validation
validation_errors = validate(self.graphql_schema, document)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)
# Execute the query
kwargs = normalize_execute_kwargs(kwargs)
return await subscribe(self.graphql_schema, document, *args, **kwargs)

View File

@ -0,0 +1,56 @@
from pytest import mark
from graphene import ObjectType, Int, String, Schema, Field
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return "Hello, world!"
class Subscription(ObjectType):
count_to_ten = Field(Int)
async def subscribe_count_to_ten(root, info):
count = 0
while count < 10:
count += 1
yield count
schema = Schema(query=Query, subscription=Subscription)
@mark.asyncio
async def test_subscription():
subscription = "subscription { countToTen }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countToTen"]
assert count == 10
@mark.asyncio
async def test_subscription_fails_with_invalid_query():
# It fails if the provided query is invalid
subscription = "subscription { "
result = await schema.subscribe(subscription)
assert not result.data
assert result.errors
assert "Syntax Error: Expected Name, found <EOF>" in str(result.errors[0])
@mark.asyncio
async def test_subscription_fails_when_query_is_not_valid():
# It can't subscribe to two fields at the same time, triggering a
# validation error.
subscription = "subscription { countToTen, b: countToTen }"
result = await schema.subscribe(subscription)
assert not result.data
assert result.errors
assert "Anonymous Subscription must select only one top level field." in str(
result.errors[0]
)

View File

@ -1,33 +0,0 @@
from pytest import mark
from graphene import ObjectType, Int, String, Schema, Field
class Query(ObjectType):
hello = String()
def resolve_hello(root, info):
return "Hello, world!"
class Subscription(ObjectType):
count_to_ten = Field(Int)
async def subscribe_count_to_ten(root, info):
count = 0
while count < 10:
count += 1
yield {"count_to_ten": count}
schema = Schema(query=Query, subscription=Subscription)
@mark.asyncio
async def test_subscription():
subscription = "subscription { countToTen }"
result = await schema.subscribe(subscription)
count = 0
async for item in result:
count = item.data["countToTen"]
assert count == 10

View File

@ -8,7 +8,7 @@ deps =
setenv =
PYTHONPATH = .:{envdir}
commands =
py{36,37}: pytest --cov=graphene graphene examples tests_asyncio {posargs}
py{36,37}: pytest --cov=graphene graphene examples {posargs}
[testenv:pre-commit]
basepython=python3.7