diff --git a/graphene/contrib/django/debug/__init__.py b/graphene/contrib/django/debug/__init__.py new file mode 100644 index 00000000..4c76aeca --- /dev/null +++ b/graphene/contrib/django/debug/__init__.py @@ -0,0 +1,4 @@ +from .plugin import DjangoDebugPlugin +from .types import DjangoDebug + +__all__ = ['DjangoDebugPlugin', 'DjangoDebug'] diff --git a/graphene/contrib/django/debug/plugin.py b/graphene/contrib/django/debug/plugin.py new file mode 100644 index 00000000..4b6b1b5d --- /dev/null +++ b/graphene/contrib/django/debug/plugin.py @@ -0,0 +1,77 @@ +from contextlib import contextmanager + +from django.db import connections + +from ....core.types import Field +from ....plugins import Plugin +from .sql.tracking import unwrap_cursor, wrap_cursor +from .sql.types import DjangoDebugSQL +from .types import DjangoDebug + + +class WrappedRoot(object): + + def __init__(self, root): + self._recorded = [] + self._root = root + + def record(self, **log): + self._recorded.append(DjangoDebugSQL(**log)) + + def debug(self): + return DjangoDebug(sql=self._recorded) + + +class WrapRoot(object): + + @property + def _root(self): + return self._wrapped_root.root + + @_root.setter + def _root(self, value): + self._wrapped_root = value + + def resolve_debug(self, args, info): + return self._wrapped_root.debug() + + +def debug_objecttype(objecttype): + return type( + 'Debug{}'.format(objecttype._meta.type_name), + (WrapRoot, objecttype), + {'debug': Field(DjangoDebug, name='__debug')}) + + +class DjangoDebugPlugin(Plugin): + + def transform_type(self, _type): + if _type == self.schema.query: + return + return _type + + def enable_instrumentation(self, wrapped_root): + # This is thread-safe because database connections are thread-local. + for connection in connections.all(): + wrap_cursor(connection, wrapped_root) + + def disable_instrumentation(self): + for connection in connections.all(): + unwrap_cursor(connection) + + def wrap_schema(self, schema_type): + query = schema_type._query + if query: + class_type = self.schema.objecttype(schema_type._query) + assert class_type, 'The query in schema is not constructed with graphene' + _type = debug_objecttype(class_type) + schema_type._query = self.schema.T(_type) + return schema_type + + @contextmanager + def context_execution(self, executor): + executor['root'] = WrappedRoot(root=executor['root']) + executor['schema'] = self.wrap_schema(executor['schema']) + self.enable_instrumentation(executor['root']) + yield executor + self.disable_instrumentation() diff --git a/graphene/contrib/django/debug/sql/__init__.py b/graphene/contrib/django/debug/sql/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/debug/sql/tracking.py b/graphene/contrib/django/debug/sql/tracking.py new file mode 100644 index 00000000..8ed40492 --- /dev/null +++ b/graphene/contrib/django/debug/sql/tracking.py @@ -0,0 +1,165 @@ +# Code obtained from django-debug-toolbar sql panel tracking +from __future__ import absolute_import, unicode_literals + +import json +from threading import local +from time import time + +from django.utils import six +from django.utils.encoding import force_text + + +class SQLQueryTriggered(Exception): + """Thrown when template panel triggers a query""" + + +class ThreadLocalState(local): + + def __init__(self): + self.enabled = True + + @property + def Wrapper(self): + if self.enabled: + return NormalCursorWrapper + return ExceptionCursorWrapper + + def recording(self, v): + self.enabled = v + + +state = ThreadLocalState() +recording = state.recording # export function + + +def wrap_cursor(connection, panel): + if not hasattr(connection, '_djdt_cursor'): + connection._djdt_cursor = connection.cursor + + def cursor(): + return state.Wrapper(connection._djdt_cursor(), connection, panel) + + connection.cursor = cursor + return cursor + + +def unwrap_cursor(connection): + if hasattr(connection, '_djdt_cursor'): + del connection._djdt_cursor + del connection.cursor + + +class ExceptionCursorWrapper(object): + """ + Wraps a cursor and raises an exception on any operation. + Used in Templates panel. + """ + + def __init__(self, cursor, db, logger): + pass + + def __getattr__(self, attr): + raise SQLQueryTriggered() + + +class NormalCursorWrapper(object): + """ + Wraps a cursor and logs queries. + """ + + def __init__(self, cursor, db, logger): + self.cursor = cursor + # Instance of a BaseDatabaseWrapper subclass + self.db = db + # logger must implement a ``record`` method + self.logger = logger + + def _quote_expr(self, element): + if isinstance(element, six.string_types): + return "'%s'" % force_text(element).replace("'", "''") + else: + return repr(element) + + def _quote_params(self, params): + if not params: + return params + if isinstance(params, dict): + return dict((key, self._quote_expr(value)) + for key, value in params.items()) + return list(map(self._quote_expr, params)) + + def _decode(self, param): + try: + return force_text(param, strings_only=True) + except UnicodeDecodeError: + return '(encoded string)' + + def _record(self, method, sql, params): + start_time = time() + try: + return method(sql, params) + finally: + stop_time = time() + duration = (stop_time - start_time) + _params = '' + try: + _params = json.dumps(list(map(self._decode, params))) + except Exception: + pass # object not JSON serializable + + alias = getattr(self.db, 'alias', 'default') + conn = self.db.connection + vendor = getattr(conn, 'vendor', 'unknown') + + params = { + 'vendor': vendor, + 'alias': alias, + 'sql': self.db.ops.last_executed_query( + self.cursor, sql, self._quote_params(params)), + 'duration': duration, + 'raw_sql': sql, + 'params': _params, + 'start_time': start_time, + 'stop_time': stop_time, + 'is_slow': duration > 10, + 'is_select': sql.lower().strip().startswith('select'), + } + + if vendor == 'postgresql': + # If an erroneous query was ran on the connection, it might + # be in a state where checking isolation_level raises an + # exception. + try: + iso_level = conn.isolation_level + except conn.InternalError: + iso_level = 'unknown' + params.update({ + 'trans_id': self.logger.get_transaction_id(alias), + 'trans_status': conn.get_transaction_status(), + 'iso_level': iso_level, + 'encoding': conn.encoding, + }) + + # We keep `sql` to maintain backwards compatibility + self.logger.record(**params) + + def callproc(self, procname, params=()): + return self._record(self.cursor.callproc, procname, params) + + def execute(self, sql, params=()): + return self._record(self.cursor.execute, sql, params) + + def executemany(self, sql, param_list): + return self._record(self.cursor.executemany, sql, param_list) + + def __getattr__(self, attr): + return getattr(self.cursor, attr) + + def __iter__(self): + return iter(self.cursor) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() diff --git a/graphene/contrib/django/debug/sql/types.py b/graphene/contrib/django/debug/sql/types.py new file mode 100644 index 00000000..7ee1894f --- /dev/null +++ b/graphene/contrib/django/debug/sql/types.py @@ -0,0 +1,19 @@ +from .....core import Float, ObjectType, String + + +class DjangoDebugSQL(ObjectType): + vendor = String() + alias = String() + sql = String() + duration = Float() + raw_sql = String() + params = String() + start_time = Float() + stop_time = Float() + is_slow = String() + is_select = String() + + trans_id = String() + trans_status = String() + iso_level = String() + encoding = String() diff --git a/graphene/contrib/django/debug/tests/__init__.py b/graphene/contrib/django/debug/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene/contrib/django/debug/tests/test_query.py b/graphene/contrib/django/debug/tests/test_query.py new file mode 100644 index 00000000..4df26e4f --- /dev/null +++ b/graphene/contrib/django/debug/tests/test_query.py @@ -0,0 +1,70 @@ +import pytest + +import graphene +from graphene.contrib.django import DjangoObjectType + +from ...tests.models import Reporter +from ..plugin import DjangoDebugPlugin + +# from examples.starwars_django.models import Character + +pytestmark = pytest.mark.django_db + + +def test_should_query_well(): + r1 = Reporter(last_name='ABA') + r1.save() + r2 = Reporter(last_name='Griffin') + r2.save() + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + all_reporters = ReporterType.List() + + def resolve_all_reporters(self, *args, **kwargs): + return Reporter.objects.all() + + def resolve_reporter(self, *args, **kwargs): + return Reporter.objects.first() + + query = ''' + query ReporterQuery { + reporter { + lastName + } + allReporters { + lastName + } + __debug { + sql { + rawSql + } + } + } + ''' + expected = { + 'reporter': { + 'lastName': 'ABA', + }, + 'allReporters': [{ + 'lastName': 'ABA', + }, { + 'lastName': 'Griffin', + }], + '__debug': { + 'sql': [{ + 'rawSql': str(Reporter.objects.order_by('pk')[:1].query) + }, { + 'rawSql': str(Reporter.objects.all().query) + }] + } + } + schema = graphene.Schema(query=Query, plugins=[DjangoDebugPlugin()]) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene/contrib/django/debug/types.py b/graphene/contrib/django/debug/types.py new file mode 100644 index 00000000..bceb54b0 --- /dev/null +++ b/graphene/contrib/django/debug/types.py @@ -0,0 +1,7 @@ +from ....core.classtypes.objecttype import ObjectType +from ....core.types import Field +from .sql.types import DjangoDebugSQL + + +class DjangoDebug(ObjectType): + sql = Field(DjangoDebugSQL.List()) diff --git a/graphene/core/schema.py b/graphene/core/schema.py index 6c98371f..f8f26dc1 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -118,17 +118,10 @@ class Schema(object): def types(self): return self._types_names - def execute(self, request='', root=None, vars=None, - operation_name=None, **kwargs): - root = root or object() - return self.executor.execute( - self.schema, - request, - root=root, - args=vars, - operation_name=operation_name, - **kwargs - ) + def execute(self, request='', root=None, args=None, **kwargs): + kwargs = dict(kwargs, request=request, root=root, args=args, schema=self.schema) + with self.plugins.context_execution(**kwargs) as execute_kwargs: + return self.executor.execute(**execute_kwargs) def introspect(self): return self.execute(introspection_query).data diff --git a/graphene/core/types/base.py b/graphene/core/types/base.py index c9d26cf5..ec8c7b3b 100644 --- a/graphene/core/types/base.py +++ b/graphene/core/types/base.py @@ -129,6 +129,7 @@ class MountedType(FieldType, ArgumentType): class NamedType(InstanceType): + def __init__(self, name=None, default_name=None, *args, **kwargs): self.name = name self.default_name = None diff --git a/graphene/plugins/base.py b/graphene/plugins/base.py index 8ab5668a..2347beba 100644 --- a/graphene/plugins/base.py +++ b/graphene/plugins/base.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from functools import partial, reduce @@ -38,3 +39,15 @@ class PluginManager(object): def __contains__(self, name): return name in self.PLUGIN_FUNCTIONS + + @contextmanager + def context_execution(self, **executor): + contexts = [] + functions = self.get_plugin_functions('context_execution') + for f in functions: + context = f(executor) + executor = context.__enter__() + contexts.append((context, executor)) + yield executor + for context, value in contexts[::-1]: + context.__exit__(None, None, None)