mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 09:57:41 +03:00 
			
		
		
		
	Merge pull request #176 from graphql-python/features/middlewares
Added Middleware
This commit is contained in:
		
						commit
						3a1093af24
					
				| 
						 | 
				
			
			@ -17,6 +17,7 @@ ga = "UA-12613282-7"
 | 
			
		|||
    "/docs/basic-types/",
 | 
			
		||||
    "/docs/enums/",
 | 
			
		||||
    "/docs/relay/",
 | 
			
		||||
    "/docs/middleware/",
 | 
			
		||||
  ]
 | 
			
		||||
 | 
			
		||||
[docs.django]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,14 +13,19 @@ For that, you will need to add the plugin in your graphene schema.
 | 
			
		|||
 | 
			
		||||
## Installation
 | 
			
		||||
 | 
			
		||||
For use the Django Debug plugin in Graphene, just import `DjangoDebugPlugin` and add it to the `plugins` argument when you initiate the `Schema`.
 | 
			
		||||
For use the Django Debug plugin in Graphene:
 | 
			
		||||
* Import `DjangoDebugMiddleware` and add it to the `middleware` argument when you initiate the `Schema`.
 | 
			
		||||
* Add the `debug` field into the schema root `Query` with the value `graphene.Field(DjangoDebug, name='__debug')`.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from graphene.contrib.django.debug import DjangoDebugPlugin
 | 
			
		||||
from graphene.contrib.django.debug import DjangoDebugMiddleware, DjangoDebug
 | 
			
		||||
 | 
			
		||||
# ...
 | 
			
		||||
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
 | 
			
		||||
class Query(graphene.ObjectType):
 | 
			
		||||
    # ...
 | 
			
		||||
    debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
 | 
			
		||||
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This plugin, will add another field in the `Query` named `__debug`.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										43
									
								
								docs/pages/docs/middleware.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								docs/pages/docs/middleware.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,43 @@
 | 
			
		|||
---
 | 
			
		||||
title: Middleware
 | 
			
		||||
description: Walkthrough Middleware
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# Middleware
 | 
			
		||||
 | 
			
		||||
You can use _middleware_ to affect the evaluation of fields in your schema.
 | 
			
		||||
 | 
			
		||||
A middleware is any object that responds to `resolve(*args, next_middleware)`. Inside that method, it should either:
 | 
			
		||||
 | 
			
		||||
* Send `resolve` to the next middleware to continue the evaluation; or
 | 
			
		||||
* Return a value to end the evaluation early.
 | 
			
		||||
 | 
			
		||||
Middlewares' `resolve` is invoked with several arguments:
 | 
			
		||||
 | 
			
		||||
* `next` represents the execution chain. Call `next` to continue evalution.
 | 
			
		||||
* `root` is the root value object passed throughout the query
 | 
			
		||||
* `args` is the hash of arguments passed to the field
 | 
			
		||||
* `context` is the context object passed throughout the query
 | 
			
		||||
* `info` is the resolver info
 | 
			
		||||
 | 
			
		||||
Add a middleware to a schema by adding to the `middlewares` list.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### Example: Authorization
 | 
			
		||||
 | 
			
		||||
This middleware only continues evaluation if the `field_name` is not `'user'`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
class AuthorizationMiddleware(object):
 | 
			
		||||
 | 
			
		||||
    def resolve(self, next, root, args, context, info):
 | 
			
		||||
        if info.field_name == 'user':
 | 
			
		||||
            return None
 | 
			
		||||
        return next(root, args, context, info)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then, add the middleware to your schema:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
schema = Schema(middlewares=[AuthorizationMiddleware])
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			@ -1,4 +1,4 @@
 | 
			
		|||
from .plugin import DjangoDebugPlugin
 | 
			
		||||
from .middleware import DjangoDebugMiddleware
 | 
			
		||||
from .types import DjangoDebug
 | 
			
		||||
 | 
			
		||||
__all__ = ['DjangoDebugPlugin', 'DjangoDebug']
 | 
			
		||||
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										56
									
								
								graphene/contrib/django/debug/middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								graphene/contrib/django/debug/middleware.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,56 @@
 | 
			
		|||
from promise import Promise
 | 
			
		||||
from django.db import connections
 | 
			
		||||
 | 
			
		||||
