mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 09:57:41 +03:00 
			
		
		
		
	Added support for subscription (#1107)
* Added support for subscription * Added pre-commit hooks for black and formatted changed files * Checked with flake8 * Integrated changes from master. Co-authored-by: Rob Blackbourn <rblackbourn@bhdgsystematic.com> Co-authored-by: Rob Blackbourn <rtb@beast.jetblack.net>
This commit is contained in:
		
							parent
							
								
									88f79b2850
								
							
						
					
					
						commit
						1cf303a27b
					
				| 
						 | 
				
			
			@ -17,6 +17,43 @@ For executing a query against a schema, you can directly call the ``execute`` me
 | 
			
		|||
``result`` represents the result of execution. ``result.data`` is the result of executing the query, ``result.errors`` is ``None`` if no errors occurred, and is a non-empty list if an error occurred.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
For executing a subscription, you can directly call the ``subscribe`` method on it.
 | 
			
		||||
This method is async and must be awaited.
 | 
			
		||||
 | 
			
		||||
.. code:: python
 | 
			
		||||
 | 
			
		||||
    import asyncio
 | 
			
		||||
    from datetime import datetime
 | 
			
		||||
    from graphene import ObjectType, String, Schema, Field
 | 
			
		||||
 | 
			
		||||
    # All schema require a query.
 | 
			
		||||
    class Query(ObjectType):
 | 
			
		||||
        hello = String()
 | 
			
		||||
 | 
			
		||||
        def resolve_hello(root, info):
 | 
			
		||||
            return 'Hello, world!'
 | 
			
		||||
 | 
			
		||||
    class Subscription(ObjectType):
 | 
			
		||||
        time_of_day = Field(String)
 | 
			
		||||
 | 
			
		||||
        async def subscribe_time_of_day(root, info):
 | 
			
		||||
            while True:
 | 
			
		||||
                yield { 'time_of_day': datetime.now().isoformat()}
 | 
			
		||||
                await asyncio.sleep(1)
 | 
			
		||||
 | 
			
		||||
    SCHEMA = Schema(query=Query, subscription=Subscription)
 | 
			
		||||
 | 
			
		||||
    async def main(schema):
 | 
			
		||||
 | 
			
		||||
        subscription = 'subscription { timeOfDay }'
 | 
			
		||||
        result = await schema.subscribe(subscription)
 | 
			
		||||
        async for item in result:
 | 
			
		||||
            print(item.data['timeOfDay'])
 | 
			
		||||
 | 
			
		||||
    asyncio.run(main(SCHEMA))
 | 
			
		||||
 | 
			
		||||
The ``result`` is an async iterator which yields items in the same manner as a query.
 | 
			
		||||
 | 
			
		||||
.. _SchemaExecuteContext:
 | 
			
		||||
 | 
			
		||||
Context
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,9 @@ from graphql import (
 | 
			
		|||
    graphql,
 | 
			
		||||
    graphql_sync,
 | 
			
		||||
    introspection_types,
 | 
			
		||||
    parse,
 | 
			
		||||
    print_schema,
 | 
			
		||||
    subscribe,
 | 
			
		||||
    GraphQLArgument,
 | 
			
		||||
    GraphQLBoolean,
 | 
			
		||||
    GraphQLEnumValue,
 | 
			
		||||
| 
						 | 
				
			
			@ -309,13 +311,19 @@ class TypeMap(dict):
 | 
			
		|||
                        if isinstance(arg.type, NonNull)
 | 
			
		||||
                        else arg.default_value,
 | 
			
		||||
                    )
 | 
			
		||||
                resolve = field.get_resolver(
 | 
			
		||||
                    self.get_resolver(graphene_type, name, field.default_value)
 | 
			
		||||
                )
 | 
			
		||||
                _field = GraphQLField(
 | 
			
		||||
                    field_type,
 | 
			
		||||
                    args=args,
 | 
			
		||||
                    resolve=resolve,
 | 
			
		||||
                    resolve=field.get_resolver(
 | 
			
		||||
                        self.get_resolver_for_type(
 | 
			
		||||
                            graphene_type, "resolve_{}", name, field.default_value
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    subscribe=field.get_resolver(
 | 
			
		||||
                        self.get_resolver_for_type(
 | 
			
		||||
                            graphene_type, "subscribe_{}", name, field.default_value
 | 
			
		||||
                        )
 | 
			
		||||
                    ),
 | 
			
		||||
                    deprecation_reason=field.deprecation_reason,
 | 
			
		||||
                    description=field.description,
 | 
			
		||||
                )
 | 
			
		||||
| 
						 | 
				
			
			@ -323,6 +331,32 @@ class TypeMap(dict):
 | 
			
		|||
            fields[field_name] = _field
 | 
			
		||||
        return fields
 | 
			
		||||
 | 
			
		||||
    def get_resolver_for_type(self, graphene_type, pattern, name, default_value):
 | 
			
		||||
        if not issubclass(graphene_type, ObjectType):
 | 
			
		||||
            return
 | 
			
		||||
        func_name = pattern.format(name)
 | 
			
		||||
        resolver = getattr(graphene_type, func_name, None)
 | 
			
		||||
        if not resolver:
 | 
			
		||||
            # If we don't find the resolver in the ObjectType class, then try to
 | 
			
		||||
            # find it in each of the interfaces
 | 
			
		||||
            interface_resolver = None
 | 
			
		||||
            for interface in graphene_type._meta.interfaces:
 | 
			
		||||
                if name not in interface._meta.fields:
 | 
			
		||||
                    continue
 | 
			
		||||
                interface_resolver = getattr(interface, func_name, None)
 | 
			
		||||
                if interface_resolver:
 | 
			
		||||
                    break
 | 
			
		||||
            resolver = interface_resolver
 | 
			
		||||
 | 
			
		||||
        # Only if is not decorated with classmethod
 | 
			
		||||
        if resolver:
 | 
			
		||||
            return get_unbound_function(resolver)
 | 
			
		||||
 | 
			
		||||
        default_resolver = (
 | 
			
		||||
            graphene_type._meta.default_resolver or get_default_resolver()
 | 
			
		||||
        )
 | 
			
		||||
        return partial(default_resolver, name, default_value)
 | 
			
		||||
 | 
			
		||||
    def resolve_type(self, resolve_type_func, type_name, root, info, _type):
 | 
			
		||||
        type_ = resolve_type_func(root, info)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -468,6 +502,11 @@ class Schema:
 | 
			
		|||
        kwargs = normalize_execute_kwargs(kwargs)
 | 
			
		||||
        return await graphql(self.graphql_schema, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    async def subscribe(self, query, *args, **kwargs):
 | 
			
		||||
        document = parse(query)
 | 
			
		||||
        kwargs = normalize_execute_kwargs(kwargs)
 | 
			
		||||
        return await subscribe(self.graphql_schema, document, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def introspect(self):
 | 
			
		||||
        introspection = self.execute(introspection_query)
 | 
			
		||||
        if introspection.errors:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										33
									
								
								tests_asyncio/test_subscribe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								tests_asyncio/test_subscribe.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,33 @@
 | 
			
		|||
from pytest import mark
 | 
			
		||||
 | 
			
		||||
from graphene import ObjectType, Int, String, Schema, Field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Query(ObjectType):
 | 
			
		||||
    hello = String()
 | 
			
		||||
 | 
			
		||||
    def resolve_hello(root, info):
 | 
			
		||||
        return "Hello, world!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Subscription(ObjectType):
 | 
			
		||||
    count_to_ten = Field(Int)
 | 
			
		||||
 | 
			
		||||
    async def subscribe_count_to_ten(root, info):
 | 
			
		||||
        count = 0
 | 
			
		||||
        while count < 10:
 | 
			
		||||
            count += 1
 | 
			
		||||
            yield {"count_to_ten": count}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
schema = Schema(query=Query, subscription=Subscription)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mark.asyncio
 | 
			
		||||
async def test_subscription():
 | 
			
		||||
    subscription = "subscription { countToTen }"
 | 
			
		||||
    result = await schema.subscribe(subscription)
 | 
			
		||||
    count = 0
 | 
			
		||||
    async for item in result:
 | 
			
		||||
        count = item.data["countToTen"]
 | 
			
		||||
    assert count == 10
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user