Add session_id support for clients

This commit is contained in:
sw 2022-05-29 17:50:01 +08:00
parent f1c9562260
commit 7a58546669
2 changed files with 51 additions and 5 deletions

View File

@ -13,6 +13,7 @@ import requests
from .models import ModelBase, MODEL from .models import ModelBase, MODEL
from .utils import parse_tsv, import_submodules from .utils import parse_tsv, import_submodules
from .query import Q from .query import Q
from .session import ctx_session_id, ctx_session_timeout
logger = logging.getLogger('clickhouse_orm') logger = logging.getLogger('clickhouse_orm')
@ -114,11 +115,13 @@ class Database(object):
self.request_session.auth = (username, password or '') self.request_session.auth = (username, password or '')
self.log_statements = log_statements self.log_statements = log_statements
self.settings = {} self.settings = {}
self.db_exists = False # this is required before running _is_existing_database self.db_exists = False # this is required before running _is_existing_database
self.db_exists = self._is_existing_database() self.db_exists = self._is_existing_database()
if readonly: if readonly:
if not self.db_exists: if not self.db_exists:
raise DatabaseException('Database does not exist, and cannot be created under readonly connection') raise DatabaseException(
'Database does not exist, and cannot be created under readonly connection'
)
self.connection_readonly = self._is_connection_readonly() self.connection_readonly = self._is_connection_readonly()
self.readonly = True self.readonly = True
elif autocreate and not self.db_exists: elif autocreate and not self.db_exists:
@ -379,6 +382,19 @@ class Database(object):
if int(name[:4]) >= up_to: if int(name[:4]) >= up_to:
break break
@property
def session_id(self):
"""return current client session_id"""
return ctx_session_id.get(None)
@property
def _context_params(self):
"""return context params"""
params = {}
if ctx_session_id.get(None):
params.update(session_id=self.session_id, session_timeout=ctx_session_timeout.get(60))
return params
def _get_applied_migrations(self, migrations_package_name): def _get_applied_migrations(self, migrations_package_name):
from .migrations import MigrationHistory from .migrations import MigrationHistory
self.create_table(MigrationHistory) self.create_table(MigrationHistory)
@ -392,7 +408,9 @@ class Database(object):
if self.log_statements: if self.log_statements:
logger.info(data) logger.info(data)
params = self._build_params(settings) params = self._build_params(settings)
r = self.request_session.post(self.db_url, params=params, data=data, stream=stream, timeout=self.timeout) r = self.request_session.post(
self.db_url, params=params, data=data, stream=stream, timeout=self.timeout
)
if r.status_code != 200: if r.status_code != 200:
raise ServerError(r.text) raise ServerError(r.text)
return r return r
@ -400,6 +418,7 @@ class Database(object):
def _build_params(self, settings): def _build_params(self, settings):
params = dict(settings or {}) params = dict(settings or {})
params.update(self.settings) params.update(self.settings)
params.update(self._context_params)
if self.db_exists: if self.db_exists:
params['database'] = self.db_name params['database'] = self.db_name
# Send the readonly flag, unless the connection is already readonly (to prevent db error) # Send the readonly flag, unless the connection is already readonly (to prevent db error)
@ -408,9 +427,9 @@ class Database(object):
return params return params
def _substitute(self, query, model_class=None): def _substitute(self, query, model_class=None):
''' """
Replaces $db and $table placeholders in the query. Replaces $db and $table placeholders in the query.
''' """
if '$' in query: if '$' in query:
mapping = dict(db="`%s`" % self.db_name) mapping = dict(db="`%s`" % self.db_name)
if model_class: if model_class:

View File

@ -0,0 +1,27 @@
import uuid
from typing import Optional
from contextvars import ContextVar
ctx_session_id: ContextVar[str] = ContextVar('ck.session_id')
ctx_session_timeout: ContextVar[int] = ContextVar('ck.session_timeout')
class SessionContext:
def __init__(self, session: str, timeout: int):
self.session = session
self.timeout = timeout
self.token1 = None
self.token2 = None
def __enter__(self):
self.token1 = ctx_session_id.set(self.session)
self.token2 = ctx_session_timeout.set(self.timeout)
def __exit__(self, exc_type, exc_val, exc_tb):
ctx_session_id.reset(self.token1)
ctx_session_timeout.reset(self.token2)
def in_session(session: Optional[str] = None, timeout: int = 60):
session = session or str(uuid.uuid4())
return SessionContext(session, timeout)