saving info about user after /start

This commit is contained in:
kiriharu 2021-01-30 13:01:05 +03:00
parent b8162edff0
commit 36d8020dd2
8 changed files with 75 additions and 11 deletions

View File

@ -6,14 +6,14 @@ authors = ["kiriharu <kiriharu@yandex.ru>"]
[tool.poetry.dependencies]
python = "^3.8"
core = {path = "../core"}
aiogram = "^2.11.2"
httpx = "^0.16.1"
python-whois = "^0.7.3"
core = {path = "../core"}
aioinflux = "^0.9.0"
loguru = "^0.5.3"
whois-vu = "^0.3.0"
tortoise-orm = "^0.16.19"
tortoise-orm = "^0.16.20"
[tool.poetry.dev-dependencies]

View File

@ -1,6 +1,8 @@
from aiogram import Bot, Dispatcher, executor
from aiogram.contrib.fsm_storage.memory import MemoryStorage
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
from tgbot.middlewares import UserMiddleware, WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
from tortoise import Tortoise
from loguru import logger
import config
import handlers
@ -9,13 +11,29 @@ telegram_bot = Bot(token=config.TELEGRAM_BOT_TOKEN)
dp = Dispatcher(telegram_bot, storage=storage)
def on_startup():
handlers.default.setup(dp)
dp.middleware.setup(WriteCommandMetric())
dp.middleware.setup(LoggingMiddleware())
dp.middleware.setup(ThrottlingMiddleware())
async def database_init():
if config.MYSQL_HOST is not None:
db_url = f"mysql://{config.MYSQL_USER}:{config.MYSQL_PASSWORD}@" \
f"{config.MYSQL_HOST}:{config.MYSQL_PORT}/{config.MYSQL_DATABASE}"
else:
db_url = "sqlite://db.sqlite3"
await Tortoise.init(
db_url=db_url,
modules={
'default': ['tgbot.models']
}
)
await Tortoise.generate_schemas()
logger.info("Tortoise inited!")
async def on_startup(disp: Dispatcher):
await database_init()
handlers.default.setup(disp)
disp.middleware.setup(WriteCommandMetric())
disp.middleware.setup(LoggingMiddleware())
disp.middleware.setup(ThrottlingMiddleware())
disp.middleware.setup(UserMiddleware())
if __name__ == '__main__':
on_startup()
executor.start_polling(dp, skip_updates=True)
executor.start_polling(dp, skip_updates=True, on_startup=on_startup)

View File

@ -13,3 +13,10 @@ INFLUX_DB = os.getenv("INFLUX_DB", None)
# Notifications
NOTIFICATION_BOT_TOKEN = os.getenv("NOTIFICATION_BOT_TOKEN")
NOTIFICATION_USERS = os.getenv("NOTIFICATION_USERS", "").split(",")
# Mysql params
MYSQL_HOST = os.getenv("MYSQL_HOST", None) # if none, use sqlite db
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_PORT = os.getenv("MYSQL_PORT", 3306)
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "unicheckbot")

View File

@ -1,5 +1,7 @@
from aiogram.types import Message
from tgbot.models.user import User
from tgbot.middlewares.throttling import rate_limit
from tgbot.middlewares.userdata import userdata_required
start_message = """
@ -30,6 +32,7 @@ start_message = """
"""
@userdata_required
@rate_limit
async def start_cmd(msg: Message):
async def start_cmd(msg: Message, user: User):
await msg.answer(start_message.replace("%name%", msg.from_user.full_name), parse_mode='markdown', disable_web_page_preview=True)

View File

@ -1,3 +1,4 @@
from tgbot.middlewares.write_command_metric import WriteCommandMetric
from tgbot.middlewares.logging import LoggingMiddleware
from tgbot.middlewares.throttling import ThrottlingMiddleware
from tgbot.middlewares.userdata import UserMiddleware

View File

@ -0,0 +1,33 @@
from aiogram.dispatcher.handler import current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.types import Message, CallbackQuery
from tgbot.models import User
def userdata_required(func):
"""Setting login_required to function"""
setattr(func, 'userdata_required', True)
return func
class UserMiddleware(BaseMiddleware):
def __init__(self):
super(UserMiddleware, self).__init__()
@staticmethod
async def get_userdata(telegram_id: int) -> User:
handler = current_handler.get()
if handler:
attr = getattr(handler, 'userdata_required', False)
if attr:
# Setting user
user, _ = await User.get_or_create(telegram_id=telegram_id)
return user
async def on_process_message(self, message: Message, data: dict):
data['user'] = await self.get_userdata(message.from_user.id)
async def on_process_callback_query(self, callback_query: CallbackQuery, data: dict):
data['user'] = await self.get_userdata(callback_query.from_user.id)

View File

@ -0,0 +1 @@
from .user import *

View File

@ -5,3 +5,4 @@ class User(Model):
telegram_id = fields.IntField(pk=True)
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)