diff --git a/graphene/contrib/django/debug/__init__.py b/graphene/contrib/django/debug/__init__.py index 0636c3fa..4c76aeca 100644 --- a/graphene/contrib/django/debug/__init__.py +++ b/graphene/contrib/django/debug/__init__.py @@ -1,4 +1,4 @@ -from .schema import DebugSchema +from .plugin import DjangoDebugPlugin from .types import DjangoDebug -__all__ = ['DebugSchema', 'DjangoDebug'] +__all__ = ['DjangoDebugPlugin', 'DjangoDebug'] diff --git a/graphene/contrib/django/debug/schema.py b/graphene/contrib/django/debug/plugin.py similarity index 74% rename from graphene/contrib/django/debug/schema.py rename to graphene/contrib/django/debug/plugin.py index e5e5f30b..5c21863f 100644 --- a/graphene/contrib/django/debug/schema.py +++ b/graphene/contrib/django/debug/plugin.py @@ -1,5 +1,7 @@ +from contextlib import contextmanager from django.db import connections +from ....plugins import Plugin from ....core.schema import Schema from ....core.types import Field from .sql.tracking import unwrap_cursor, wrap_cursor @@ -41,15 +43,11 @@ def debug_objecttype(objecttype): {'debug': Field(DjangoDebug, name='__debug')}) -class DebugSchema(Schema): - - @property - def query(self): - return self._query - - @query.setter - def query(self, value): - self._query = value and debug_objecttype(value) +class DjangoDebugPlugin(Plugin): + def transform_type(self, _type): + if _type == self.schema.query: + return debug_objecttype(_type) + return _type def enable_instrumentation(self, wrapped_root): # This is thread-safe because database connections are thread-local. @@ -60,9 +58,9 @@ class DebugSchema(Schema): for connection in connections.all(): unwrap_cursor(connection) - def execute(self, query, root=None, *args, **kwargs): - wrapped_root = WrappedRoot(root=root) - self.enable_instrumentation(wrapped_root) - result = super(DebugSchema, self).execute(query, wrapped_root, *args, **kwargs) + @contextmanager + def context_execution(self, executor): + executor['root'] = WrappedRoot(root=executor['root']) + self.enable_instrumentation(executor['root']) + yield executor self.disable_instrumentation() - return result diff --git a/graphene/contrib/django/debug/tests/test_query.py b/graphene/contrib/django/debug/tests/test_query.py index 3b4de477..4df26e4f 100644 --- a/graphene/contrib/django/debug/tests/test_query.py +++ b/graphene/contrib/django/debug/tests/test_query.py @@ -4,7 +4,7 @@ import graphene from graphene.contrib.django import DjangoObjectType from ...tests.models import Reporter -from ..schema import DebugSchema +from ..plugin import DjangoDebugPlugin # from examples.starwars_django.models import Character @@ -24,7 +24,7 @@ def test_should_query_well(): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - all_reporters = ReporterType.List + all_reporters = ReporterType.List() def resolve_all_reporters(self, *args, **kwargs): return Reporter.objects.all() @@ -64,7 +64,7 @@ def test_should_query_well(): }] } } - schema = DebugSchema(query=Query) + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) result = schema.execute(query) assert not result.errors assert result.data == expected diff --git a/graphene/contrib/django/debug/types.py b/graphene/contrib/django/debug/types.py index d84cbb98..bceb54b0 100644 --- a/graphene/contrib/django/debug/types.py +++ b/graphene/contrib/django/debug/types.py @@ -4,4 +4,4 @@ from .sql.types import DjangoDebugSQL class DjangoDebug(ObjectType): - sql = Field(DjangoDebugSQL.List) + sql = Field(DjangoDebugSQL.List()) diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 4da067e4..e8995426 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -127,17 +127,25 @@ class Schema(object): def types(self): return self._types_names - def execute(self, request='', root=None, vars=None, - operation_name=None, **kwargs): - root = root or object() - return self.executor.execute( + def execute(self, request='', root=None, args=None, **kwargs): + executor = kwargs + executor['root'] = root + executor['args'] = args + contexts = [] + for plugin in self.plugins: + if not hasattr(plugin, 'context_execution'): + continue + context = plugin.context_execution(executor) + executor = context.__enter__() + contexts.append((context, executor)) + result = self.executor.execute( self.schema, request, - root=root, - args=vars, - operation_name=operation_name, - **kwargs + **executor ) + for context, value in contexts[::-1]: + context.__exit__(None, None, None) + return result def introspect(self): return self.execute(introspection_query).data