Merge pull request #176 from graphql-python/features/middlewares

Added Middleware
This commit is contained in:
Syrus Akbary 2016-05-21 00:15:42 -07:00
commit 3a1093af24
26 changed files with 238 additions and 213 deletions

View File

@ -17,6 +17,7 @@ ga = "UA-12613282-7"
"/docs/basic-types/", "/docs/basic-types/",
"/docs/enums/", "/docs/enums/",
"/docs/relay/", "/docs/relay/",
"/docs/middleware/",
] ]
[docs.django] [docs.django]

View File

@ -13,14 +13,19 @@ For that, you will need to add the plugin in your graphene schema.
## Installation ## Installation
For use the Django Debug plugin in Graphene, just import `DjangoDebugPlugin` and add it to the `plugins` argument when you initiate the `Schema`. For use the Django Debug plugin in Graphene:
* Import `DjangoDebugMiddleware` and add it to the `middleware` argument when you initiate the `Schema`.
* Add the `debug` field into the schema root `Query` with the value `graphene.Field(DjangoDebug, name='__debug')`.
```python ```python
from graphene.contrib.django.debug import DjangoDebugPlugin from graphene.contrib.django.debug import DjangoDebugMiddleware, DjangoDebug
# ... class Query(graphene.ObjectType):
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) # ...
debug = graphene.Field(DjangoDebug, name='__debug')
schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
``` ```
This plugin, will add another field in the `Query` named `__debug`. This plugin, will add another field in the `Query` named `__debug`.

View File

@ -0,0 +1,43 @@
---
title: Middleware
description: Walkthrough Middleware
---
# Middleware
You can use _middleware_ to affect the evaluation of fields in your schema.
A middleware is any object that responds to `resolve(*args, next_middleware)`. Inside that method, it should either:
* Send `resolve` to the next middleware to continue the evaluation; or
* Return a value to end the evaluation early.
Middlewares' `resolve` is invoked with several arguments:
* `next` represents the execution chain. Call `next` to continue evalution.
* `root` is the root value object passed throughout the query
* `args` is the hash of arguments passed to the field
* `context` is the context object passed throughout the query
* `info` is the resolver info
Add a middleware to a schema by adding to the `middlewares` list.
### Example: Authorization
This middleware only continues evaluation if the `field_name` is not `'user'`:
```python
class AuthorizationMiddleware(object):
def resolve(self, next, root, args, context, info):
if info.field_name == 'user':
return None
return next(root, args, context, info)
```
Then, add the middleware to your schema:
```python
schema = Schema(middlewares=[AuthorizationMiddleware])
```

View File

@ -1,4 +1,4 @@
from .plugin import DjangoDebugPlugin from .middleware import DjangoDebugMiddleware
from .types import DjangoDebug from .types import DjangoDebug
__all__ = ['DjangoDebugPlugin', 'DjangoDebug'] __all__ = ['DjangoDebugMiddleware', 'DjangoDebug']

View File

@ -0,0 +1,56 @@
from promise import Promise
from django.db import connections
from .sql.tracking import unwrap_cursor, wrap_cursor
from .types import DjangoDebug
class DjangoDebugContext(object):
def __init__(self):
self.debug_promise = None
self.promises = []
self.enable_instrumentation()
self.object = DjangoDebug(sql=[])
def get_debug_promise(self):
if not self.debug_promise:
self.debug_promise = Promise.all(self.promises)
return self.debug_promise.then(self.on_resolve_all_promises)
def on_resolve_all_promises(self, values):
self.disable_instrumentation()
return self.object
def add_promise(self, promise):
if self.debug_promise and not self.debug_promise.is_fulfilled:
self.promises.append(promise)
def enable_instrumentation(self):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, self)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
class DjangoDebugMiddleware(object):
def resolve(self, next, root, args, context, info):
django_debug = getattr(context, 'django_debug', None)
if not django_debug:
if context is None:
raise Exception('DjangoDebug cannot be executed in None contexts')
try:
context.django_debug = DjangoDebugContext()
except Exception:
raise Exception('DjangoDebug need the context to be writable, context received: {}.'.format(
context.__class__.__name__
))
if info.schema.graphene_schema.T(DjangoDebug) == info.return_type:
return context.django_debug.get_debug_promise()
promise = next(root, args, context, info)
context.django_debug.add_promise(promise)
return promise