from .sql.tracking import unwrap_cursor, wrap_cursor
 | 
			
		||||
from .types import DjangoDebug
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugContext(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.debug_promise = None
 | 
			
		||||
        self.promises = []
 | 
			
		||||
        self.enable_instrumentation()
 | 
			
		||||
        self.object = DjangoDebug(sql=[])
 | 
			
		||||
 | 
			
		||||
    def get_debug_promise(self):
 | 
			
		||||
        if not self.debug_promise:
 | 
			
		||||
            self.debug_promise = Promise.all(self.promises)
 | 
			
		||||
        return self.debug_promise.then(self.on_resolve_all_promises)
 | 
			
		||||
 | 
			
		||||
    def on_resolve_all_promises(self, values):
 | 
			
		||||
        self.disable_instrumentation()
 | 
			
		||||
        return self.object
 | 
			
		||||
 | 
			
		||||
    def add_promise(self, promise):
 | 
			
		||||
        if self.debug_promise and not self.debug_promise.is_fulfilled:
 | 
			
		||||
            self.promises.append(promise)
 | 
			
		||||
 | 
			
		||||
    def enable_instrumentation(self):
 | 
			
		||||
        # This is thread-safe because database connections are thread-local.
 | 
			
		||||
        for connection in connections.all():
 | 
			
		||||
            wrap_cursor(connection, self)
 | 
			
		||||
 | 
			
		||||
    def disable_instrumentation(self):
 | 
			
		||||
        for connection in connections.all():
 | 
			
		||||
            unwrap_cursor(connection)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugMiddleware(object):
 | 
			
		||||
 | 
			
		||||
    def resolve(self, next, root, args, context, info):
 | 
			
		||||
        django_debug = getattr(context, 'django_debug', None)
 | 
			
		||||
        if not django_debug:
 | 
			
		||||
            if context is None:
 | 
			
		||||
                raise Exception('DjangoDebug cannot be executed in None contexts')
 | 
			
		||||
            try:
 | 
			
		||||
                context.django_debug = DjangoDebugContext()
 | 
			
		||||
            except Exception:
 | 
			
		||||
                raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
 | 
			
		||||
                    context.__class__.__name__
 | 
			
		||||
                ))
 | 
			
		||||
        if info.schema.graphene_schema.T(DjangoDebug) == info.return_type:
 | 
			
		||||
            return context.django_debug.get_debug_promise()
 | 
			
		||||
        promise = next(root, args, context, info)
 | 
			
		||||
        context.django_debug.add_promise(promise)
 | 
			
		||||
        return promise
 | 
			
		||||
| 
						 | 
				
			
			@ -1,79 +0,0 @@
 | 
			
		|||
from contextlib import contextmanager
 | 
			
		||||
 | 
			
		||||
from django.db import connections
 | 
			
		||||
 | 
			
		||||
from ....core.schema import GraphQLSchema
 | 
			
		||||
from ....core.types import Field
 | 
			
		||||
from ....plugins import Plugin
 | 
			
		||||
from .sql.tracking import unwrap_cursor, wrap_cursor
 | 
			
		||||
from .sql.types import DjangoDebugSQL
 | 
			
		||||
