mirror of
https://github.com/graphql-python/graphene-django.git
synced 2025-02-22 14:30:33 +03:00
If DjangoDebugMiddleware is installed, calling `cursor.execute(b)` where b is a `bytes` object causes the recording (and thus the entire database call) to throw a TypeError due to 775644b536/graphene_django/debug/sql/tracking.py (L126)
:
```
"is_select": sql.lower().strip().startswith("select"),
```
Calling execute with a bytes parameter, to my knowledge, is not currently done within the high-level abstractions in the Django ORM, but is very much supported by psycopg2, as evidenced by the use in psycopg2's own `execute_values` in https://github.com/psycopg/psycopg2/blob/2_9_3/lib/extras.py#L1270 :
```
cur.execute(b''.join(parts))
```
This fix ensures that the sql parameter is safely decoded before scanning whether it begins with SELECT; since this is the only usage, the change is trivial.
The only workaround if code calls execute_values is to disable the DjangoDebugMiddleware altogether, which is far from ideal.
170 lines
4.9 KiB
Python
170 lines
4.9 KiB
Python
# Code obtained from django-debug-toolbar sql panel tracking
|
|
|
|
import json
|
|
from threading import local
|
|
from time import time
|
|
|
|
from django.utils.encoding import force_str
|
|
|
|
from .types import DjangoDebugSQL
|
|
|
|
|
|
class SQLQueryTriggered(Exception):
|
|
"""Thrown when template panel triggers a query"""
|
|
|
|
|
|
class ThreadLocalState(local):
|
|
def __init__(self):
|
|
self.enabled = True
|
|
|
|
@property
|
|
def Wrapper(self):
|
|
if self.enabled:
|
|
return NormalCursorWrapper
|
|
return ExceptionCursorWrapper
|
|
|
|
def recording(self, v):
|
|
self.enabled = v
|
|
|
|
|
|
state = ThreadLocalState()
|
|
recording = state.recording # export function
|
|
|
|
|
|
def wrap_cursor(connection, panel):
|
|
if not hasattr(connection, "_graphene_cursor"):
|
|
connection._graphene_cursor = connection.cursor
|
|
|
|
def cursor():
|
|
return state.Wrapper(connection._graphene_cursor(), connection, panel)
|
|
|
|
connection.cursor = cursor
|
|
return cursor
|
|
|
|
|
|
def unwrap_cursor(connection):
|
|
if hasattr(connection, "_graphene_cursor"):
|
|
previous_cursor = connection._graphene_cursor
|
|
connection.cursor = previous_cursor
|
|
del connection._graphene_cursor
|
|
|
|
|
|
class ExceptionCursorWrapper:
|
|
"""
|
|
Wraps a cursor and raises an exception on any operation.
|
|
Used in Templates panel.
|
|
"""
|
|
|
|
def __init__(self, cursor, db, logger):
|
|
pass
|
|
|
|
def __getattr__(self, attr):
|
|
raise SQLQueryTriggered()
|
|
|
|
|
|
class NormalCursorWrapper:
|
|
"""
|
|
Wraps a cursor and logs queries.
|
|
"""
|
|
|
|
def __init__(self, cursor, db, logger):
|
|
self.cursor = cursor
|
|
# Instance of a BaseDatabaseWrapper subclass
|
|
self.db = db
|
|
# logger must implement a ``record`` method
|
|
self.logger = logger
|
|
|
|
def _quote_expr(self, element):
|
|
if isinstance(element, str):
|
|
return "'%s'" % force_str(element).replace("'", "''")
|
|
else:
|
|
return repr(element)
|
|
|
|
def _quote_params(self, params):
|
|
if not params:
|
|
return params
|
|
if isinstance(params, dict):
|
|
return {key: self._quote_expr(value) for key, value in params.items()}
|
|
return list(map(self._quote_expr, params))
|
|
|
|
def _decode(self, param):
|
|
try:
|
|
return force_str(param, strings_only=True)
|
|
except UnicodeDecodeError:
|
|
return "(encoded string)"
|
|
|
|
def _record(self, method, sql, params):
|
|
start_time = time()
|
|
try:
|
|
return method(sql, params)
|
|
finally:
|
|
stop_time = time()
|
|
duration = stop_time - start_time
|
|
_params = ""
|
|
try:
|
|
_params = json.dumps(list(map(self._decode, params)))
|
|
except Exception:
|
|
pass # object not JSON serializable
|
|
|
|
alias = getattr(self.db, "alias", "default")
|
|
conn = self.db.connection
|
|
vendor = getattr(conn, "vendor", "unknown")
|
|
sql_str = sql.decode(errors="ignore") if isinstance(sql, bytes) else sql
|
|
|
|
params = {
|
|
"vendor": vendor,
|
|
"alias": alias,
|
|
"sql": self.db.ops.last_executed_query(
|
|
self.cursor, sql, self._quote_params(params)
|
|
),
|
|
"duration": duration,
|
|
"raw_sql": sql,
|
|
"params": _params,
|
|
"start_time": start_time,
|
|
"stop_time": stop_time,
|
|
"is_slow": duration > 10,
|
|
"is_select": sql_str.lower().strip().startswith("select"),
|
|
}
|
|
|
|
if vendor == "postgresql":
|
|
# If an erroneous query was ran on the connection, it might
|
|
# be in a state where checking isolation_level raises an
|
|
# exception.
|
|
try:
|
|
iso_level = conn.isolation_level
|
|
except conn.InternalError:
|
|
iso_level = "unknown"
|
|
params.update(
|
|
{
|
|
"trans_id": self.logger.get_transaction_id(alias),
|
|
"trans_status": conn.get_transaction_status(),
|
|
"iso_level": iso_level,
|
|
"encoding": conn.encoding,
|
|
}
|
|
)
|
|
|
|
_sql = DjangoDebugSQL(**params)
|
|
# We keep `sql` to maintain backwards compatibility
|
|
self.logger.object.sql.append(_sql)
|
|
|
|
def callproc(self, procname, params=None):
|
|
return self._record(self.cursor.callproc, procname, params)
|
|
|
|
def execute(self, sql, params=None):
|
|
return self._record(self.cursor.execute, sql, params)
|
|
|
|
def executemany(self, sql, param_list):
|
|
return self._record(self.cursor.executemany, sql, param_list)
|
|
|
|
def __getattr__(self, attr):
|
|
return getattr(self.cursor, attr)
|
|
|
|
def __iter__(self):
|
|
return iter(self.cursor)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
self.close()
|