diff --git a/Makefile b/Makefile index df3b4118..c78e2b4f 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ install-dev: pip install -e ".[dev]" test: - py.test graphene examples tests_asyncio + py.test graphene examples .PHONY: docs ## Generate docs docs: install-dev @@ -20,8 +20,8 @@ docs-live: install-dev .PHONY: format format: - black graphene examples setup.py tests_asyncio + black graphene examples setup.py .PHONY: lint lint: - flake8 graphene examples setup.py tests_asyncio + flake8 graphene examples setup.py diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index cfb6cb63..1a4684e5 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -171,8 +171,8 @@ class IterableConnectionField(Field): on_resolve = partial(cls.resolve_connection, connection_type, args) return maybe_thenable(resolved, on_resolve) - def get_resolver(self, parent_resolver): - resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) + def wrap_resolve(self, parent_resolver): + resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver) return partial(self.connection_resolver, resolver, self.type) diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 13fb8cea..b189bc97 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -37,7 +37,7 @@ class GlobalID(Field): parent_type_name = parent_type_name or info.parent_type.name return node.to_global_id(parent_type_name, type_id) # root._meta.name - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial( self.id_resolver, parent_resolver, @@ -60,7 +60,7 @@ class NodeField(Field): **kwargs, ) - def get_resolver(self, parent_resolver): + def wrap_resolve(self, parent_resolver): return partial(self.node_type.node_resolver, get_type(self.field_type)) diff --git a/tests_asyncio/test_relay_connection.py b/graphene/relay/tests/test_connection_async.py similarity index 100% rename from tests_asyncio/test_relay_connection.py rename to graphene/relay/tests/test_connection_async.py diff --git a/graphene/relay/tests/test_global_id.py b/graphene/relay/tests/test_global_id.py index 2fe81300..81860d9d 100644 --- a/graphene/relay/tests/test_global_id.py +++ b/graphene/relay/tests/test_global_id.py @@ -45,7 +45,7 @@ def test_global_id_allows_overriding_of_node_and_required(): def test_global_id_defaults_to_info_parent_type(): my_id = "1" gid = GlobalID() - id_resolver = gid.get_resolver(lambda *_: my_id) + id_resolver = gid.wrap_resolve(lambda *_: my_id) my_global_id = id_resolver(None, Info(User)) assert my_global_id == to_global_id(User._meta.name, my_id) @@ -53,6 +53,6 @@ def test_global_id_defaults_to_info_parent_type(): def test_global_id_allows_setting_customer_parent_type(): my_id = "1" gid = GlobalID(parent_type=User) - id_resolver = gid.get_resolver(lambda *_: my_id) + id_resolver = gid.wrap_resolve(lambda *_: my_id) my_global_id = id_resolver(None, None) assert my_global_id == to_global_id(User._meta.name, my_id) diff --git a/tests_asyncio/test_relay_mutation.py b/graphene/relay/tests/test_mutation_async.py similarity index 100% rename from tests_asyncio/test_relay_mutation.py rename to graphene/relay/tests/test_mutation_async.py diff --git a/graphene/types/field.py b/graphene/types/field.py index 1a1ccf93..dafb04b5 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -8,6 +8,7 @@ from .resolver import default_resolver from .structures import NonNull from .unmountedtype import UnmountedType from .utils import get_type +from ..utils.deprecated import warn_deprecation base_type = type @@ -114,5 +115,24 @@ class Field(MountedType): def type(self): return get_type(self._type) - def get_resolver(self, parent_resolver): + get_resolver = None + + def wrap_resolve(self, parent_resolver): + """ + Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME} + (parent_resolver) if the Field definition has no resolver. + """ + if self.get_resolver is not None: + warn_deprecation( + "The get_resolver method is being deprecated, please rename it to wrap_resolve." + ) + return self.get_resolver(parent_resolver) + return self.resolver or parent_resolver + + def wrap_subscribe(self, parent_subscribe): + """ + Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME} + (parent_subscribe) if the Field definition has no subscribe. + """ + return parent_subscribe diff --git a/graphene/types/schema.py b/graphene/types/schema.py index ce0c7439..5eb59e66 100644 --- a/graphene/types/schema.py +++ b/graphene/types/schema.py @@ -10,8 +10,11 @@ from graphql import ( parse, print_schema, subscribe, + validate, + ExecutionResult, GraphQLArgument, GraphQLBoolean, + GraphQLError, GraphQLEnumValue, GraphQLField, GraphQLFloat, @@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info): return isinstance(root, possible_types) +# We use this resolver for subscriptions +def identity_resolve(root, info): + return root + + class TypeMap(dict): def __init__( self, @@ -307,22 +315,39 @@ class TypeMap(dict): if isinstance(arg.type, NonNull) else arg.default_value, ) + subscribe = field.wrap_subscribe( + self.get_function_for_type( + graphene_type, f"subscribe_{name}", name, field.default_value, + ) + ) + + # If we are in a subscription, we use (by default) an + # identity-based resolver for the root, rather than the + # default resolver for objects/dicts. + if subscribe: + field_default_resolver = identity_resolve + elif issubclass(graphene_type, ObjectType): + default_resolver = ( + graphene_type._meta.default_resolver or get_default_resolver() + ) + field_default_resolver = partial( + default_resolver, name, field.default_value + ) + else: + field_default_resolver = None + + resolve = field.wrap_resolve( + self.get_function_for_type( + graphene_type, f"resolve_{name}", name, field.default_value + ) + or field_default_resolver + ) + _field = GraphQLField( field_type, args=args, - resolve=field.get_resolver( - self.get_resolver_for_type( - graphene_type, f"resolve_{name}", name, field.default_value - ) - ), - subscribe=field.get_resolver( - self.get_resolver_for_type( - graphene_type, - f"subscribe_{name}", - name, - field.default_value, - ) - ), + resolve=resolve, + subscribe=subscribe, deprecation_reason=field.deprecation_reason, description=field.description, ) @@ -330,7 +355,8 @@ class TypeMap(dict): fields[field_name] = _field return fields - def get_resolver_for_type(self, graphene_type, func_name, name, default_value): + def get_function_for_type(self, graphene_type, func_name, name, default_value): + """Gets a resolve or subscribe function for a given ObjectType""" if not issubclass(graphene_type, ObjectType): return resolver = getattr(graphene_type, func_name, None) @@ -350,11 +376,6 @@ class TypeMap(dict): if resolver: return get_unbound_function(resolver) - default_resolver = ( - graphene_type._meta.default_resolver or get_default_resolver() - ) - return partial(default_resolver, name, default_value) - def resolve_type(self, resolve_type_func, type_name, root, info, _type): type_ = resolve_type_func(root, info) @@ -476,7 +497,19 @@ class Schema: return await graphql(self.graphql_schema, *args, **kwargs) async def subscribe(self, query, *args, **kwargs): - document = parse(query) + """Execute a GraphQL subscription on the schema asynchronously.""" + # Do parsing + try: + document = parse(query) + except GraphQLError as error: + return ExecutionResult(data=None, errors=[error]) + + # Do validation + validation_errors = validate(self.graphql_schema, document) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + + # Execute the query kwargs = normalize_execute_kwargs(kwargs) return await subscribe(self.graphql_schema, document, *args, **kwargs) diff --git a/graphene/types/tests/test_subscribe_async.py b/graphene/types/tests/test_subscribe_async.py new file mode 100644 index 00000000..6f7ce4c6 --- /dev/null +++ b/graphene/types/tests/test_subscribe_async.py @@ -0,0 +1,56 @@ +from pytest import mark + +from graphene import ObjectType, Int, String, Schema, Field + + +class Query(ObjectType): + hello = String() + + def resolve_hello(root, info): + return "Hello, world!" + + +class Subscription(ObjectType): + count_to_ten = Field(Int) + + async def subscribe_count_to_ten(root, info): + count = 0 + while count < 10: + count += 1 + yield count + + +schema = Schema(query=Query, subscription=Subscription) + + +@mark.asyncio +async def test_subscription(): + subscription = "subscription { countToTen }" + result = await schema.subscribe(subscription) + count = 0 + async for item in result: + count = item.data["countToTen"] + assert count == 10 + + +@mark.asyncio +async def test_subscription_fails_with_invalid_query(): + # It fails if the provided query is invalid + subscription = "subscription { " + result = await schema.subscribe(subscription) + assert not result.data + assert result.errors + assert "Syntax Error: Expected Name, found " in str(result.errors[0]) + + +@mark.asyncio +async def test_subscription_fails_when_query_is_not_valid(): + # It can't subscribe to two fields at the same time, triggering a + # validation error. + subscription = "subscription { countToTen, b: countToTen }" + result = await schema.subscribe(subscription) + assert not result.data + assert result.errors + assert "Anonymous Subscription must select only one top level field." in str( + result.errors[0] + ) diff --git a/tests_asyncio/test_subscribe.py b/tests_asyncio/test_subscribe.py deleted file mode 100644 index bf985d58..00000000 --- a/tests_asyncio/test_subscribe.py +++ /dev/null @@ -1,33 +0,0 @@ -from pytest import mark - -from graphene import ObjectType, Int, String, Schema, Field - - -class Query(ObjectType): - hello = String() - - def resolve_hello(root, info): - return "Hello, world!" - - -class Subscription(ObjectType): - count_to_ten = Field(Int) - - async def subscribe_count_to_ten(root, info): - count = 0 - while count < 10: - count += 1 - yield {"count_to_ten": count} - - -schema = Schema(query=Query, subscription=Subscription) - - -@mark.asyncio -async def test_subscription(): - subscription = "subscription { countToTen }" - result = await schema.subscribe(subscription) - count = 0 - async for item in result: - count = item.data["countToTen"] - assert count == 10 diff --git a/tox.ini b/tox.ini index 468f5fbc..b0298fea 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,7 @@ deps = setenv = PYTHONPATH = .:{envdir} commands = - py{36,37}: pytest --cov=graphene graphene examples tests_asyncio {posargs} + py{36,37}: pytest --cov=graphene graphene examples {posargs} [testenv:pre-commit] basepython=python3.7