mirror of
https://github.com/graphql-python/graphene.git
synced 2025-02-02 20:54:16 +03:00
Merge pull request #176 from graphql-python/features/middlewares
Added Middleware
This commit is contained in:
commit
3a1093af24
|
@ -17,6 +17,7 @@ ga = "UA-12613282-7"
|
||||||
"/docs/basic-types/",
|
"/docs/basic-types/",
|
||||||
"/docs/enums/",
|
"/docs/enums/",
|
||||||
"/docs/relay/",
|
"/docs/relay/",
|
||||||
|
"/docs/middleware/",
|
||||||
]
|
]
|
||||||
|
|
||||||
[docs.django]
|
[docs.django]
|
||||||
|
|
|
@ -13,14 +13,19 @@ For that, you will need to add the plugin in your graphene schema.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
For use the Django Debug plugin in Graphene, just import `DjangoDebugPlugin` and add it to the `plugins` argument when you initiate the `Schema`.
|
For use the Django Debug plugin in Graphene:
|
||||||
|
* Import `DjangoDebugMiddleware` and add it to the `middleware` argument when you initiate the `Schema`.
|
||||||
|
* Add the `debug` field into the schema root `Query` with the value `graphene.Field(DjangoDebug, name='__debug')`.
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from graphene.contrib.django.debug import DjangoDebugPlugin
|
from graphene.contrib.django.debug import DjangoDebugMiddleware, DjangoDebug
|
||||||
|
|
||||||
# ...
|
class Query(graphene.ObjectType):
|
||||||
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
|
# ...
|
||||||
|
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||||
|
|
||||||
|
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
|
||||||
```
|
```
|
||||||
|
|
||||||
This plugin, will add another field in the `Query` named `__debug`.
|
This plugin, will add another field in the `Query` named `__debug`.
|
||||||
|
|
43
docs/pages/docs/middleware.md
Normal file
43
docs/pages/docs/middleware.md
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
---
|
||||||
|
title: Middleware
|
||||||
|
description: Walkthrough Middleware
|
||||||
|
---
|
||||||
|
|
||||||
|
# Middleware
|
||||||
|
|
||||||
|
You can use _middleware_ to affect the evaluation of fields in your schema.
|
||||||
|
|
||||||
|
A middleware is any object that responds to `resolve(*args, next_middleware)`. Inside that method, it should either:
|
||||||
|
|
||||||
|
* Send `resolve` to the next middleware to continue the evaluation; or
|
||||||
|
* Return a value to end the evaluation early.
|
||||||
|
|
||||||
|
Middlewares' `resolve` is invoked with several arguments:
|
||||||
|
|
||||||
|
* `next` represents the execution chain. Call `next` to continue evalution.
|
||||||
|
* `root` is the root value object passed throughout the query
|
||||||
|
* `args` is the hash of arguments passed to the field
|
||||||
|
* `context` is the context object passed throughout the query
|
||||||
|
* `info` is the resolver info
|
||||||
|
|
||||||
|
Add a middleware to a schema by adding to the `middlewares` list.
|
||||||
|
|
||||||
|
|
||||||
|
### Example: Authorization
|
||||||
|
|
||||||
|
This middleware only continues evaluation if the `field_name` is not `'user'`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AuthorizationMiddleware(object):
|
||||||
|
|
||||||
|
def resolve(self, next, root, args, context, info):
|
||||||
|
if info.field_name == 'user':
|
||||||
|
return None
|
||||||
|
return next(root, args, context, info)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, add the middleware to your schema:
|
||||||
|
|
||||||
|
```python
|
||||||
|
schema = Schema(middlewares=[AuthorizationMiddleware])
|
||||||
|
```
|
|
@ -1,4 +1,4 @@
|
||||||
from .plugin import DjangoDebugPlugin
|
from .middleware import DjangoDebugMiddleware
|
||||||
from .types import DjangoDebug
|
from .types import DjangoDebug
|
||||||
|
|
||||||
__all__ = ['DjangoDebugPlugin', 'DjangoDebug']
|
__all__ = ['DjangoDebugMiddleware', 'DjangoDebug']
|
||||||
|
|
56
graphene/contrib/django/debug/middleware.py
Normal file
56
graphene/contrib/django/debug/middleware.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
from promise import Promise
|
||||||
|
from django.db import connections
|
||||||
|
|
||||||
|
from .sql.tracking import unwrap_cursor, wrap_cursor
|
||||||
|
from .types import DjangoDebug
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebugContext(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.debug_promise = None
|
||||||
|
self.promises = []
|
||||||
|
self.enable_instrumentation()
|
||||||
|
self.object = DjangoDebug(sql=[])
|
||||||
|
|
||||||
|
def get_debug_promise(self):
|
||||||
|
if not self.debug_promise:
|
||||||
|
self.debug_promise = Promise.all(self.promises)
|
||||||
|
return self.debug_promise.then(self.on_resolve_all_promises)
|
||||||
|
|
||||||
|
def on_resolve_all_promises(self, values):
|
||||||
|
self.disable_instrumentation()
|
||||||
|
return self.object
|
||||||
|
|
||||||
|
def add_promise(self, promise):
|
||||||
|
if self.debug_promise and not self.debug_promise.is_fulfilled:
|
||||||
|
self.promises.append(promise)
|
||||||
|
|
||||||
|
def enable_instrumentation(self):
|
||||||
|
# This is thread-safe because database connections are thread-local.
|
||||||
|
for connection in connections.all():
|
||||||
|
wrap_cursor(connection, self)
|
||||||
|
|
||||||
|
def disable_instrumentation(self):
|
||||||
|
for connection in connections.all():
|
||||||
|
unwrap_cursor(connection)
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebugMiddleware(object):
|
||||||
|
|
||||||
|
def resolve(self, next, root, args, context, info):
|
||||||
|
django_debug = getattr(context, 'django_debug', None)
|
||||||
|
if not django_debug:
|
||||||
|
if context is None:
|
||||||
|
raise Exception('DjangoDebug cannot be executed in None contexts')
|
||||||
|
try:
|
||||||
|
context.django_debug = DjangoDebugContext()
|
||||||
|
except Exception:
|
||||||
|
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
|
||||||
|
context.__class__.__name__
|
||||||
|
))
|
||||||
|
if info.schema.graphene_schema.T(DjangoDebug) == info.return_type:
|
||||||
|
return context.django_debug.get_debug_promise()
|
||||||
|
promise = next(root, args, context, info)
|
||||||
|
context.django_debug.add_promise(promise)
|
||||||
|
return promise
|
|
@ -1,79 +0,0 @@
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from django.db import connections
|
|
||||||
|
|
||||||
from ....core.schema import GraphQLSchema
|
|
||||||
from ....core.types import Field
|
|
||||||
from ....plugins import Plugin
|
|
||||||
from .sql.tracking import unwrap_cursor, wrap_cursor
|
|
||||||
from .sql.types import DjangoDebugSQL
|
|
||||||
from .types import DjangoDebug
|
|
||||||
|
|
||||||
|
|
||||||
class WrappedRoot(object):
|
|
||||||
|
|
||||||
def __init__(self, root):
|
|
||||||
self._recorded = []
|
|
||||||
self._root = root
|
|
||||||
|
|
||||||
def record(self, **log):
|
|
||||||
self._recorded.append(DjangoDebugSQL(**log))
|
|
||||||
|
|
||||||
def debug(self):
|
|
||||||
return DjangoDebug(sql=self._recorded)
|
|
||||||
|
|
||||||
|
|
||||||
class WrapRoot(object):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _root(self):
|
|
||||||
return self._wrapped_root.root
|
|
||||||
|
|
||||||
@_root.setter
|
|
||||||
def _root(self, value):
|
|
||||||
self._wrapped_root = value
|
|
||||||
|
|
||||||
def resolve_debug(self, args, info):
|
|
||||||
return self._wrapped_root.debug()
|
|
||||||
|
|
||||||
|
|
||||||
def debug_objecttype(objecttype):
|
|
||||||
return type(
|
|
||||||
'Debug{}'.format(objecttype._meta.type_name),
|
|
||||||
(WrapRoot, objecttype),
|
|
||||||
{'debug': Field(DjangoDebug, name='__debug')})
|
|
||||||
|
|
||||||
|
|
||||||
class DjangoDebugPlugin(Plugin):
|
|
||||||
|
|
||||||
def enable_instrumentation(self, wrapped_root):
|
|
||||||
# This is thread-safe because database connections are thread-local.
|
|
||||||
for connection in connections.all():
|
|
||||||
wrap_cursor(connection, wrapped_root)
|
|
||||||
|
|
||||||
def disable_instrumentation(self):
|
|
||||||
for connection in connections.all():
|
|
||||||
unwrap_cursor(connection)
|
|
||||||
|
|
||||||
def wrap_schema(self, schema_type):
|
|
||||||
query = schema_type._query
|
|
||||||
if query:
|
|
||||||
class_type = self.schema.objecttype(schema_type.get_query_type())
|
|
||||||
assert class_type, 'The query in schema is not constructed with graphene'
|
|
||||||
_type = debug_objecttype(class_type)
|
|
||||||
self.schema.register(_type, force=True)
|
|
||||||
return GraphQLSchema(
|
|
||||||
self.schema,
|
|
||||||
self.schema.T(_type),
|
|
||||||
schema_type.get_mutation_type(),
|
|
||||||
schema_type.get_subscription_type()
|
|
||||||
)
|
|
||||||
return schema_type
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def context_execution(self, executor):
|
|
||||||
executor['root_value'] = WrappedRoot(root=executor.get('root_value'))
|
|
||||||
executor['schema'] = self.wrap_schema(executor['schema'])
|
|
||||||
self.enable_instrumentation(executor['root_value'])
|
|
||||||
yield executor
|
|
||||||
self.disable_instrumentation()
|
|
|
@ -8,6 +8,8 @@ from time import time
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
|
|
||||||
|
from .types import DjangoDebugSQL, DjangoDebugPostgreSQL
|
||||||
|
|
||||||
|
|
||||||
class SQLQueryTriggered(Exception):
|
class SQLQueryTriggered(Exception):
|
||||||
"""Thrown when template panel triggers a query"""
|
"""Thrown when template panel triggers a query"""
|
||||||
|
@ -139,9 +141,11 @@ class NormalCursorWrapper(object):
|
||||||
'iso_level': iso_level,
|
'iso_level': iso_level,
|
||||||
'encoding': conn.encoding,
|
'encoding': conn.encoding,
|
||||||
})
|
})
|
||||||
|
_sql = DjangoDebugPostgreSQL(**params)
|
||||||
|
else:
|
||||||
|
_sql = DjangoDebugSQL(**params)
|
||||||
# We keep `sql` to maintain backwards compatibility
|
# We keep `sql` to maintain backwards compatibility
|
||||||
self.logger.record(**params)
|
self.logger.object.sql.append(_sql)
|
||||||
|
|
||||||
def callproc(self, procname, params=()):
|
def callproc(self, procname, params=()):
|
||||||
return self._record(self.cursor.callproc, procname, params)
|
return self._record(self.cursor.callproc, procname, params)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from .....core import Boolean, Float, ObjectType, String
|
from .....core import Boolean, Float, ObjectType, String
|
||||||
|
|
||||||
|
|
||||||
class DjangoDebugSQL(ObjectType):
|
class DjangoDebugBaseSQL(ObjectType):
|
||||||
vendor = String()
|
vendor = String()
|
||||||
alias = String()
|
alias = String()
|
||||||
sql = String()
|
sql = String()
|
||||||
|
@ -13,6 +13,12 @@ class DjangoDebugSQL(ObjectType):
|
||||||
is_slow = Boolean()
|
is_slow = Boolean()
|
||||||
is_select = Boolean()
|
is_select = Boolean()
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebugSQL(DjangoDebugBaseSQL):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoDebugPostgreSQL(DjangoDebugBaseSQL):
|
||||||
trans_id = String()
|
trans_id = String()
|
||||||
trans_status = String()
|
trans_status = String()
|
||||||
iso_level = String()
|
iso_level = String()
|
||||||
|
|
|
@ -5,7 +5,12 @@ from graphene.contrib.django import DjangoConnectionField, DjangoNode
|
||||||
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
|
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
|
||||||
|
|
||||||
from ...tests.models import Reporter
|
from ...tests.models import Reporter
|
||||||
from ..plugin import DjangoDebugPlugin
|
from ..middleware import DjangoDebugMiddleware
|
||||||
|
from ..types import DjangoDebug
|
||||||
|
|
||||||
|
|
||||||
|
class context(object):
|
||||||
|
pass
|
||||||
|
|
||||||
# from examples.starwars_django.models import Character
|
# from examples.starwars_django.models import Character
|
||||||
|
|
||||||
|
@ -25,6 +30,7 @@ def test_should_query_field():
|
||||||
|
|
||||||
class Query(graphene.ObjectType):
|
class Query(graphene.ObjectType):
|
||||||
reporter = graphene.Field(ReporterType)
|
reporter = graphene.Field(ReporterType)
|
||||||
|
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||||
|
|
||||||
def resolve_reporter(self, *args, **kwargs):
|
def resolve_reporter(self, *args, **kwargs):
|
||||||
return Reporter.objects.first()
|
return Reporter.objects.first()
|
||||||
|
@ -51,8 +57,8 @@ def test_should_query_field():
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
|
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
|
||||||
result = schema.execute(query)
|
result = schema.execute(query, context_value=context())
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
|
||||||
|
@ -70,6 +76,7 @@ def test_should_query_list():
|
||||||
|
|
||||||
class Query(graphene.ObjectType):
|
class Query(graphene.ObjectType):
|
||||||
all_reporters = ReporterType.List()
|
all_reporters = ReporterType.List()
|
||||||
|
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||||
|
|
||||||
def resolve_all_reporters(self, *args, **kwargs):
|
def resolve_all_reporters(self, *args, **kwargs):
|
||||||
return Reporter.objects.all()
|
return Reporter.objects.all()
|
||||||
|
@ -98,8 +105,8 @@ def test_should_query_list():
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
|
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
|
||||||
result = schema.execute(query)
|
result = schema.execute(query, context_value=context())
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
|
||||||
|
@ -117,6 +124,7 @@ def test_should_query_connection():
|
||||||
|
|
||||||
class Query(graphene.ObjectType):
|
class Query(graphene.ObjectType):
|
||||||
all_reporters = DjangoConnectionField(ReporterType)
|
all_reporters = DjangoConnectionField(ReporterType)
|
||||||
|
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||||
|
|
||||||
def resolve_all_reporters(self, *args, **kwargs):
|
def resolve_all_reporters(self, *args, **kwargs):
|
||||||
return Reporter.objects.all()
|
return Reporter.objects.all()
|
||||||
|
@ -146,8 +154,8 @@ def test_should_query_connection():
|
||||||
}]
|
}]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
|
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
|
||||||
result = schema.execute(query)
|
result = schema.execute(query, context_value=context())
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data['allReporters'] == expected['allReporters']
|
assert result.data['allReporters'] == expected['allReporters']
|
||||||
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
|
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
|
||||||
|
@ -172,6 +180,7 @@ def test_should_query_connectionfilter():
|
||||||
|
|
||||||
class Query(graphene.ObjectType):
|
class Query(graphene.ObjectType):
|
||||||
all_reporters = DjangoFilterConnectionField(ReporterType)
|
all_reporters = DjangoFilterConnectionField(ReporterType)
|
||||||
|
debug = graphene.Field(DjangoDebug, name='__debug')
|
||||||
|
|
||||||
def resolve_all_reporters(self, *args, **kwargs):
|
def resolve_all_reporters(self, *args, **kwargs):
|
||||||
return Reporter.objects.all()
|
return Reporter.objects.all()
|
||||||
|
@ -201,8 +210,8 @@ def test_should_query_connectionfilter():
|
||||||
}]
|
}]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()])
|
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
|
||||||
result = schema.execute(query)
|
result = schema.execute(query, context_value=context())
|
||||||
assert not result.errors
|
assert not result.errors
|
||||||
assert result.data['allReporters'] == expected['allReporters']
|
assert result.data['allReporters'] == expected['allReporters']
|
||||||
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
|
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from ....core.classtypes.objecttype import ObjectType
|
from ....core.classtypes.objecttype import ObjectType
|
||||||
from ....core.types import Field
|
from ....core.types import Field
|
||||||
from .sql.types import DjangoDebugSQL
|
from .sql.types import DjangoDebugBaseSQL
|
||||||
|
|
||||||
|
|
||||||
class DjangoDebug(ObjectType):
|
class DjangoDebug(ObjectType):
|
||||||
sql = Field(DjangoDebugSQL.List())
|
sql = Field(DjangoDebugBaseSQL.List())
|
||||||
|
|
|
@ -7,7 +7,7 @@ from graphql.utils.schema_printer import print_schema
|
||||||
|
|
||||||
from graphene import signals
|
from graphene import signals
|
||||||
|
|
||||||
from ..plugins import CamelCase, PluginManager
|
from ..middlewares import MiddlewareManager, CamelCaseArgsMiddleware
|
||||||
from .classtypes.base import ClassType
|
from .classtypes.base import ClassType
|
||||||
from .types.base import InstanceType
|
from .types.base import InstanceType
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class Schema(object):
|
||||||
_executor = None
|
_executor = None
|
||||||
|
|
||||||
def __init__(self, query=None, mutation=None, subscription=None,
|
def __init__(self, query=None, mutation=None, subscription=None,
|
||||||
name='Schema', executor=None, plugins=None, auto_camelcase=True, **options):
|
name='Schema', executor=None, middlewares=None, auto_camelcase=True, **options):
|
||||||
self._types_names = {}
|
self._types_names = {}
|
||||||
self._types = {}
|
self._types = {}
|
||||||
self.mutation = mutation
|
self.mutation = mutation
|
||||||
|
@ -31,21 +31,19 @@ class Schema(object):
|
||||||
self.subscription = subscription
|
self.subscription = subscription
|
||||||
self.name = name
|
self.name = name
|
||||||
self.executor = executor
|
self.executor = executor
|
||||||
plugins = plugins or []
|
if 'plugins' in options:
|
||||||
|
raise Exception('Plugins are deprecated, please use middlewares.')
|
||||||
|
middlewares = middlewares or []
|
||||||
if auto_camelcase:
|
if auto_camelcase:
|
||||||
plugins.append(CamelCase())
|
middlewares.append(CamelCaseArgsMiddleware())
|
||||||
self.plugins = PluginManager(self, plugins)
|
self.auto_camelcase = auto_camelcase
|
||||||
|
self.middleware_manager = MiddlewareManager(self, middlewares)
|
||||||
self.options = options
|
self.options = options
|
||||||
signals.init_schema.send(self)
|
signals.init_schema.send(self)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Schema: %s (%s)>' % (str(self.name), hash(self))
|
return '<Schema: %s (%s)>' % (str(self.name), hash(self))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
if name in self.plugins:
|
|
||||||
return getattr(self.plugins, name)
|
|
||||||
return super(Schema, self).__getattr__(name)
|
|
||||||
|
|
||||||
def T(self, _type):
|
def T(self, _type):
|
||||||
if not _type:
|
if not _type:
|
||||||
return
|
return
|
||||||
|
@ -111,13 +109,16 @@ class Schema(object):
|
||||||
raise KeyError('Type %r not found in %r' % (type_name, self))
|
raise KeyError('Type %r not found in %r' % (type_name, self))
|
||||||
return self._types_names[type_name]
|
return self._types_names[type_name]
|
||||||
|
|
||||||
|
def resolver_with_middleware(self, resolver):
|
||||||
|
return self.middleware_manager.wrap(resolver)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def types(self):
|
def types(self):
|
||||||
return self._types_names
|
return self._types_names
|
||||||
|
|
||||||
def execute(self, request_string='', root_value=None, variable_values=None,
|
def execute(self, request_string='', root_value=None, variable_values=None,
|
||||||
context_value=None, operation_name=None, executor=None):
|
context_value=None, operation_name=None, executor=None):
|
||||||
kwargs = dict(
|
return graphql(
|
||||||
schema=self.schema,
|
schema=self.schema,
|
||||||
request_string=request_string,
|
request_string=request_string,
|
||||||
root_value=root_value,
|
root_value=root_value,
|
||||||
|
@ -126,8 +127,6 @@ class Schema(object):
|
||||||
operation_name=operation_name,
|
operation_name=operation_name,
|
||||||
executor=executor or self._executor
|
executor=executor or self._executor
|
||||||
)
|
)
|
||||||
with self.plugins.context_execution(**kwargs) as execute_kwargs:
|
|
||||||
return graphql(**execute_kwargs)
|
|
||||||
|
|
||||||
def introspect(self):
|
def introspect(self):
|
||||||
return graphql(self.schema, introspection_query).data
|
return graphql(self.schema, introspection_query).data
|
||||||
|
|
|
@ -93,7 +93,7 @@ def test_field_resolve():
|
||||||
f = StringField(required=True, resolve=lambda *args: 'RESOLVED').as_field()
|
f = StringField(required=True, resolve=lambda *args: 'RESOLVED').as_field()
|
||||||
f.contribute_to_class(MyOt, 'field_name')
|
f.contribute_to_class(MyOt, 'field_name')
|
||||||
field_type = schema.T(f)
|
field_type = schema.T(f)
|
||||||
assert 'RESOLVED' == field_type.resolver(MyOt, None, None, None)
|
assert 'RESOLVED' == field_type.resolver(MyOt, None, None, None).value
|
||||||
|
|
||||||
|
|
||||||
def test_field_resolve_type_custom():
|
def test_field_resolve_type_custom():
|
||||||
|
|
|
@ -154,6 +154,12 @@ def test_lazytype():
|
||||||
assert schema.T(t) == schema.T(MyType)
|
assert schema.T(t) == schema.T(MyType)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecated_plugins_throws_exception():
|
||||||
|
with raises(Exception) as excinfo:
|
||||||
|
Schema(plugins=[])
|
||||||
|
assert 'Plugins are deprecated, please use middlewares' in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
def test_schema_str():
|
def test_schema_str():
|
||||||
expected = """
|
expected = """
|
||||||
schema {
|
schema {
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
from functools import wraps
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from graphql.type import GraphQLArgument
|
from graphql.type import GraphQLArgument
|
||||||
|
|
||||||
from ...utils import ProxySnakeDict
|
|
||||||
from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
|
from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,11 +51,3 @@ def to_arguments(*args, **kwargs):
|
||||||
arguments[name] = argument
|
arguments[name] = argument
|
||||||
|
|
||||||
return sorted(arguments.values())
|
return sorted(arguments.values())
|
||||||
|
|
||||||
|
|
||||||
def snake_case_args(resolver):
|
|
||||||
@wraps(resolver)
|
|
||||||
def wrapped_resolver(instance, args, context, info):
|
|
||||||
return resolver(instance, ProxySnakeDict(args), context, info)
|
|
||||||
|
|
||||||
return wrapped_resolver
|
|
||||||
|
|
|
@ -3,6 +3,8 @@ from functools import partial, total_ordering
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from ...utils import to_camel_case
|
||||||
|
|
||||||
|
|
||||||
class InstanceType(object):
|
class InstanceType(object):
|
||||||
|
|
||||||
|
@ -142,7 +144,9 @@ class GroupNamedType(InstanceType):
|
||||||
self.types = types
|
self.types = types
|
||||||
|
|
||||||
def get_named_type(self, schema, type):
|
def get_named_type(self, schema, type):
|
||||||
name = type.name or schema.get_default_namedtype_name(type.default_name)
|
name = type.name
|
||||||
|
if not name and schema.auto_camelcase:
|
||||||
|
name = to_camel_case(type.default_name)
|
||||||
return name, schema.T(type)
|
return name, schema.T(type)
|
||||||
|
|
||||||
def iter_types(self, schema):
|
def iter_types(self, schema):
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..classtypes.base import FieldsClassType
|
||||||
from ..classtypes.inputobjecttype import InputObjectType
|
from ..classtypes.inputobjecttype import InputObjectType
|
||||||
from ..classtypes.mutation import Mutation
|
from ..classtypes.mutation import Mutation
|
||||||
from ..exceptions import SkipField
|
from ..exceptions import SkipField
|
||||||
from .argument import Argument, ArgumentsGroup, snake_case_args
|
from .argument import Argument, ArgumentsGroup
|
||||||
from .base import (ArgumentType, GroupNamedType, LazyType, MountType,
|
from .base import (ArgumentType, GroupNamedType, LazyType, MountType,
|
||||||
NamedType, OrderedType)
|
NamedType, OrderedType)
|
||||||
from .definitions import NonNull
|
from .definitions import NonNull
|
||||||
|
@ -89,9 +89,6 @@ class Field(NamedType, OrderedType):
|
||||||
return NonNull(self.type)
|
return NonNull(self.type)
|
||||||
return self.type
|
return self.type
|
||||||
|
|
||||||
def decorate_resolver(self, resolver):
|
|
||||||
return snake_case_args(resolver)
|
|
||||||
|
|
||||||
def internal_type(self, schema):
|
def internal_type(self, schema):
|
||||||
if not self.object_type:
|
if not self.object_type:
|
||||||
raise Exception('The field is not mounted in any ClassType')
|
raise Exception('The field is not mounted in any ClassType')
|
||||||
|
@ -118,10 +115,13 @@ class Field(NamedType, OrderedType):
|
||||||
resolver = wrapped_func
|
resolver = wrapped_func
|
||||||
|
|
||||||
assert type, 'Internal type for field %s is None' % str(self)
|
assert type, 'Internal type for field %s is None' % str(self)
|
||||||
return GraphQLField(type, args=schema.T(arguments),
|
return GraphQLField(
|
||||||
resolver=self.decorate_resolver(resolver),
|
type,
|
||||||
deprecation_reason=self.deprecation_reason,
|
args=schema.T(arguments),
|
||||||
description=description,)
|
resolver=schema.resolver_with_middleware(resolver),
|
||||||
|
deprecation_reason=self.deprecation_reason,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""
|
"""
|
||||||
|
@ -175,7 +175,8 @@ class InputField(NamedType, OrderedType):
|
||||||
def internal_type(self, schema):
|
def internal_type(self, schema):
|
||||||
return GraphQLInputObjectField(
|
return GraphQLInputObjectField(
|
||||||
schema.T(self.type),
|
schema.T(self.type),
|
||||||
default_value=self.default, description=self.description)
|
default_value=self.default, description=self.description
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FieldsGroupType(GroupNamedType):
|
class FieldsGroupType(GroupNamedType):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from pytest import raises
|
||||||
from graphene.core.schema import Schema
|
from graphene.core.schema import Schema
|
||||||
from graphene.core.types import ObjectType
|
from graphene.core.types import ObjectType
|
||||||
|
|
||||||
from ..argument import Argument, snake_case_args, to_arguments
|
from ..argument import Argument, to_arguments
|
||||||
from ..scalars import String
|
from ..scalars import String
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,10 +45,3 @@ def test_to_arguments_wrong_type():
|
||||||
p=3
|
p=3
|
||||||
)
|
)
|
||||||
assert 'Unknown argument p=3' == str(excinfo.value)
|
assert 'Unknown argument p=3' == str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
def test_snake_case_args():
|
|
||||||
def resolver(instance, args, context, info):
|
|
||||||
return args['my_arg']['inner_arg']
|
|
||||||
r = snake_case_args(resolver)
|
|
||||||
assert r(None, {'myArg': {'innerArg': 3}}, None, None) == 3
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ def test_field_internal_type():
|
||||||
assert field.attname == 'my_field'
|
assert field.attname == 'my_field'
|
||||||
assert isinstance(type, GraphQLField)
|
assert isinstance(type, GraphQLField)
|
||||||
assert type.description == 'My argument'
|
assert type.description == 'My argument'
|
||||||
assert type.resolver(None, {}, None, None) == 'RESOLVED'
|
assert type.resolver(None, {}, None, None).value == 'RESOLVED'
|
||||||
assert type.type == GraphQLString
|
assert type.type == GraphQLString
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ def test_field_objectype_resolver():
|
||||||
type = schema.T(field)
|
type = schema.T(field)
|
||||||
assert isinstance(type, GraphQLField)
|
assert isinstance(type, GraphQLField)
|
||||||
assert type.description == 'Custom description'
|
assert type.description == 'Custom description'
|
||||||
assert type.resolver(Query(), {}, None, None) == 'RESOLVED'
|
assert type.resolver(Query(), {}, None, None).value == 'RESOLVED'
|
||||||
|
|
||||||
|
|
||||||
def test_field_custom_name():
|
def test_field_custom_name():
|
||||||
|
@ -161,7 +161,7 @@ def test_field_resolve_argument():
|
||||||
schema = Schema(query=Query)
|
schema = Schema(query=Query)
|
||||||
|
|
||||||
type = schema.T(field)
|
type = schema.T(field)
|
||||||
assert type.resolver(None, {'firstName': 'Peter'}, None, None) == 'Peter'
|
assert type.resolver(None, {'firstName': 'Peter'}, None, None).value == 'Peter'
|
||||||
|
|
||||||
|
|
||||||
def test_field_resolve_vars():
|
def test_field_resolve_vars():
|
||||||
|
@ -216,7 +216,6 @@ def test_field_resolve_object():
|
||||||
att_func = field_func
|
att_func = field_func
|
||||||
|
|
||||||
assert field.resolver(Root, {}, None) is True
|
assert field.resolver(Root, {}, None) is True
|
||||||
assert field.resolver(Root, {}, None) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_field_resolve_source_object():
|
def test_field_resolve_source_object():
|
||||||
|
@ -235,4 +234,3 @@ def test_field_resolve_source_object():
|
||||||
att_func = field_func
|
att_func = field_func
|
||||||
|
|
||||||
assert field.resolver(Root, {}, None) is True
|
assert field.resolver(Root, {}, None) is True
|
||||||
assert field.resolver(Root, {}, None) is True
|
|
||||||
|
|
6
graphene/middlewares/__init__.py
Normal file
6
graphene/middlewares/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from .base import MiddlewareManager
|
||||||
|
from .camel_case import CamelCaseArgsMiddleware
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'MiddlewareManager', 'CamelCaseArgsMiddleware'
|
||||||
|
]
|
23
graphene/middlewares/base.py
Normal file
23
graphene/middlewares/base.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from ..utils import promise_middleware
|
||||||
|
|
||||||
|
MIDDLEWARE_RESOLVER_FUNCTION = 'resolve'
|
||||||
|
|
||||||
|
|
||||||
|
class MiddlewareManager(object):
|
||||||
|
|
||||||
|
def __init__(self, schema, middlewares=None):
|
||||||
|
self.schema = schema
|
||||||
|
self.middlewares = middlewares or []
|
||||||
|
|
||||||
|
def add_middleware(self, middleware):
|
||||||
|
self.middlewares.append(middleware)
|
||||||
|
|
||||||
|
def get_middleware_resolvers(self):
|
||||||
|
for middleware in self.middlewares:
|
||||||
|
if not hasattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION):
|
||||||
|
continue
|
||||||
|
yield getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION)
|
||||||
|
|
||||||
|
def wrap(self, resolver):
|
||||||
|
middleware_resolvers = self.get_middleware_resolvers()
|
||||||
|
return promise_middleware(resolver, middleware_resolvers)
|
8
graphene/middlewares/camel_case.py
Normal file
8
graphene/middlewares/camel_case.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from ..utils import ProxySnakeDict
|
||||||
|
|
||||||
|
|
||||||
|
class CamelCaseArgsMiddleware(object):
|
||||||
|
|
||||||
|
def resolve(self, next, root, args, context, info):
|
||||||
|
args = ProxySnakeDict(args)
|
||||||
|
return next(root, args, context, info)
|
|
@ -1,6 +0,0 @@
|
||||||
from .base import Plugin, PluginManager
|
|
||||||
from .camel_case import CamelCase
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'Plugin', 'PluginManager', 'CamelCase'
|
|
||||||
]
|
|
|
@ -1,53 +0,0 @@
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import partial, reduce
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(object):
|
|
||||||
|
|
||||||
def contribute_to_schema(self, schema):
|
|
||||||
self.schema = schema
|
|
||||||
|
|
||||||
|
|
||||||
def apply_function(a, b):
|
|
||||||
return b(a)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginManager(object):
|
|
||||||
|
|
||||||
PLUGIN_FUNCTIONS = ('get_default_namedtype_name', )
|
|
||||||
|
|
||||||
def __init__(self, schema, plugins=[]):
|
|
||||||
self.schema = schema
|
|
||||||
self.plugins = []
|
|
||||||
for plugin in plugins:
|
|
||||||
self.add_plugin(plugin)
|
|
||||||
|
|
||||||
def add_plugin(self, plugin):
|
|
||||||
if hasattr(plugin, 'contribute_to_schema'):
|
|
||||||
plugin.contribute_to_schema(self.schema)
|
|
||||||
self.plugins.append(plugin)
|
|
||||||
|
|
||||||
def get_plugin_functions(self, function):
|
|
||||||
for plugin in self.plugins:
|
|
||||||
if not hasattr(plugin, function):
|
|
||||||
continue
|
|
||||||
yield getattr(plugin, function)
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
functions = self.get_plugin_functions(name)
|
|
||||||
return partial(reduce, apply_function, functions)
|
|
||||||
|
|
||||||
def __contains__(self, name):
|
|
||||||
return name in self.PLUGIN_FUNCTIONS
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def context_execution(self, **executor):
|
|
||||||
contexts = []
|
|
||||||
functions = self.get_plugin_functions('context_execution')
|
|
||||||
for f in functions:
|
|
||||||
context = f(executor)
|
|
||||||
executor = context.__enter__()
|
|
||||||
contexts.append((context, executor))
|
|
||||||
yield executor
|
|
||||||
for context, value in contexts[::-1]:
|
|
||||||
context.__exit__(None, None, None)
|
|
|
@ -1,7 +0,0 @@
|
||||||
from ..utils import to_camel_case
|
|
||||||
|
|
||||||
|
|
||||||
class CamelCase(object):
|
|
||||||
|
|
||||||
def get_default_namedtype_name(self, value):
|
|
||||||
return to_camel_case(value)
|
|
|
@ -3,6 +3,7 @@ from .proxy_snake_dict import ProxySnakeDict
|
||||||
from .caching import cached_property, memoize
|
from .caching import cached_property, memoize
|
||||||
from .maybe_func import maybe_func
|
from .maybe_func import maybe_func
|
||||||
from .misc import enum_to_graphql_enum
|
from .misc import enum_to_graphql_enum
|
||||||
|
from .promise_middleware import promise_middleware
|
||||||
from .resolve_only_args import resolve_only_args
|
from .resolve_only_args import resolve_only_args
|
||||||
from .lazylist import LazyList
|
from .lazylist import LazyList
|
||||||
from .wrap_resolver_function import with_context, wrap_resolver_function
|
from .wrap_resolver_function import with_context, wrap_resolver_function
|
||||||
|
@ -10,5 +11,5 @@ from .wrap_resolver_function import with_context, wrap_resolver_function
|
||||||
|
|
||||||
__all__ = ['to_camel_case', 'to_snake_case', 'to_const', 'ProxySnakeDict',
|
__all__ = ['to_camel_case', 'to_snake_case', 'to_const', 'ProxySnakeDict',
|
||||||
'cached_property', 'memoize', 'maybe_func', 'enum_to_graphql_enum',
|
'cached_property', 'memoize', 'maybe_func', 'enum_to_graphql_enum',
|
||||||
'resolve_only_args', 'LazyList', 'with_context',
|
'promise_middleware', 'resolve_only_args', 'LazyList', 'with_context',
|
||||||
'wrap_resolver_function']
|
'wrap_resolver_function']
|
||||||
|
|
17
graphene/utils/promise_middleware.py
Normal file
17
graphene/utils/promise_middleware.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from functools import partial
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from promise import Promise
|
||||||
|
|
||||||
|
|
||||||
|
def promise_middleware(func, middlewares):
|
||||||
|
middlewares = chain((func, make_it_promise), middlewares)
|
||||||
|
past = None
|
||||||
|
for m in middlewares:
|
||||||
|
past = partial(m, past) if past else m
|
||||||
|
|
||||||
|
return past
|
||||||
|
|
||||||
|
|
||||||
|
def make_it_promise(next, *a, **b):
|
||||||
|
return Promise.resolve(next(*a, **b))
|
Loading…
Reference in New Issue
Block a user