from .types import DjangoDebug
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WrappedRoot(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, root):
 | 
			
		||||
        self._recorded = []
 | 
			
		||||
        self._root = root
 | 
			
		||||
 | 
			
		||||
    def record(self, **log):
 | 
			
		||||
        self._recorded.append(DjangoDebugSQL(**log))
 | 
			
		||||
 | 
			
		||||
    def debug(self):
 | 
			
		||||
        return DjangoDebug(sql=self._recorded)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WrapRoot(object):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _root(self):
 | 
			
		||||
        return self._wrapped_root.root
 | 
			
		||||
 | 
			
		||||
    @_root.setter
 | 
			
		||||
    def _root(self, value):
 | 
			
		||||
        self._wrapped_root = value
 | 
			
		||||
 | 
			
		||||
    def resolve_debug(self, args, info):
 | 
			
		||||
        return self._wrapped_root.debug()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def debug_objecttype(objecttype):
 | 
			
		||||
    return type(
 | 
			
		||||
        'Debug{}'.format(objecttype._meta.type_name),
 | 
			
		||||
        (WrapRoot, objecttype),
 | 
			
		||||
        {'debug': Field(DjangoDebug, name='__debug')})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugPlugin(Plugin):
 | 
			
		||||
 | 
			
		||||
    def enable_instrumentation(self, wrapped_root):
 | 
			
		||||
        # This is thread-safe because database connections are thread-local.
 | 
			
		||||
        for connection in connections.all():
 | 
			
		||||
            wrap_cursor(connection, wrapped_root)
 | 
			
		||||
 | 
			
		||||
    def disable_instrumentation(self):
 | 
			
		||||
        for connection in connections.all():
 | 
			
		||||
            unwrap_cursor(connection)
 | 
			
		||||
 | 
			
		||||
    def wrap_schema(self, schema_type):
 | 
			
		||||
        query = schema_type._query
 | 
			
		||||
        if query:
 | 
			
		||||
            class_type = self.schema.objecttype(schema_type.get_query_type())
 | 
			
		||||
            assert class_type, 'The query in schema is not constructed with graphene'
 | 
			
		||||
            _type = debug_objecttype(class_type)
 | 
			
		||||
            self.schema.register(_type, force=True)
 | 
			
		||||
            return GraphQLSchema(
 | 
			
		||||
                self.schema,
 | 
			
		||||
                self.schema.T(_type),
 | 
			
		||||
                schema_type.get_mutation_type(),
 | 
			
		||||
                schema_type.get_subscription_type()
 | 
			
		||||
            )
 | 
			
		||||
        return schema_type
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def context_execution(self, executor):
 | 
			
		||||
        executor['root_value'] = WrappedRoot(root=executor.get('root_value'))
 | 
			
		||||
        executor['schema'] = self.wrap_schema(executor['schema'])
 | 
			
		||||
        self.enable_instrumentation(executor['root_value'])
 | 
			
		||||
        yield executor
 | 
			
		||||
        self.disable_instrumentation()
 | 
			
		||||
| 
						 | 
				
			
			@ -8,6 +8,8 @@ from time import time
 | 
			
		|||
from django.utils import six
 | 
			
		||||
from django.utils.encoding import force_text
 | 
			
		||||
 | 
			
		||||
from .types import DjangoDebugSQL, DjangoDebugPostgreSQL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SQLQueryTriggered(Exception):
 | 
			
		||||
    """Thrown when template panel triggers a query"""
 | 
			
		||||
| 
						 | 
				
			
			@ -139,9 +141,11 @@ class NormalCursorWrapper(object):
 | 
			
		|||
                    'iso_level': iso_level,
 | 
			
		||||
                    'encoding': conn.encoding,
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
                _sql = DjangoDebugPostgreSQL(**params)
 | 
			
		||||
            else:
 | 
			
		||||
                _sql = DjangoDebugSQL(**params)
 | 
			
		||||
            # We keep `sql` to maintain backwards compatibility
 | 
			
		||||
            self.logger.record(**params)
 | 
			
		||||
            self.logger.object.sql.append(_sql)
 | 
			
		||||
 | 
			
		||||
    def callproc(self, procname, params=()):
 | 
			
		||||
        return self._record(self.cursor.callproc, procname, params)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
from .....core import Boolean, Float, ObjectType, String
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugSQL(ObjectType):
 | 
			
		||||
class DjangoDebugBaseSQL(ObjectType):
 | 
			
		||||
    vendor = String()
 | 
			
		||||
    alias = String()
 | 
			
		||||
    sql = String()
 | 
			
		||||
| 
						 | 
				
			
			@ -13,6 +13,12 @@ class DjangoDebugSQL(ObjectType):
 | 
			
		|||
    is_slow = Boolean()
 | 
			
		||||
    is_select = Boolean()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugSQL(DjangoDebugBaseSQL):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebugPostgreSQL(DjangoDebugBaseSQL):
 | 
			
		||||
    trans_id = String()
 | 
			
		||||
    trans_status = String()
 | 
			
		||||
    iso_level = String()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,12 @@ from graphene.contrib.django import DjangoConnectionField, DjangoNode
 | 
			
		|||
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
 | 
			
		||||
 | 
			
		||||
from ...tests.models import Reporter
 | 
			
		||||
from ..plugin import DjangoDebugPlugin
 | 
			
		||||
from ..middleware import DjangoDebugMiddleware
 | 
			
		||||