View File

@ -1,79 +0,0 @@
from contextlib import contextmanager
from django.db import connections
from ....core.schema import GraphQLSchema
from ....core.types import Field
from ....plugins import Plugin
from .sql.tracking import unwrap_cursor, wrap_cursor
from .sql.types import DjangoDebugSQL
from .types import DjangoDebug
class WrappedRoot(object):
def __init__(self, root):
self._recorded = []
self._root = root
def record(self, **log):
self._recorded.append(DjangoDebugSQL(**log))
def debug(self):
return DjangoDebug(sql=self._recorded)
class WrapRoot(object):
@property
def _root(self):
return self._wrapped_root.root
@_root.setter
def _root(self, value):
self._wrapped_root = value
def resolve_debug(self, args, info):
return self._wrapped_root.debug()
def debug_objecttype(objecttype):
return type(
'Debug{}'.format(objecttype._meta.type_name),
(WrapRoot, objecttype),
{'debug': Field(DjangoDebug, name='__debug')})
class DjangoDebugPlugin(Plugin):
def enable_instrumentation(self, wrapped_root):
# This is thread-safe because database connections are thread-local.
for connection in connections.all():
wrap_cursor(connection, wrapped_root)
def disable_instrumentation(self):
for connection in connections.all():
unwrap_cursor(connection)
def wrap_schema(self, schema_type):
query = schema_type._query
if query:
class_type = self.schema.objecttype(schema_type.get_query_type())
assert class_type, 'The query in schema is not constructed with graphene'
_type = debug_objecttype(class_type)
self.schema.register(_type, force=True)
return GraphQLSchema(
self.schema,
self.schema.T(_type),
schema_type.get_mutation_type(),
schema_type.get_subscription_type()
)
return schema_type
@contextmanager
def context_execution(self, executor):
executor['root_value'] = WrappedRoot(root=executor.get('root_value'))
executor['schema'] = self.wrap_schema(executor['schema'])
self.enable_instrumentation(executor['root_value'])
yield executor
self.disable_instrumentation()

View File

