mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 09:57:41 +03:00 
			
		
		
		
	Subscription revamp (#1235)
* Integrate async tests into main code * Added full support for subscriptions * Fixed syntax using black * Fixed typo
This commit is contained in:
		
							parent
							
								
									2130005406
								
							
						
					
					
						commit
						d085c8852b
					
				
							
								
								
									
										6
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								Makefile
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -8,7 +8,7 @@ install-dev:
 | 
			
		|||
	pip install -e ".[dev]"
 | 
			
		||||
 | 
			
		||||
test:
 | 
			
		||||
	py.test graphene examples tests_asyncio
 | 
			
		||||
	py.test graphene examples
 | 
			
		||||
 | 
			
		||||
.PHONY: docs ## Generate docs
 | 
			
		||||
docs: install-dev
 | 
			
		||||
| 
						 | 
				
			
			@ -20,8 +20,8 @@ docs-live: install-dev
 | 
			
		|||
 | 
			
		||||
.PHONY: format
 | 
			
		||||
format:
 | 
			
		||||
	black graphene examples setup.py tests_asyncio
 | 
			
		||||
	black graphene examples setup.py
 | 
			
		||||
 | 
			
		||||
.PHONY: lint
 | 
			
		||||
lint:
 | 
			
		||||
	flake8 graphene examples setup.py tests_asyncio
 | 
			
		||||
	flake8 graphene examples setup.py
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -171,8 +171,8 @@ class IterableConnectionField(Field):
 | 
			
		|||
        on_resolve = partial(cls.resolve_connection, connection_type, args)
 | 
			
		||||
        return maybe_thenable(resolved, on_resolve)
 | 
			
		||||
 | 
			
		||||
    def get_resolver(self, parent_resolver):
 | 
			
		||||
        resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
 | 
			
		||||
    def wrap_resolve(self, parent_resolver):
 | 
			
		||||
        resolver = super(IterableConnectionField, self).wrap_resolve(parent_resolver)
 | 
			
		||||
        return partial(self.connection_resolver, resolver, self.type)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,7 +37,7 @@ class GlobalID(Field):
 | 
			
		|||
        parent_type_name = parent_type_name or info.parent_type.name
 | 
			
		||||
        return node.to_global_id(parent_type_name, type_id)  # root._meta.name
 | 
			
		||||
 | 
			
		||||
    def get_resolver(self, parent_resolver):
 | 
			
		||||
    def wrap_resolve(self, parent_resolver):
 | 
			
		||||
        return partial(
 | 
			
		||||
            self.id_resolver,
 | 
			
		||||
            parent_resolver,
 | 
			
		||||
| 
						 | 
				
			
			@ -60,7 +60,7 @@ class NodeField(Field):
 | 
			
		|||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_resolver(self, parent_resolver):
 | 
			
		||||
    def wrap_resolve(self, parent_resolver):
 | 
			
		||||
        return partial(self.node_type.node_resolver, get_type(self.field_type))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,7 +45,7 @@ def test_global_id_allows_overriding_of_node_and_required():
 | 
			
		|||
def test_global_id_defaults_to_info_parent_type():
 | 
			
		||||
    my_id = "1"
 | 
			
		||||
    gid = GlobalID()
 | 
			
		||||
    id_resolver = gid.get_resolver(lambda *_: my_id)
 | 
			
		||||
    id_resolver = gid.wrap_resolve(lambda *_: my_id)
 | 
			
		||||
    my_global_id = id_resolver(None, Info(User))
 | 
			
		||||
    assert my_global_id == to_global_id(User._meta.name, my_id)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,6 +53,6 @@ def test_global_id_defaults_to_info_parent_type():
 | 
			
		|||
def test_global_id_allows_setting_customer_parent_type():
 | 
			
		||||
    my_id = "1"
 | 
			
		||||
    gid = GlobalID(parent_type=User)
 | 
			
		||||
    id_resolver = gid.get_resolver(lambda *_: my_id)
 | 
			
		||||
    id_resolver = gid.wrap_resolve(lambda *_: my_id)
 | 
			
		||||
    my_global_id = id_resolver(None, None)
 | 
			
		||||
    assert my_global_id == to_global_id(User._meta.name, my_id)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,6 +8,7 @@ from .resolver import default_resolver
 | 
			
		|||
from .structures import NonNull
 | 
			
		||||
from .unmountedtype import UnmountedType
 | 
			
		||||
from .utils import get_type
 | 
			
		||||
from ..utils.deprecated import warn_deprecation
 | 
			
		||||
 | 
			
		||||
base_type = type
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -114,5 +115,24 @@ class Field(MountedType):
 | 
			
		|||
    def type(self):
 | 
			
		||||
        return get_type(self._type)
 | 
			
		||||
 | 
			
		||||
    def get_resolver(self, parent_resolver):
 | 
			
		||||
    get_resolver = None
 | 
			
		||||
 | 
			
		||||
    def wrap_resolve(self, parent_resolver):
 | 
			
		||||
        """
 | 
			
		||||
        Wraps a function resolver, using the ObjectType resolve_{FIELD_NAME}
 | 
			
		||||
        (parent_resolver) if the Field definition has no resolver.
 | 
			
		||||
        """
 | 
			
		||||
        if self.get_resolver is not None:
 | 
			
		||||
            warn_deprecation(
 | 
			
		||||
                "The get_resolver method is being deprecated, please rename it to wrap_resolve."
 | 
			
		||||
            )
 | 
			
		||||
            return self.get_resolver(parent_resolver)
 | 
			
		||||
 | 
			
		||||
        return self.resolver or parent_resolver
 | 
			
		||||
 | 
			
		||||
    def wrap_subscribe(self, parent_subscribe):
 | 
			
		||||
        """
 | 
			
		||||
        Wraps a function subscribe, using the ObjectType subscribe_{FIELD_NAME}
 | 
			
		||||
        (parent_subscribe) if the Field definition has no subscribe.
 | 
			
		||||
        """
 | 
			
		||||
        return parent_subscribe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,8 +10,11 @@ from graphql import (
 | 
			
		|||
    parse,
 | 
			
		||||
    print_schema,
 | 
			
		||||
    subscribe,
 | 
			
		||||
    validate,
 | 
			
		||||
    ExecutionResult,
 | 
			
		||||
    GraphQLArgument,
 | 
			
		||||
    GraphQLBoolean,
 | 
			
		||||
    GraphQLError,
 | 
			
		||||
    GraphQLEnumValue,
 | 
			
		||||
    GraphQLField,
 | 
			
		||||
    GraphQLFloat,
 | 
			
		||||
| 
						 | 
				
			
			@ -76,6 +79,11 @@ def is_type_of_from_possible_types(possible_types, root, _info):
 | 
			
		|||
    return isinstance(root, possible_types)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We use this resolver for subscriptions
 | 
			
		||||
def identity_resolve(root, info):
 | 
			
		||||
    return root
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TypeMap(dict):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -307,22 +315,39 @@ class TypeMap(dict):
 | 
			
		|||
                        if isinstance(arg.type, NonNull)
 | 
			
		||||
                        else arg.default_value,
 | 
			
		||||
                    )
 | 
			
		||||
                subscribe = field.wrap_subscribe(
 | 
			
		||||
                    self.get_function_for_type(
 | 
			
		||||
                        graphene_type, f"subscribe_{name}", name, field.default_value,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # If we are in a subscription, we use (by default) an
 | 
			
		||||
                # identity-based resolver for the root, rather than the
 | 
			
		||||
                # default resolver for objects/dicts.
 | 
			
		||||
                if subscribe:
 | 
			
		||||
                    field_default_resolver = identity_resolve
 | 
			
		||||
                elif issubclass(graphene_type, ObjectType):
 | 
			
		||||
                    default_resolver = (
 | 
			
		||||
                        graphene_type._meta.default_resolver or get_default_resolver()
 | 
			
		||||
                    )
 | 
			
		||||
                    field_default_resolver = partial(
 | 
			
		||||
                        default_resolver, name, field.default_value
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    field_default_resolver = None
 | 
			
		||||
 | 
			
		||||
                resolve = field.wrap_resolve(
 | 
			
		||||
                    self.get_function_for_type(
 | 
			
		||||
                        graphene_type, f"resolve_{name}", name, field.default_value
 | 
			
		||||
                    )
 | 
			
		||||
                    or field_default_resolver
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                _field = GraphQLField(
 | 
			
		||||
                    field_type,
 | 
			
		||||
                    args=args,
 | 
			
		||||
                    resolve=field.get_resolver(
 | 
			
		||||
                        self.get_resolver_for_type(
 | 
			
		||||
                            graphene_type, f"resolve_{name}", name, field.default_value
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    subscribe=field.get_resolver(
 | 
			
		||||
                        self.get_resolver_for_type(
 | 
			
		||||
                            graphene_type,
 | 
			
		||||
                            f"subscribe_{name}",
 | 
			
		||||
                            name,
 | 
			
		||||
                            field.default_value,
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    resolve=resolve,
 | 
			
		||||
                    subscribe=subscribe,
 | 
			
		||||
                    deprecation_reason=field.deprecation_reason,
 | 
			
		||||
                    description=field.description,
 | 
			
		||||
                )
 | 
			
		||||
| 
						 | 
				
			
			@ -330,7 +355,8 @@ class TypeMap(dict):
 | 
			
		|||
            fields[field_name] = _field
 | 
			
		||||
        return fields
 | 
			
		||||
 | 
			
		||||
    def get_resolver_for_type(self, graphene_type, func_name, name, default_value):
 | 
			
		||||
    def get_function_for_type(self, graphene_type, func_name, name, default_value):
 | 
			
		||||
        """Gets a resolve or subscribe function for a given ObjectType"""
 | 
			
		||||
        if not issubclass(graphene_type, ObjectType):
 | 
			
		||||
            return
 | 
			
		||||
        resolver = getattr(graphene_type, func_name, None)
 | 
			
		||||
| 
						 | 
				
			
			@ -350,11 +376,6 @@ class TypeMap(dict):
 | 
			
		|||
        if resolver:
 | 
			
		||||
            return get_unbound_function(resolver)
 | 
			
		||||
 | 
			
		||||
        default_resolver = (
 | 
			
		||||
            graphene_type._meta.default_resolver or get_default_resolver()
 | 
			
		||||
        )
 | 
			
		||||
        return partial(default_resolver, name, default_value)
 | 
			
		||||
 | 
			
		||||
    def resolve_type(self, resolve_type_func, type_name, root, info, _type):
 | 
			
		||||
        type_ = resolve_type_func(root, info)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -476,7 +497,19 @@ class Schema:
 | 
			
		|||
        return await graphql(self.graphql_schema, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    async def subscribe(self, query, *args, **kwargs):
 | 
			
		||||
        """Execute a GraphQL subscription on the schema asynchronously."""
 | 
			
		||||
        # Do parsing
 | 
			
		||||
        try:
 | 
			
		||||
            document = parse(query)
 | 
			
		||||
        except GraphQLError as error:
 | 
			
		||||
            return ExecutionResult(data=None, errors=[error])
 | 
			
		||||
 | 
			
		||||
        # Do validation
 | 
			
		||||
        validation_errors = validate(self.graphql_schema, document)
 | 
			
		||||
        if validation_errors:
 | 
			
		||||
            return ExecutionResult(data=None, errors=validation_errors)
 | 
			
		||||
 | 
			
		||||
        # Execute the query
 | 
			
		||||
        kwargs = normalize_execute_kwargs(kwargs)
 | 
			
		||||
        return await subscribe(self.graphql_schema, document, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										56
									
								
								graphene/types/tests/test_subscribe_async.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								graphene/types/tests/test_subscribe_async.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,56 @@
 | 
			
		|||
from pytest import mark
 | 
			
		||||
 | 
			
		||||
from graphene import ObjectType, Int, String, Schema, Field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Query(ObjectType):
 | 
			
		||||
    hello = String()
 | 
			
		||||
 | 
			
		||||
    def resolve_hello(root, info):
 | 
			
		||||
        return "Hello, world!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Subscription(ObjectType):
 | 
			
		||||
    count_to_ten = Field(Int)
 | 
			
		||||
 | 
			
		||||
    async def subscribe_count_to_ten(root, info):
 | 
			
		||||
        count = 0
 | 
			
		||||
        while count < 10:
 | 
			
		||||
            count += 1
 | 
			
		||||
            yield count
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
schema = Schema(query=Query, subscription=Subscription)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.asyncio
 | 
			
		||||
async def test_subscription():
 | 
			
		||||
    subscription = "subscription { countToTen }"
 | 
			
		||||
    result = await schema.subscribe(subscription)
 | 
			
		||||
    count = 0
 | 
			
		||||
    async for item in result:
 | 
			
		||||
        count = item.data["countToTen"]
 | 
			
		||||
    assert count == 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.asyncio
 | 
			
		||||
async def test_subscription_fails_with_invalid_query():
 | 
			
		||||
    # It fails if the provided query is invalid
 | 
			
		||||
    subscription = "subscription { "
 | 
			
		||||
    result = await schema.subscribe(subscription)
 | 
			
		||||
    assert not result.data
 | 
			
		||||
    assert result.errors
 | 
			
		||||
    assert "Syntax Error: Expected Name, found <EOF>" in str(result.errors[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.asyncio
 | 
			
		||||
async def test_subscription_fails_when_query_is_not_valid():
 | 
			
		||||
    # It can't subscribe to two fields at the same time, triggering a
 | 
			
		||||
    # validation error.
 | 
			
		||||
    subscription = "subscription { countToTen, b: countToTen }"
 | 
			
		||||
    result = await schema.subscribe(subscription)
 | 
			
		||||
    assert not result.data
 | 
			
		||||
    assert result.errors
 | 
			
		||||
    assert "Anonymous Subscription must select only one top level field." in str(
 | 
			
		||||
        result.errors[0]
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -1,33 +0,0 @@
 | 
			
		|||
from pytest import mark
 | 
			
		||||
 | 
			
		||||
from graphene import ObjectType, Int, String, Schema, Field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Query(ObjectType):
 | 
			
		||||
    hello = String()
 | 
			
		||||
 | 
			
		||||
    def resolve_hello(root, info):
 | 
			
		||||
        return "Hello, world!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Subscription(ObjectType):
 | 
			
		||||
    count_to_ten = Field(Int)
 | 
			
		||||
 | 
			
		||||
    async def subscribe_count_to_ten(root, info):
 | 
			
		||||
        count = 0
 | 
			
		||||
        while count < 10:
 | 
			
		||||
            count += 1
 | 
			
		||||
            yield {"count_to_ten": count}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
schema = Schema(query=Query, subscription=Subscription)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.asyncio
 | 
			
		||||
async def test_subscription():
 | 
			
		||||
    subscription = "subscription { countToTen }"
 | 
			
		||||
    result = await schema.subscribe(subscription)
 | 
			
		||||
    count = 0
 | 
			
		||||
    async for item in result:
 | 
			
		||||
        count = item.data["countToTen"]
 | 
			
		||||
    assert count == 10
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user