from ..types import DjangoDebug
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class context(object):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
# from examples.starwars_django.models import Character
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -25,6 +30,7 @@ def test_should_query_field():
 | 
			
		|||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        reporter = graphene.Field(ReporterType)
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
 | 
			
		||||
        def resolve_reporter(self, *args, **kwargs):
 | 
			
		||||
            return Reporter.objects.first()
 | 
			
		||||
| 
						 | 
				
			
			@ -51,8 +57,8 @@ def test_should_query_field():
 | 
			
		|||
            }]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(query, context_value=context())
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -70,6 +76,7 @@ def test_should_query_list():
 | 
			
		|||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = ReporterType.List()
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, *args, **kwargs):
 | 
			
		||||
            return Reporter.objects.all()
 | 
			
		||||
| 
						 | 
				
			
			@ -98,8 +105,8 @@ def test_should_query_list():
 | 
			
		|||
            }]
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(query, context_value=context())
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -117,6 +124,7 @@ def test_should_query_connection():
 | 
			
		|||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = DjangoConnectionField(ReporterType)
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, *args, **kwargs):
 | 
			
		||||
            return Reporter.objects.all()
 | 
			
		||||
| 
						 | 
				
			
			@ -146,8 +154,8 @@ def test_should_query_connection():
 | 
			
		|||
            }]
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
    schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(query, context_value=context())
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data['allReporters'] == expected['allReporters']
 | 
			
		||||
    assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
 | 
			
		||||
| 
						 | 
				
			
			@ -172,6 +180,7 @@ def test_should_query_connectionfilter():
 | 
			
		|||
 | 
			
		||||
    class Query(graphene.ObjectType):
 | 
			
		||||
        all_reporters = DjangoFilterConnectionField(ReporterType)
 | 
			
		||||
        debug = graphene.Field(DjangoDebug, name='__debug')
 | 
			
		||||
 | 
			
		||||
        def resolve_all_reporters(self, *args, **kwargs):
 | 
			
		||||
            return Reporter.objects.all()
 | 
			
		||||
| 
						 | 
				
			
			@ -201,8 +210,8 @@ def test_should_query_connectionfilter():
 | 
			
		|||
            }]
 | 
			
		||||
        },
 | 
			
		||||
    }
 | 
			
		||||
    schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
 | 
			
		||||
    result = schema.execute(query, context_value=context())
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data['allReporters'] == expected['allReporters']
 | 
			
		||||
    assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
from ....core.classtypes.objecttype import ObjectType
 | 
			
		||||
from ....core.types import Field
 | 
			
		||||
from .sql.types import DjangoDebugSQL
 | 
			
		||||
from .sql.types import DjangoDebugBaseSQL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DjangoDebug(ObjectType):
 | 
			
		||||
    sql = Field(DjangoDebugSQL.List())
 | 
			
		||||
    sql = Field(DjangoDebugBaseSQL.List())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,7 +7,7 @@ from graphql.utils.schema_printer import print_schema
 | 
			
		|||
 | 
			
		||||
from graphene import signals
 | 
			
		||||
 | 
			
		||||
from ..plugins import CamelCase, PluginManager
 | 
			
		||||
from ..middlewares import MiddlewareManager, CamelCaseArgsMiddleware
 | 
			
		||||
from .classtypes.base import ClassType
 | 
			
		||||
from .types.base import InstanceType
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -23,7 +23,7 @@ class Schema(object):
 | 
			
		|||
    _executor = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, query=None, mutation=None, subscription=None,
 | 
			
		||||
                 name='Schema', executor=None, plugins=None, auto_camelcase=True, **options):
 | 
			
		||||
                 name='Schema', executor=None, middlewares=None, auto_camelcase=True, **options):
 | 
			
		||||
        self._types_names = {}
 | 
			
		||||
        self._types = {}
 | 
			
		||||
        self.mutation = mutation
 | 
			
		||||