@ -8,6 +8,8 @@ from time import time
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from .types import DjangoDebugSQL, DjangoDebugPostgreSQL
class SQLQueryTriggered(Exception): class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query""" """Thrown when template panel triggers a query"""
@ -139,9 +141,11 @@ class NormalCursorWrapper(object):
'iso_level': iso_level, 'iso_level': iso_level,
'encoding': conn.encoding, 'encoding': conn.encoding,
}) })
_sql = DjangoDebugPostgreSQL(**params)
else:
_sql = DjangoDebugSQL(**params)
# We keep `sql` to maintain backwards compatibility # We keep `sql` to maintain backwards compatibility
self.logger.record(**params) self.logger.object.sql.append(_sql)
def callproc(self, procname, params=()): def callproc(self, procname, params=()):
return self._record(self.cursor.callproc, procname, params) return self._record(self.cursor.callproc, procname, params)

View File

@ -1,7 +1,7 @@
from .....core import Boolean, Float, ObjectType, String from .....core import Boolean, Float, ObjectType, String
class DjangoDebugSQL(ObjectType): class DjangoDebugBaseSQL(ObjectType):
vendor = String() vendor = String()
alias = String() alias = String()
sql = String() sql = String()
@ -13,6 +13,12 @@ class DjangoDebugSQL(ObjectType):
is_slow = Boolean() is_slow = Boolean()
is_select = Boolean() is_select = Boolean()
class DjangoDebugSQL(DjangoDebugBaseSQL):
pass
class DjangoDebugPostgreSQL(DjangoDebugBaseSQL):
trans_id = String() trans_id = String()
trans_status = String() trans_status = String()
iso_level = String() iso_level = String()

View File

@ -5,7 +5,12 @@ from graphene.contrib.django import DjangoConnectionField, DjangoNode
from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED from graphene.contrib.django.utils import DJANGO_FILTER_INSTALLED
from ...tests.models import Reporter from ...tests.models import Reporter
from ..plugin import DjangoDebugPlugin from ..middleware import DjangoDebugMiddleware
from ..types import DjangoDebug
class context(object):
pass
# from examples.starwars_django.models import Character # from examples.starwars_django.models import Character
@ -25,6 +30,7 @@ def test_should_query_field():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first() return Reporter.objects.first()
@ -51,8 +57,8 @@ def test_should_query_field():
}] }]
} }
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -70,6 +76,7 @@ def test_should_query_list():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = ReporterType.List() all_reporters = ReporterType.List()
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all() return Reporter.objects.all()
@ -98,8 +105,8 @@ def test_should_query_list():
}] }]
} }
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -117,6 +124,7 @@ def test_should_query_connection():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all() return Reporter.objects.all()
@ -146,8 +154,8 @@ def test_should_query_connection():
}] }]
}, },
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']
@ -172,6 +180,7 @@ def test_should_query_connectionfilter():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoFilterConnectionField(ReporterType) all_reporters = DjangoFilterConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, *args, **kwargs):
return Reporter.objects.all() return Reporter.objects.all()
@ -201,8 +210,8 @@ def test_should_query_connectionfilter():
}] }]
}, },
} }
schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) schema = graphene.Schema(query=Query, middlewares=[DjangoDebugMiddleware()])
result = schema.execute(query) result = schema.execute(query, context_value=context())
assert not result.errors assert not result.errors
assert result.data['allReporters'] == expected['allReporters'] assert result.data['allReporters'] == expected['allReporters']
assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql'] assert 'COUNT' in result.data['__debug']['sql'][0]['rawSql']

View File

@ -1,7 +1,7 @@
from ....core.classtypes.objecttype import ObjectType from ....core.classtypes.objecttype import ObjectType
from ....core.types import Field from ....core.types import Field
from .sql.types import DjangoDebugSQL from .sql.types import DjangoDebugBaseSQL
class DjangoDebug(ObjectType): class DjangoDebug(ObjectType):
sql = Field(DjangoDebugSQL.List()) sql = Field(DjangoDebugBaseSQL.List())

View File

@ -7,7 +7,7 @@ from graphql.utils.schema_printer import print_schema
from graphene import signals from graphene import signals
from ..plugins import CamelCase, PluginManager from ..middlewares import MiddlewareManager, CamelCaseArgsMiddleware
from .classtypes.base import ClassType from .classtypes.base import ClassType
from .types.base import InstanceType from .types.base import InstanceType
@ -23,7 +23,7 @@ class Schema(object):
_executor = None _executor = None
def __init__(self, query=None, mutation=None, subscription=None, def __init__(self, query=None, mutation=None, subscription=None,
name='Schema', executor=None, plugins=None, auto_camelcase=True, **options): name='Schema', executor=None, middlewares=None, auto_camelcase=True, **options):
self._types_names = {} self._types_names = {}
self._types = {} self._types = {}
self.mutation = mutation self.mutation = mutation
@ -31,21 +31,19 @@ class Schema(object):
self.subscription = subscription self.subscription = subscription
self.name = name self.name = name
self.executor = executor self.executor = executor
plugins = plugins or [] if 'plugins' in options:
raise Exception('Plugins are deprecated, please use middlewares.')
middlewares = middlewares or []
if auto_camelcase: if auto_camelcase:
plugins.append(CamelCase()) middlewares.append(CamelCaseArgsMiddleware())
self.plugins = PluginManager(self, plugins) self.auto_camelcase = auto_camelcase
self.middleware_manager = MiddlewareManager(self, middlewares)
self.options = options self.options = options
signals.init_schema.send(self) signals.init_schema.send(self)
def __repr__(self): def __repr__(self):
return '<Schema: %s (%s)>' % (str(self.name), hash(self)) return '<Schema: %s (%s)>' % (str(self.name), hash(self))
def __getattr__(self, name):
if name in self.plugins:
return getattr(self.plugins, name)
return super(Schema, self).__getattr__(name)
def T(self, _type): def T(self, _type):
if not _type: if not _type:
return return
@ -111,13 +109,16 @@ class Schema(object):
raise KeyError('Type %r not found in %r' % (type_name, self)) raise KeyError('Type %r not found in %r' % (type_name, self))
return self._types_names[type_name] return self._types_names[type_name]
def resolver_with_middleware(self, resolver):
return self.middleware_manager.wrap(resolver)
@property @property
def types(self): def types(self):
return self._types_names return self._types_names
def execute(self, request_string='', root_value=None, variable_values=None, def execute(self, request_string='', root_value=None, variable_values=None,
context_value=None, operation_name=None, executor=None): context_value=None, operation_name=None, executor=None):
kwargs = dict( return graphql(
schema=self.schema, schema=self.schema,
request_string=request_string, request_string=request_string,
root_value=root_value, root_value=root_value,
@ -126,8 +127,6 @@ class Schema(object):
operation_name=operation_name, operation_name=operation_name,
executor=executor or self._executor executor=executor or self._executor
) )
with self.plugins.context_execution(**kwargs) as execute_kwargs:
return graphql(**execute_kwargs)
def introspect(self): def introspect(self):
return graphql(self.schema, introspection_query).data return graphql(self.schema, introspection_query).data

View File

@ -93,7 +93,7 @@ def test_field_resolve():
f = StringField(required=True, resolve=lambda *args: 'RESOLVED').as_field() f = StringField(required=True, resolve=lambda *args: 'RESOLVED').as_field()
f.contribute_to_class(MyOt, 'field_name') f.contribute_to_class(MyOt, 'field_name')
field_type = schema.T(f) field_type = schema.T(f)
assert 'RESOLVED' == field_type.resolver(MyOt, None, None, None) assert 'RESOLVED' == field_type.resolver(MyOt, None, None, None).value
def test_field_resolve_type_custom(): def test_field_resolve_type_custom():

View File

@ -154,6 +154,12 @@ def test_lazytype():
assert schema.T(t) == schema.T(MyType) assert schema.T(t) == schema.T(MyType)
def test_deprecated_plugins_throws_exception():
with raises(Exception) as excinfo:
Schema(plugins=[])
assert 'Plugins are deprecated, please use middlewares' in str(excinfo.value)
def test_schema_str(): def test_schema_str():
expected = """ expected = """
schema { schema {

View File

@ -1,9 +1,7 @@
from functools import wraps
from itertools import chain from itertools import chain
from graphql.type import GraphQLArgument from graphql.type import GraphQLArgument
from ...utils import ProxySnakeDict
from .base import ArgumentType, GroupNamedType, NamedType, OrderedType from .base import ArgumentType, GroupNamedType, NamedType, OrderedType
@ -53,11 +51,3 @@ def to_arguments(*args, **kwargs):
arguments[name] = argument arguments[name] = argument
return sorted(arguments.values()) return sorted(arguments.values())
def snake_case_args(resolver):
@wraps(resolver)
def wrapped_resolver(instance, args, context, info):
return resolver(instance, ProxySnakeDict(args), context, info)
return wrapped_resolver

View File

@ -3,6 +3,8 @@ from functools import partial, total_ordering
import six import six
from ...utils import to_camel_case
class InstanceType(object): class InstanceType(object):
@ -142,7 +144,9 @@ class GroupNamedType(InstanceType):
self.types = types self.types = types
def get_named_type(self, schema, type): def get_named_type(self, schema, type):
name = type.name or schema.get_default_namedtype_name(type.default_name) name = type.name
if not name and schema.auto_camelcase:
name = to_camel_case(type.default_name)
return name, schema.T(type) return name, schema.T(type)
def iter_types(self, schema): def iter_types(self, schema):

View File

@ -10,7 +10,7 @@ from ..classtypes.base import FieldsClassType
from ..classtypes.inputobjecttype import InputObjectType from ..classtypes.inputobjecttype import InputObjectType
from ..classtypes.mutation import Mutation from ..classtypes.mutation import Mutation
from ..exceptions import SkipField from ..exceptions import SkipField
from .argument import Argument, ArgumentsGroup, snake_case_args from .argument import Argument, ArgumentsGroup
from .base import (ArgumentType, GroupNamedType, LazyType, MountType, from .base import (ArgumentType, GroupNamedType, LazyType, MountType,
NamedType, OrderedType) NamedType, OrderedType)
from .definitions import NonNull from .definitions import NonNull
@ -89,9 +89,6 @@ class Field(NamedType, OrderedType):
return NonNull(self.type) return NonNull(self.type)
return self.type return self.type
def decorate_resolver(self, resolver):
return snake_case_args(resolver)
def internal_type(self, schema): def internal_type(self, schema):
if not self.object_type: if not self.object_type:
raise Exception('The field is not mounted in any ClassType') raise Exception('The field is not mounted in any ClassType')
@ -118,10 +115,13 @@ class Field(NamedType, OrderedType):
resolver = wrapped_func resolver = wrapped_func
assert type, 'Internal type for field %s is None' % str(self) assert type, 'Internal type for field %s is None' % str(self)
return GraphQLField(type, args=schema.T(arguments), return GraphQLField(
resolver=self.decorate_resolver(resolver), type,
deprecation_reason=self.deprecation_reason, args=schema.T(arguments),
description=description,) resolver=schema.resolver_with_middleware(resolver),
deprecation_reason=self.deprecation_reason,
description=description,
)
def __repr__(self): def __repr__(self):
""" """
@ -175,7 +175,8 @@ class InputField(NamedType, OrderedType):
def internal_type(self, schema): def internal_type(self, schema):
return GraphQLInputObjectField( return GraphQLInputObjectField(
schema.T(self.type), schema.T(self.type),
default_value=self.default, description=self.description) default_value=self.default, description=self.description
)
class FieldsGroupType(GroupNamedType): class FieldsGroupType(GroupNamedType):

View File

@ -4,7 +4,7 @@ from pytest import raises
from graphene.core.schema import Schema from graphene.core.schema import Schema
from graphene.core.types import ObjectType from graphene.core.types import ObjectType
from ..argument import Argument, snake_case_args, to_arguments from ..argument import Argument, to_arguments
from ..scalars import String from ..scalars import String
@ -45,10 +45,3 @@ def test_to_arguments_wrong_type():
p=3 p=3
) )
assert 'Unknown argument p=3' == str(excinfo.value) assert 'Unknown argument p=3' == str(excinfo.value)
def test_snake_case_args():
def resolver(instance, args, context, info):
return args['my_arg']['inner_arg']
r = snake_case_args(resolver)
assert r(None, {'myArg': {'innerArg': 3}}, None, None) == 3

View File

@ -24,7 +24,7 @@ def test_field_internal_type():
assert field.attname == 'my_field' assert field.attname == 'my_field'
assert isinstance(type, GraphQLField) assert isinstance(type, GraphQLField)
assert type.description == 'My argument' assert type.description == 'My argument'
assert type.resolver(None, {}, None, None) == 'RESOLVED' assert type.resolver(None, {}, None, None).value == 'RESOLVED'
assert type.type == GraphQLString assert type.type == GraphQLString
@ -43,7 +43,7 @@ def test_field_objectype_resolver():
type = schema.T(field) type = schema.T(field)
assert isinstance(type, GraphQLField) assert isinstance(type, GraphQLField)
assert type.description == 'Custom description' assert type.description == 'Custom description'
assert type.resolver(Query(), {}, None, None) == 'RESOLVED' assert type.resolver(Query(), {}, None, None).value == 'RESOLVED'
def test_field_custom_name(): def test_field_custom_name():
@ -161,7 +161,7 @@ def test_field_resolve_argument():
schema = Schema(query=Query) schema = Schema(query=Query)
type = schema.T(field) type = schema.T(field)
assert type.resolver(None, {'firstName': 'Peter'}, None, None) == 'Peter' assert type.resolver(None, {'firstName': 'Peter'}, None, None).value == 'Peter'
def test_field_resolve_vars(): def test_field_resolve_vars():
@ -216,7 +216,6 @@ def test_field_resolve_object():
att_func = field_func att_func = field_func
assert field.resolver(Root, {}, None) is True assert field.resolver(Root, {}, None) is True
assert field.resolver(Root, {}, None) is True
def test_field_resolve_source_object(): def test_field_resolve_source_object():
@ -235,4 +234,3 @@ def test_field_resolve_source_object():
att_func = field_func att_func = field_func
assert field.resolver(Root, {}, None) is True assert field.resolver(Root, {}, None) is True
assert field.resolver(Root, {}, None) is True

View File

@ -0,0 +1,6 @@
from .base import MiddlewareManager
from .camel_case import CamelCaseArgsMiddleware
__all__ = [
'MiddlewareManager', 'CamelCaseArgsMiddleware'
]

View File

@ -0,0 +1,23 @@
from ..utils import promise_middleware
MIDDLEWARE_RESOLVER_FUNCTION = 'resolve'
class MiddlewareManager(object):
def __init__(self, schema, middlewares=None):
self.schema = schema
self.middlewares = middlewares or []
def add_middleware(self, middleware):
self.middlewares.append(middleware)
def get_middleware_resolvers(self):
for middleware in self.middlewares:
if not hasattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION):
continue
yield getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION)
def wrap(self, resolver):
middleware_resolvers = self.get_middleware_resolvers()
return promise_middleware(resolver, middleware_resolvers)

View File

@ -0,0 +1,8 @@
from ..utils import ProxySnakeDict
class CamelCaseArgsMiddleware(object):
def resolve(self, next, root, args, context, info):
args = ProxySnakeDict(args)
return next(root, args, context, info)

View File

@ -1,6 +0,0 @@
from .base import Plugin, PluginManager
from .camel_case import CamelCase
__all__ = [
'Plugin', 'PluginManager', 'CamelCase'
]

View File

@ -1,53 +0,0 @@
from contextlib import contextmanager
from functools import partial, reduce
class Plugin(object):
def contribute_to_schema(self, schema):
self.schema = schema
def apply_function(a, b):
return b(a)
class PluginManager(object):
PLUGIN_FUNCTIONS = ('get_default_namedtype_name', )
def __init__(self, schema, plugins=[]):
self.schema = schema
self.plugins = []
for plugin in plugins:
self.add_plugin(plugin)
def add_plugin(self, plugin):
if hasattr(plugin, 'contribute_to_schema'):
plugin.contribute_to_schema(self.schema)
self.plugins.append(plugin)
def get_plugin_functions(self, function):
for plugin in self.plugins:
if not hasattr(plugin, function):
continue
yield getattr(plugin, function)
def __getattr__(self, name):
functions = self.get_plugin_functions(name)
return partial(reduce, apply_function, functions)
def __contains__(self, name):
return name in self.PLUGIN_FUNCTIONS
@contextmanager
def context_execution(self, **executor):
contexts = []
functions = self.get_plugin_functions('context_execution')
for f in functions:
context = f(executor)
executor = context.__enter__()
contexts.append((context, executor))
yield executor
for context, value in contexts[::-1]:
context.__exit__(None, None, None)

View File

@ -1,7 +0,0 @@
from ..utils import to_camel_case
class CamelCase(object):
def get_default_namedtype_name(self, value):
return to_camel_case(value)

View File

@ -3,6 +3,7 @@ from .proxy_snake_dict import ProxySnakeDict
from .caching import cached_property, memoize from .caching import cached_property, memoize
from .maybe_func import maybe_func from .maybe_func import maybe_func
from .misc import enum_to_graphql_enum from .misc import enum_to_graphql_enum
from .promise_middleware import promise_middleware
from .resolve_only_args import resolve_only_args from .resolve_only_args import resolve_only_args
from .lazylist import LazyList from .lazylist import LazyList
from .wrap_resolver_function import with_context, wrap_resolver_function from .wrap_resolver_function import with_context, wrap_resolver_function
@ -10,5 +11,5 @@ from .wrap_resolver_function import with_context, wrap_resolver_function
__all__ = ['to_camel_case', 'to_snake_case', 'to_const', 'ProxySnakeDict', __all__ = ['to_camel_case', 'to_snake_case', 'to_const', 'ProxySnakeDict',
'cached_property', 'memoize', 'maybe_func', 'enum_to_graphql_enum', 'cached_property', 'memoize', 'maybe_func', 'enum_to_graphql_enum',
'resolve_only_args', 'LazyList', 'with_context', 'promise_middleware', 'resolve_only_args', 'LazyList', 'with_context',
'wrap_resolver_function'] 'wrap_resolver_function']

View File

@ -0,0 +1,17 @@
from functools import partial
from itertools import chain
from promise import Promise
def promise_middleware(func, middlewares):
middlewares = chain((func, make_it_promise), middlewares)
past = None
for m in middlewares:
past = partial(m, past) if past else m
return past
def make_it_promise(next, *a, **b):
return Promise.resolve(next(*a, **b))