初始化
This commit is contained in:
2
core/__init__.py
Normal file
2
core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""核心模块"""
|
||||
|
||||
298
core/database.py
Normal file
298
core/database.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""SQLite数据库操作模块 - 使用标准库sqlite3"""
|
||||
import sqlite3
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pathlib import Path
|
||||
from config import DATABASE_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
"""数据库管理类"""
|
||||
|
||||
def __init__(self, db_path: str = DATABASE_PATH):
|
||||
"""初始化数据库连接
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
self._ensure_db_exists()
|
||||
self.init_tables()
|
||||
|
||||
def _ensure_db_exists(self):
|
||||
"""确保数据库目录存在"""
|
||||
db_dir = Path(self.db_path).parent
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def conn(self) -> sqlite3.Connection:
|
||||
"""获取数据库连接(懒加载)"""
|
||||
if self._conn is None:
|
||||
self._conn = sqlite3.connect(
|
||||
self.db_path,
|
||||
check_same_thread=False, # 允许多线程访问
|
||||
isolation_level=None # 自动提交
|
||||
)
|
||||
self._conn.row_factory = sqlite3.Row # 支持字典式访问
|
||||
return self._conn
|
||||
|
||||
def init_tables(self):
|
||||
"""初始化数据库表"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# 用户表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
username TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active INTEGER NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# 游戏状态表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS game_states (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
chat_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
game_type TEXT NOT NULL,
|
||||
state_data TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
UNIQUE(chat_id, user_id, game_type)
|
||||
)
|
||||
""")
|
||||
|
||||
# 创建索引
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_user
|
||||
ON game_states(chat_id, user_id)
|
||||
""")
|
||||
|
||||
# 游戏统计表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS game_stats (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
game_type TEXT NOT NULL,
|
||||
wins INTEGER DEFAULT 0,
|
||||
losses INTEGER DEFAULT 0,
|
||||
draws INTEGER DEFAULT 0,
|
||||
total_plays INTEGER DEFAULT 0,
|
||||
UNIQUE(user_id, game_type)
|
||||
)
|
||||
""")
|
||||
|
||||
logger.info("数据库表初始化完成")
|
||||
|
||||
# ===== 用户相关操作 =====
|
||||
|
||||
def get_or_create_user(self, user_id: int, username: str = None) -> Dict:
|
||||
"""获取或创建用户
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
用户信息字典
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
current_time = int(time.time())
|
||||
|
||||
# 尝试获取用户
|
||||
cursor.execute(
|
||||
"SELECT * FROM users WHERE user_id = ?",
|
||||
(user_id,)
|
||||
)
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user:
|
||||
# 更新最后活跃时间
|
||||
cursor.execute(
|
||||
"UPDATE users SET last_active = ? WHERE user_id = ?",
|
||||
(current_time, user_id)
|
||||
)
|
||||
return dict(user)
|
||||
else:
|
||||
# 创建新用户
|
||||
cursor.execute(
|
||||
"INSERT INTO users (user_id, username, created_at, last_active) VALUES (?, ?, ?, ?)",
|
||||
(user_id, username, current_time, current_time)
|
||||
)
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'created_at': current_time,
|
||||
'last_active': current_time
|
||||
}
|
||||
|
||||
# ===== 游戏状态相关操作 =====
|
||||
|
||||
def get_game_state(self, chat_id: int, user_id: int, game_type: str) -> Optional[Dict]:
|
||||
"""获取游戏状态
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID
|
||||
user_id: 用户ID
|
||||
game_type: 游戏类型
|
||||
|
||||
Returns:
|
||||
游戏状态字典,如果不存在返回None
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM game_states WHERE chat_id = ? AND user_id = ? AND game_type = ?",
|
||||
(chat_id, user_id, game_type)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
state = dict(row)
|
||||
# 解析JSON数据
|
||||
if state.get('state_data'):
|
||||
state['state_data'] = json.loads(state['state_data'])
|
||||
return state
|
||||
return None
|
||||
|
||||
def save_game_state(self, chat_id: int, user_id: int, game_type: str, state_data: Dict):
|
||||
"""保存游戏状态
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID
|
||||
user_id: 用户ID
|
||||
game_type: 游戏类型
|
||||
state_data: 状态数据字典
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
current_time = int(time.time())
|
||||
state_json = json.dumps(state_data, ensure_ascii=False)
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO game_states (chat_id, user_id, game_type, state_data, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(chat_id, user_id, game_type)
|
||||
DO UPDATE SET state_data = ?, updated_at = ?
|
||||
""", (chat_id, user_id, game_type, state_json, current_time, current_time,
|
||||
state_json, current_time))
|
||||
|
||||
logger.debug(f"保存游戏状态: chat_id={chat_id}, user_id={user_id}, game_type={game_type}")
|
||||
|
||||
def delete_game_state(self, chat_id: int, user_id: int, game_type: str):
|
||||
"""删除游戏状态
|
||||
|
||||
Args:
|
||||
chat_id: 会话ID
|
||||
user_id: 用户ID
|
||||
game_type: 游戏类型
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM game_states WHERE chat_id = ? AND user_id = ? AND game_type = ?",
|
||||
(chat_id, user_id, game_type)
|
||||
)
|
||||
logger.debug(f"删除游戏状态: chat_id={chat_id}, user_id={user_id}, game_type={game_type}")
|
||||
|
||||
def cleanup_old_sessions(self, timeout: int = 1800):
|
||||
"""清理过期的游戏会话
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
cutoff_time = int(time.time()) - timeout
|
||||
|
||||
cursor.execute(
|
||||
"DELETE FROM game_states WHERE updated_at < ?",
|
||||
(cutoff_time,)
|
||||
)
|
||||
deleted = cursor.rowcount
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"清理了 {deleted} 个过期游戏会话")
|
||||
|
||||
# ===== 游戏统计相关操作 =====
|
||||
|
||||
def get_game_stats(self, user_id: int, game_type: str) -> Dict:
|
||||
"""获取游戏统计
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
game_type: 游戏类型
|
||||
|
||||
Returns:
|
||||
统计数据字典
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM game_stats WHERE user_id = ? AND game_type = ?",
|
||||
(user_id, game_type)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return dict(row)
|
||||
else:
|
||||
# 返回默认值
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'game_type': game_type,
|
||||
'wins': 0,
|
||||
'losses': 0,
|
||||
'draws': 0,
|
||||
'total_plays': 0
|
||||
}
|
||||
|
||||
def update_game_stats(self, user_id: int, game_type: str,
|
||||
win: bool = False, loss: bool = False, draw: bool = False):
|
||||
"""更新游戏统计
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
game_type: 游戏类型
|
||||
win: 是否获胜
|
||||
loss: 是否失败
|
||||
draw: 是否平局
|
||||
"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# 使用UPSERT语法
|
||||
cursor.execute("""
|
||||
INSERT INTO game_stats (user_id, game_type, wins, losses, draws, total_plays)
|
||||
VALUES (?, ?, ?, ?, ?, 1)
|
||||
ON CONFLICT(user_id, game_type)
|
||||
DO UPDATE SET
|
||||
wins = wins + ?,
|
||||
losses = losses + ?,
|
||||
draws = draws + ?,
|
||||
total_plays = total_plays + 1
|
||||
""", (user_id, game_type, int(win), int(loss), int(draw),
|
||||
int(win), int(loss), int(draw)))
|
||||
|
||||
logger.debug(f"更新游戏统计: user_id={user_id}, game_type={game_type}")
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
logger.info("数据库连接已关闭")
|
||||
|
||||
|
||||
# 全局数据库实例
|
||||
_db_instance: Optional[Database] = None
|
||||
|
||||
|
||||
def get_db() -> Database:
|
||||
"""获取全局数据库实例(单例模式)"""
|
||||
global _db_instance
|
||||
if _db_instance is None:
|
||||
_db_instance = Database()
|
||||
return _db_instance
|
||||
|
||||
34
core/middleware.py
Normal file
34
core/middleware.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""中间件模块"""
|
||||
import asyncio
|
||||
import logging
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from config import MAX_CONCURRENT_REQUESTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""并发限制中间件 - 防止内存爆炸"""
|
||||
|
||||
def __init__(self, app, max_concurrent: int = MAX_CONCURRENT_REQUESTS):
|
||||
super().__init__(app)
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.max_concurrent = max_concurrent
|
||||
logger.info(f"并发限制中间件已启用,最大并发数:{max_concurrent}")
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""处理请求"""
|
||||
async with self.semaphore:
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"请求处理错误: {e}", exc_info=True)
|
||||
return Response(
|
||||
content='{"error": "Internal Server Error"}',
|
||||
status_code=500,
|
||||
media_type="application/json"
|
||||
)
|
||||
|
||||
78
core/models.py
Normal file
78
core/models.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""数据模型定义"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class CallbackRequest(BaseModel):
|
||||
"""WPS Callback请求模型"""
|
||||
chatid: int = Field(..., description="会话ID")
|
||||
creator: int = Field(..., description="发送者ID")
|
||||
content: str = Field(..., description="消息内容")
|
||||
reply: Optional[Dict[str, Any]] = Field(None, description="回复内容")
|
||||
robot_key: str = Field(..., description="机器人key")
|
||||
url: str = Field(..., description="callback地址")
|
||||
ctime: int = Field(..., description="发送时间")
|
||||
|
||||
|
||||
class TextMessage(BaseModel):
|
||||
"""文本消息"""
|
||||
msgtype: str = "text"
|
||||
text: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def create(cls, content: str):
|
||||
"""创建文本消息"""
|
||||
return cls(text={"content": content})
|
||||
|
||||
|
||||
class MarkdownMessage(BaseModel):
|
||||
"""Markdown消息"""
|
||||
msgtype: str = "markdown"
|
||||
markdown: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def create(cls, text: str):
|
||||
"""创建Markdown消息"""
|
||||
return cls(markdown={"text": text})
|
||||
|
||||
|
||||
class LinkMessage(BaseModel):
|
||||
"""链接消息"""
|
||||
msgtype: str = "link"
|
||||
link: Dict[str, str]
|
||||
|
||||
@classmethod
|
||||
def create(cls, title: str, text: str, message_url: str = "", btn_title: str = "查看详情"):
|
||||
"""创建链接消息"""
|
||||
return cls(link={
|
||||
"title": title,
|
||||
"text": text,
|
||||
"messageUrl": message_url,
|
||||
"btnTitle": btn_title
|
||||
})
|
||||
|
||||
|
||||
class GameState(BaseModel):
|
||||
"""游戏状态基类"""
|
||||
game_type: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class GuessGameState(GameState):
|
||||
"""猜数字游戏状态"""
|
||||
game_type: str = "guess"
|
||||
target: int = Field(..., description="目标数字")
|
||||
attempts: int = Field(0, description="尝试次数")
|
||||
guesses: list[int] = Field(default_factory=list, description="历史猜测")
|
||||
max_attempts: int = Field(10, description="最大尝试次数")
|
||||
|
||||
|
||||
class QuizGameState(GameState):
|
||||
"""问答游戏状态"""
|
||||
game_type: str = "quiz"
|
||||
question_id: int = Field(..., description="问题ID")
|
||||
question: str = Field(..., description="问题内容")
|
||||
attempts: int = Field(0, description="尝试次数")
|
||||
max_attempts: int = Field(3, description="最大尝试次数")
|
||||
|
||||
Reference in New Issue
Block a user