| 
						 | 
				
			
			@ -31,21 +31,19 @@ class Schema(object):
 | 
			
		|||
        self.subscription = subscription
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.executor = executor
 | 
			
		||||
        plugins = plugins or []
 | 
			
		||||
        if 'plugins' in options:
 | 
			
		||||
            raise Exception('Plugins are deprecated, please use middlewares.')
 | 
			
		||||
        middlewares = middlewares or []
 | 
			
		||||
        if auto_camelcase:
 | 
			
		||||
            plugins.append(CamelCase())
 | 
			
		||||
        self.plugins = PluginManager(self, plugins)
 | 
			
		||||
            middlewares.append(CamelCaseArgsMiddleware())
 | 
			
		||||
        self.auto_camelcase = auto_camelcase
 | 
			
		||||
        self.middleware_manager = MiddlewareManager(self, middlewares)
 | 
			
		||||
        self.options = options
 | 
			
		||||
        signals.init_schema.send(self)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return '<Schema: %s (%s)>' % (str(self.name), hash(self))
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        if name in self.plugins:
 | 
			
		||||
            return getattr(self.plugins, name)
 | 
			
		||||
        return super(Schema, self).__getattr__(name)
 | 
			
		||||
 | 
			
		||||
    def T(self, _type):
 | 
			
		||||
        if not _type:
 | 
			
		||||
            return
 | 
			
		||||
