Move non-persistent stuff to base Session class

This commit is contained in:
Tulir Asokan 2018-03-02 11:10:11 +02:00
parent c5e6f7e265
commit 4c64d53e71
2 changed files with 87 additions and 119 deletions

View File

@ -1,10 +1,32 @@
from abc import ABC, abstractmethod
import time
import platform
class Session(ABC):
@abstractmethod
def __init__(self):
self._sequence = 0
self._last_msg_id = 0
self._time_offset = 0
system = platform.uname()
self._device_model = system.system or 'Unknown'
self._system_version = system.release or '1.0'
self._app_version = '1.0'
self._lang_code = 'en'
self._system_lang_code = self.lang_code
self._report_errors = True
self._flood_sleep_threshold = 60
def clone(self):
raise NotImplementedError
cloned = self.__class__()
cloned._device_model = self.device_model
cloned._system_version = self.system_version
cloned._app_version = self.app_version
cloned._lang_code = self.lang_code
cloned._system_lang_code = self.system_lang_code
cloned._report_errors = self.report_errors
cloned._flood_sleep_threshold = self.flood_sleep_threshold
@abstractmethod
def set_dc(self, dc_id, server_address, port):
@ -31,14 +53,12 @@ class Session(ABC):
raise NotImplementedError
@property
@abstractmethod
def time_offset(self):
raise NotImplementedError
return self._time_offset
@time_offset.setter
@abstractmethod
def time_offset(self, value):
raise NotImplementedError
self._time_offset = value
@property
@abstractmethod
@ -50,46 +70,6 @@ class Session(ABC):
def salt(self, value):
raise NotImplementedError
@property
@abstractmethod
def device_model(self):
raise NotImplementedError
@property
@abstractmethod
def system_version(self):
raise NotImplementedError
@property
@abstractmethod
def app_version(self):
raise NotImplementedError
@property
@abstractmethod
def lang_code(self):
raise NotImplementedError
@property
@abstractmethod
def system_lang_code(self):
raise NotImplementedError
@property
@abstractmethod
def report_errors(self):
raise NotImplementedError
@property
@abstractmethod
def sequence(self):
raise NotImplementedError
@property
@abstractmethod
def flood_sleep_threshold(self):
raise NotImplementedError
@abstractmethod
def close(self):
raise NotImplementedError
@ -107,18 +87,6 @@ class Session(ABC):
def list_sessions(cls):
raise NotImplementedError
@abstractmethod
def get_new_msg_id(self):
raise NotImplementedError
@abstractmethod
def update_time_offset(self, correct_msg_id):
raise NotImplementedError
@abstractmethod
def generate_sequence(self, content_related):
raise NotImplementedError
@abstractmethod
def process_entities(self, tlo):
raise NotImplementedError
@ -134,3 +102,63 @@ class Session(ABC):
@abstractmethod
def get_file(self, md5_digest, file_size, cls):
raise NotImplementedError
@property
def device_model(self):
return self._device_model
@property
def system_version(self):
return self._system_version
@property
def app_version(self):
return self._app_version
@property
def lang_code(self):
return self._lang_code
@property
def system_lang_code(self):
return self._system_lang_code
@property
def report_errors(self):
return self._report_errors
@property
def flood_sleep_threshold(self):
return self._flood_sleep_threshold
@property
def sequence(self):
return self._sequence
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
now = time.time() + self._time_offset
nanoseconds = int((now - int(now)) * 1e+9)
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
now = int(time.time())
correct = correct_msg_id >> 32
self._time_offset = correct - now
self._last_msg_id = 0
def generate_sequence(self, content_related):
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2

View File

@ -1,6 +1,4 @@
from enum import Enum
import time
import platform
from .. import utils
from .abstract import Session
@ -29,38 +27,16 @@ class _SentFileType(Enum):
class MemorySession(Session):
def __init__(self):
super().__init__()
self._dc_id = None
self._server_address = None
self._port = None
self._salt = None
self._auth_key = None
self._sequence = 0
self._last_msg_id = 0
self._time_offset = 0
self._flood_sleep_threshold = 60
system = platform.uname()
self._device_model = system.system or 'Unknown'
self._system_version = system.release or '1.0'
self._app_version = '1.0'
self._lang_code = 'en'
self._system_lang_code = self.lang_code
self._report_errors = True
self._flood_sleep_threshold = 60
self._files = {}
self._entities = set()
def clone(self):
cloned = MemorySession()
cloned._device_model = self.device_model
cloned._system_version = self.system_version
cloned._app_version = self.app_version
cloned._lang_code = self.lang_code
cloned._system_lang_code = self.system_lang_code
cloned._report_errors = self.report_errors
cloned._flood_sleep_threshold = self.flood_sleep_threshold
def set_dc(self, dc_id, server_address, port):
self._dc_id = dc_id
self._server_address = server_address
@ -82,14 +58,6 @@ class MemorySession(Session):
def auth_key(self, value):
self._auth_key = value
@property
def time_offset(self):
return self._time_offset
@time_offset.setter
def time_offset(self, value):
self._time_offset = value
@property
def salt(self):
return self._salt
@ -143,34 +111,6 @@ class MemorySession(Session):
def list_sessions(cls):
raise NotImplementedError
def get_new_msg_id(self):
"""Generates a new unique message ID based on the current
time (in ms) since epoch"""
now = time.time() + self._time_offset
nanoseconds = int((now - int(now)) * 1e+9)
new_msg_id = (int(now) << 32) | (nanoseconds << 2)
if self._last_msg_id >= new_msg_id:
new_msg_id = self._last_msg_id + 4
self._last_msg_id = new_msg_id
return new_msg_id
def update_time_offset(self, correct_msg_id):
now = int(time.time())
correct = correct_msg_id >> 32
self._time_offset = correct - now
self._last_msg_id = 0
def generate_sequence(self, content_related):
if content_related:
result = self._sequence * 2 + 1
self._sequence += 1
return result
else:
return self._sequence * 2
@staticmethod
def _entities_to_rows(tlo):
if not isinstance(tlo, TLObject) and utils.is_list_like(tlo):