saving info about user after /start

This commit is contained in:
kiriharu 2021-01-30 13:01:05 +03:00
parent b8162edff0
commit 36d8020dd2
8 changed files with 75 additions and 11 deletions

View File

@ -6,14 +6,14 @@ authors = ["kiriharu <kiriharu@yandex.ru>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = "^3.8"
core = {path = "../core"}
aiogram = "^2.11.2" aiogram = "^2.11.2"
httpx = "^0.16.1" httpx = "^0.16.1"
python-whois = "^0.7.3" python-whois = "^0.7.3"
core = {path = "../core"}
aioinflux = "^0.9.0" aioinflux = "^0.9.0"
loguru = "^0.5.3" loguru = "^0.5.3"
whois-vu = "^0.3.0" whois-vu = "^0.3.0"
tortoise-orm = "^0.16.19" tortoise-orm = "^0.16.20"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]

View File

@ -1,6 +1,8 @@
from aiogram import Bot, Dispatcher, executor from aiogram import Bot, Dispatcher, executor
from aiogram.contrib.fsm_storage.memory import MemoryStorage from aiogram.contrib.fsm_storage.memory import MemoryStorage
from tgbot.middlewares import WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware from tgbot.middlewares import UserMiddleware, WriteCommandMetric, LoggingMiddleware, ThrottlingMiddleware
from tortoise import Tortoise
from loguru import logger
import config import config
import handlers import handlers
@ -9,13 +11,29 @@ telegram_bot = Bot(token=config.TELEGRAM_BOT_TOKEN)
dp = Dispatcher(telegram_bot, storage=storage) dp = Dispatcher(telegram_bot, storage=storage)
def on_startup(): async def database_init():
handlers.default.setup(dp) if config.MYSQL_HOST is not None:
dp.middleware.setup(WriteCommandMetric()) db_url = f"mysql://{config.MYSQL_USER}:{config.MYSQL_PASSWORD}@" \
dp.middleware.setup(LoggingMiddleware()) f"{config.MYSQL_HOST}:{config.MYSQL_PORT}/{config.MYSQL_DATABASE}"
dp.middleware.setup(ThrottlingMiddleware()) else:
db_url = "sqlite://db.sqlite3"
await Tortoise.init(
db_url=db_url,
modules={
'default': ['tgbot.models']
}
)
await Tortoise.generate_schemas()
logger.info("Tortoise inited!")
async def on_startup(disp: Dispatcher):
await database_init()
handlers.default.setup(disp)
disp.middleware.setup(WriteCommandMetric())
disp.middleware.setup(LoggingMiddleware())
disp.middleware.setup(ThrottlingMiddleware())
disp.middleware.setup(UserMiddleware())
if __name__ == '__main__': if __name__ == '__main__':
on_startup() executor.start_polling(dp, skip_updates=True, on_startup=on_startup)
executor.start_polling(dp, skip_updates=True)

View File

@ -13,3 +13,10 @@ INFLUX_DB = os.getenv("INFLUX_DB", None)
# Notifications # Notifications
NOTIFICATION_BOT_TOKEN = os.getenv("NOTIFICATION_BOT_TOKEN") NOTIFICATION_BOT_TOKEN = os.getenv("NOTIFICATION_BOT_TOKEN")
NOTIFICATION_USERS = os.getenv("NOTIFICATION_USERS", "").split(",") NOTIFICATION_USERS = os.getenv("NOTIFICATION_USERS", "").split(",")
# Mysql params
MYSQL_HOST = os.getenv("MYSQL_HOST", None) # if none, use sqlite db
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_PORT = os.getenv("MYSQL_PORT", 3306)
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "unicheckbot")

View File

@ -1,5 +1,7 @@
from aiogram.types import Message from aiogram.types import Message
from tgbot.models.user import User
from tgbot.middlewares.throttling import rate_limit from tgbot.middlewares.throttling import rate_limit
from tgbot.middlewares.userdata import userdata_required
start_message = """ start_message = """
@ -30,6 +32,7 @@ start_message = """
""" """
@userdata_required
@rate_limit @rate_limit
async def start_cmd(msg: Message): async def start_cmd(msg: Message, user: User):
await msg.answer(start_message.replace("%name%", msg.from_user.full_name), parse_mode='markdown', disable_web_page_preview=True) await msg.answer(start_message.replace("%name%", msg.from_user.full_name), parse_mode='markdown', disable_web_page_preview=True)

View File

@ -1,3 +1,4 @@
from tgbot.middlewares.write_command_metric import WriteCommandMetric from tgbot.middlewares.write_command_metric import WriteCommandMetric
from tgbot.middlewares.logging import LoggingMiddleware from tgbot.middlewares.logging import LoggingMiddleware
from tgbot.middlewares.throttling import ThrottlingMiddleware from tgbot.middlewares.throttling import ThrottlingMiddleware
from tgbot.middlewares.userdata import UserMiddleware

View File

@ -0,0 +1,33 @@
from aiogram.dispatcher.handler import current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.types import Message, CallbackQuery
from tgbot.models import User
def userdata_required(func):
"""Setting login_required to function"""
setattr(func, 'userdata_required', True)
return func
class UserMiddleware(BaseMiddleware):
def __init__(self):
super(UserMiddleware, self).__init__()
@staticmethod
async def get_userdata(telegram_id: int) -> User:
handler = current_handler.get()
if handler:
attr = getattr(handler, 'userdata_required', False)
if attr:
# Setting user
user, _ = await User.get_or_create(telegram_id=telegram_id)
return user
async def on_process_message(self, message: Message, data: dict):
data['user'] = await self.get_userdata(message.from_user.id)
async def on_process_callback_query(self, callback_query: CallbackQuery, data: dict):
data['user'] = await self.get_userdata(callback_query.from_user.id)

View File

@ -0,0 +1 @@
from .user import *

View File

@ -5,3 +5,4 @@ class User(Model):
telegram_id = fields.IntField(pk=True) telegram_id = fields.IntField(pk=True)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)