| 
						 | 
				
			
			@ -111,13 +109,16 @@ class Schema(object):
 | 
			
		|||
            raise KeyError('Type %r not found in %r' % (type_name, self))
 | 
			
		||||
        return self._types_names[type_name]
 | 
			
		||||
 | 
			
		||||
    def resolver_with_middleware(self, resolver):
 | 
			
		||||
        return self.middleware_manager.wrap(resolver)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def types(self):
 | 
			
		||||
        return self._types_names
 | 
			
		||||
 | 
			
		||||
    def execute(self, request_string='', root_value=None, variable_values=None,
 | 
			
		||||
                context_value=None, operation_name=None, executor=None):
 | 
			
		||||
        kwargs = dict(
 | 
			
		||||
        return graphql(
 | 
			
		||||
            schema=self.schema,
 | 
			
		||||
            request_string=request_string,
 | 
			
		||||
            root_value=root_value,
 | 
			
		||||
| 
						 | 
				
			
			@ -126,8 +127,6 @@ class Schema(object):
 | 
			
		|||
            operation_name=operation_name,
 | 
			
		||||
            executor=executor or self._executor
 | 
			
		||||
        )
 | 
			
		||||
        with self.plugins.context_execution(**kwargs) as execute_kwargs:
 | 
			
		||||
            return graphql(**execute_kwargs)
 | 
			
		||||
 | 
			
		||||
    def introspect(self):
 | 
			
		||||
        return graphql(self.schema, introspection_query).data
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -93,7 +93,7 @@ def test_field_resolve():
 | 
			
		|||
    f = StringField(required=True, resolve=lambda *args: 'RESOLVED').as_field()
 | 
			
		||||
    f.contribute_to_class(MyOt, 'field_name')
 | 
			
		||||
    field_type = schema.T(f)
 | 
			
		||||
    assert 'RESOLVED' == field_type.resolver(MyOt, None, None, None)
 | 
			
		||||
    assert 'RESOLVED' == field_type.resolver(MyOt, None, None, None).value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_resolve_type_custom():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -154,6 +154,12 @@ def test_lazytype():
 | 
			
		|||
    assert schema.T(t) == schema.T(MyType)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_deprecated_plugins_throws_exception():
 | 
			
		||||
    with raises(Exception) as excinfo:
 | 
			
		||||
        Schema(plugins=[])
 | 
			
		||||
    assert 'Plugins are deprecated, please use middlewares' in str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_schema_str():
 | 
			
		||||
    expected = """
 | 
			
		||||
schema {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,9 +1,7 @@
 | 
			
		|||
from functools import wraps
 | 
			
		||||
from itertools import chain
 | 
			
		||||
 | 
			
		||||
from graphql.type import GraphQLArgument
 | 
			
		||||
 | 
			
		||||
from ...utils import ProxySnakeDict
 | 
			
		||||
from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,11 +51,3 @@ def to_arguments(*args, **kwargs):
 | 
			
		|||
        arguments[name] = argument
 | 
			
		||||
 | 
			
		||||
    return sorted(arguments.values())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def snake_case_args(resolver):
 | 
			
		||||
    @wraps(resolver)
 | 
			
		||||
    def wrapped_resolver(instance, args, context, info):
 | 
			
		||||
        return resolver(instance, ProxySnakeDict(args), context, info)
 | 
			
		||||
 | 
			
		||||
    return wrapped_resolver
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,8 @@ from functools import partial, total_ordering
 | 
			
		|||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from ...utils import to_camel_case
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InstanceType(object):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -142,7 +144,9 @@ class GroupNamedType(InstanceType):
 | 
			
		|||
        self.types = types
 | 
			
		||||
 | 
			
		||||
    def get_named_type(self, schema, type):
 | 
			
		||||
        name = type.name or schema.get_default_namedtype_name(type.default_name)
 | 
			
		||||
        name = type.name
 | 
			
		||||
        if not name and schema.auto_camelcase:
 | 
			
		||||
            name = to_camel_case(type.default_name)
 | 
			
		||||
        return name, schema.T(type)
 | 
			
		||||
 | 
			
		||||
    def iter_types(self, schema):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,7 @@ from ..classtypes.base import FieldsClassType
 | 
			
		|||
from ..classtypes.inputobjecttype import InputObjectType
 | 
			
		||||
from ..classtypes.mutation import Mutation
 | 
			
		||||
from ..exceptions import SkipField
 | 
			
		||||
from .argument import Argument, ArgumentsGroup, snake_case_args
 | 
			
		||||
from .argument import Argument, ArgumentsGroup
 | 
			
		||||
from .base import (ArgumentType, GroupNamedType, LazyType, MountType,
 | 
			
		||||
                   NamedType, OrderedType)
 | 
			
		||||
from .definitions import NonNull
 | 
			
		||||
| 
						 | 
				
			
			@ -89,9 +89,6 @@ class Field(NamedType, OrderedType):
 | 
			
		|||
            return NonNull(self.type)
 | 
			
		||||
        return self.type
 | 
			
		||||
 | 
			
		||||
    def decorate_resolver(self, resolver):
 | 
			
		||||
        return snake_case_args(resolver)
 | 
			
		||||
 | 
			
		||||
    def internal_type(self, schema):
 | 
			
		||||
        if not self.object_type:
 | 
			
		||||
            raise Exception('The field is not mounted in any ClassType')
 | 
			
		||||
| 
						 | 
				
			
			@ -118,10 +115,13 @@ class Field(NamedType, OrderedType):
 | 
			
		|||
            resolver = wrapped_func
 | 
			
		||||
 | 
			
		||||
        assert type, 'Internal type for field %s is None' % str(self)
 | 
			
		||||
        return GraphQLField(type, args=schema.T(arguments),
 | 
			
		||||
                            resolver=self.decorate_resolver(resolver),
 | 
			
		||||
        return GraphQLField(
 | 
			
		||||
            type,
 | 
			
		||||
            args=schema.T(arguments),
 | 
			
		||||
            resolver=schema.resolver_with_middleware(resolver),
 | 
			
		||||
            deprecation_reason=self.deprecation_reason,
 | 
			
		||||
                            description=description,)
 | 
			
		||||
            description=description,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -175,7 +175,8 @@ class InputField(NamedType, OrderedType):
 | 
			
		|||
    def internal_type(self, schema):
 | 
			
		||||
        return GraphQLInputObjectField(
 | 
			
		||||
            schema.T(self.type),
 | 
			
		||||
            default_value=self.default, description=self.description)
 | 
			
		||||
            default_value=self.default, description=self.description
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FieldsGroupType(GroupNamedType):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,7 +4,7 @@ from pytest import raises
 | 
			
		|||
from graphene.core.schema import Schema
 | 
			
		||||
from graphene.core.types import ObjectType
 | 
			
		||||
 | 
			
		||||
from ..argument import Argument, snake_case_args, to_arguments
 | 
			
		||||
from ..argument import Argument, to_arguments
 | 
			
		||||
from ..scalars import String
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -45,10 +45,3 @@ def test_to_arguments_wrong_type():
 | 
			
		|||
            p=3
 | 
			
		||||
        )
 | 
			
		||||
    assert 'Unknown argument p=3' == str(excinfo.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_snake_case_args():
 | 
			
		||||
    def resolver(instance, args, context, info):
 | 
			
		||||
        return args['my_arg']['inner_arg']
 | 
			
		||||
    r = snake_case_args(resolver)
 | 
			
		||||
    assert r(None, {'myArg': {'innerArg': 3}}, None, None) == 3
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,7 +24,7 @@ def test_field_internal_type():
 | 
			
		|||
    assert field.attname == 'my_field'
 | 
			
		||||
    assert isinstance(type, GraphQLField)
 | 
			
		||||
    assert type.description == 'My argument'
 | 
			
		||||
    assert type.resolver(None, {}, None, None) == 'RESOLVED'
 | 
			
		||||
    assert type.resolver(None, {}, None, None).value == 'RESOLVED'
 | 
			
		||||
    assert type.type == GraphQLString
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -43,7 +43,7 @@ def test_field_objectype_resolver():
 | 
			
		|||
    type = schema.T(field)
 | 
			
		||||
    assert isinstance(type, GraphQLField)
 | 
			
		||||
    assert type.description == 'Custom description'
 | 
			
		||||
    assert type.resolver(Query(), {}, None, None) == 'RESOLVED'
 | 
			
		||||
    assert type.resolver(Query(), {}, None, None).value == 'RESOLVED'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_custom_name():
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +161,7 @@ def test_field_resolve_argument():
 | 
			
		|||
    schema = Schema(query=Query)
 | 
			
		||||
 | 
			
		||||
    type = schema.T(field)
 | 
			
		||||
    assert type.resolver(None, {'firstName': 'Peter'}, None, None) == 'Peter'
 | 
			
		||||
    assert type.resolver(None, {'firstName': 'Peter'}, None, None).value == 'Peter'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_resolve_vars():
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +216,6 @@ def test_field_resolve_object():
 | 
			
		|||
        att_func = field_func
 | 
			
		||||
 | 
			
		||||
    assert field.resolver(Root, {}, None) is True
 | 
			
		||||
    assert field.resolver(Root, {}, None) is True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_field_resolve_source_object():
 | 
			
		||||
| 
						 | 
				
			
			@ -235,4 +234,3 @@ def test_field_resolve_source_object():
 | 
			
		|||
        att_func = field_func
 | 
			
		||||
 | 
			
		||||
    assert field.resolver(Root, {}, None) is True
 | 
			
		||||
    assert field.resolver(Root, {}, None) is True
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										6
									
								
								graphene/middlewares/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								graphene/middlewares/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,6 @@
 | 
			
		|||
from .base import MiddlewareManager
 | 
			
		||||
from .camel_case import CamelCaseArgsMiddleware
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'MiddlewareManager', 'CamelCaseArgsMiddleware'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										23
									
								
								graphene/middlewares/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								graphene/middlewares/base.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
from ..utils import promise_middleware
 | 
			
		||||
 | 
			
		||||
MIDDLEWARE_RESOLVER_FUNCTION = 'resolve'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MiddlewareManager(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, schema, middlewares=None):
 | 
			
		||||
        self.schema = schema
 | 
			
		||||
        self.middlewares = middlewares or []
 | 
			
		||||
 | 
			
		||||
    def add_middleware(self, middleware):
 | 
			
		||||
        self.middlewares.append(middleware)
 | 
			
		||||
 | 
			
		||||
    def get_middleware_resolvers(self):
 | 
			
		||||
        for middleware in self.middlewares:
 | 
			
		||||
            if not hasattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION):
 | 
			
		||||
                continue
 | 
			
		||||
            yield getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION)
 | 
			
		||||
 | 
			
		||||
    def wrap(self, resolver):
 | 
			
		||||
        middleware_resolvers = self.get_middleware_resolvers()
 | 
			
		||||
        return promise_middleware(resolver, middleware_resolvers)
 | 
			
		||||
							
								
								
									
										8
									
								
								graphene/middlewares/camel_case.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								graphene/middlewares/camel_case.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
from ..utils import ProxySnakeDict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CamelCaseArgsMiddleware(object):
 | 
			
		||||
 | 
			
		||||
    def resolve(self, next, root, args, context, info):
 | 
			
		||||
        args = ProxySnakeDict(args)
 | 
			
		||||
        return next(root, args, context, info)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,6 +0,0 @@
 | 
			
		|||
from .base import Plugin, PluginManager
 | 
			
		||||
from .camel_case import CamelCase
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'Plugin', 'PluginManager', 'CamelCase'
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			@ -1,53 +0,0 @@
 | 
			
		|||
from contextlib import contextmanager
 | 
			
		||||
from functools import partial, reduce
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Plugin(object):
 | 
			
		||||
 | 
			
		||||
    def contribute_to_schema(self, schema):
 | 
			
		||||
        self.schema = schema
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_function(a, b):
 | 
			
		||||
    return b(a)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PluginManager(object):
 | 
			
		||||
 | 
			
		||||
    PLUGIN_FUNCTIONS = ('get_default_namedtype_name', )
 | 
			
		||||
 | 
			
		||||
    def __init__(self, schema, plugins=[]):
 | 
			
		||||
        self.schema = schema
 | 
			
		||||
        self.plugins = []
 | 
			
		||||
        for plugin in plugins:
 | 
			
		||||
            self.add_plugin(plugin)
 | 
			
		||||
 | 
			
		||||
    def add_plugin(self, plugin):
 | 
			
		||||
        if hasattr(plugin, 'contribute_to_schema'):
 | 
			
		||||
            plugin.contribute_to_schema(self.schema)
 | 
			
		||||
        self.plugins.append(plugin)
 | 
			
		||||
 | 
			
		||||
    def get_plugin_functions(self, function):
 | 
			
		||||
        for plugin in self.plugins:
 | 
			
		||||
            if not hasattr(plugin, function):
 | 
			
		||||
                continue
 | 
			
		||||
            yield getattr(plugin, function)
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        functions = self.get_plugin_functions(name)
 | 
			
		||||
        return partial(reduce, apply_function, functions)
 | 
			
		||||
 | 
			
		||||
    def __contains__(self, name):
 | 
			
		||||
        return name in self.PLUGIN_FUNCTIONS
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def context_execution(self, **executor):
 | 
			
		||||
        contexts = []
 | 
			
		||||
        functions = self.get_plugin_functions('context_execution')
 | 
			
		||||
        for f in functions:
 | 
			
		||||
            context = f(executor)
 | 
			
		||||
            executor = context.__enter__()
 | 
			
		||||
            contexts.append((context, executor))
 | 
			
		||||
        yield executor
 | 
			
		||||
        for context, value in contexts[::-1]:
 | 
			
		||||
            context.__exit__(None, None, None)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,7 +0,0 @@
 | 
			
		|||
from ..utils import to_camel_case
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CamelCase(object):
 | 
			
		||||
 | 
			
		||||
    def get_default_namedtype_name(self, value):
 | 
			
		||||
        return to_camel_case(value)
 | 
			
		||||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from .proxy_snake_dict import ProxySnakeDict
 | 
			
		|||
from .caching import cached_property, memoize
 | 
			
		||||
from .maybe_func import maybe_func
 | 
			
		||||
from .misc import enum_to_graphql_enum
 | 
			
		||||
from .promise_middleware import promise_middleware
 | 
			
		||||
from .resolve_only_args import resolve_only_args
 | 
			
		||||
from .lazylist import LazyList
 | 
			
		||||
from .wrap_resolver_function import with_context, wrap_resolver_function
 | 
			
		||||
| 
						 | 
				
			
			@ -10,5 +11,5 @@ from .wrap_resolver_function import with_context, wrap_resolver_function
 | 
			
		|||
 | 
			
		||||
__all__ = ['to_camel_case', 'to_snake_case', 'to_const', 'ProxySnakeDict',
 | 
			
		||||
           'cached_property', 'memoize', 'maybe_func', 'enum_to_graphql_enum',
 | 
			
		||||
           'resolve_only_args', 'LazyList', 'with_context',
 | 
			
		||||
           'promise_middleware', 'resolve_only_args', 'LazyList', 'with_context',
 | 
			
		||||
           'wrap_resolver_function']
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										17
									
								
								graphene/utils/promise_middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								graphene/utils/promise_middleware.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,17 @@
 | 
			
		|||
from functools import partial
 | 
			
		||||
from itertools import chain
 | 
			
		||||
 | 
			
		||||
from promise import Promise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def promise_middleware(func, middlewares):
 | 
			
		||||
    middlewares = chain((func, make_it_promise), middlewares)
 | 
			
		||||
    past = None
 | 
			
		||||
    for m in middlewares:
 | 
			
		||||
        past = partial(m, past) if past else m
 | 
			
		||||
 | 
			
		||||
    return past
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_it_promise(next, *a, **b):
 | 
			
		||||
    return Promise.resolve(next(*a, **b))
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user