Files
wecom_it_smart_desk/backend/app/database.py
T

147 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — 数据库连接与 Session 管理
# =============================================================================
# 说明:使用 SQLAlchemy 2.0 的异步引擎和会话管理,负责:
# 1. 创建异步数据库引擎(懒加载,支持 PostgreSQL 和 SQLite
# 2. 创建异步会话工厂
# 3. 提供 get_db 依赖注入函数(FastAPI 路由中使用)
# 4. 自动建表(SQLite 本地开发时自动创建所有表)
# =============================================================================
import logging
from typing import AsyncGenerator, Optional
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
# 导入配置(读取数据库连接地址)
from app.config import settings
from app.utils.response import AppException
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------
# 声明式基类(单独定义,不依赖引擎创建)
# ----------------------------------------------------------------------
# 所有模型类都继承自 Base,SQLAlchemy 通过它来检测所有模型定义
# Alembic 也通过 Base.metadata 来生成迁移脚本
# ----------------------------------------------------------------------
class Base(DeclarativeBase):
"""SQLAlchemy 声明式基类。
所有模型类都继承此类,SQLAlchemy 通过它管理所有表的元数据。
"""
pass
# ----------------------------------------------------------------------
# 懒加载引擎和会话工厂
# ----------------------------------------------------------------------
_engine: Optional[object] = None
_async_session_factory: Optional[async_sessionmaker] = None
_tables_created: bool = False # 标记是否已自动建表
def _is_sqlite() -> bool:
"""判断当前数据库 URL 是否为 SQLite。"""
return "sqlite" in settings.database_url.lower()
def _get_engine():
"""懒加载获取数据库引擎。
支持 PostgreSQL 和 SQLite 两种后端:
- PostgreSQL: 使用 asyncpg 异步驱动,带连接池
- SQLite: 使用 aiosqlite 异步驱动,无需连接池
"""
global _engine
if _engine is None:
db_url = settings.database_url
if _is_sqlite():
# SQLite 异步驱动:aiosqlite
# 不需要连接池,SQLite 是单文件数据库
_engine = create_async_engine(
db_url,
echo=False,
)
logger.info(f"使用 SQLite 数据库: {db_url}")
else:
# PostgreSQL 异步驱动:asyncpg
_engine = create_async_engine(
db_url.replace("postgresql://", "postgresql+asyncpg://"),
echo=False,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
logger.info(f"使用 PostgreSQL 数据库: {db_url.split('@')[-1]}")
return _engine
def _get_session_factory() -> async_sessionmaker:
"""懒加载获取会话工厂。"""
global _async_session_factory
if _async_session_factory is None:
_async_session_factory = async_sessionmaker(
_get_engine(),
class_=AsyncSession,
expire_on_commit=False,
)
return _async_session_factory
async def _ensure_tables():
"""自动建表(仅 SQLite 本地开发时使用)。
PostgreSQL 环境应使用 Alembic 迁移管理表结构。
SQLite 环境下直接用 metadata.create_all 创建所有表,省去迁移步骤。
"""
global _tables_created
if _tables_created:
return
_tables_created = True
if _is_sqlite():
# 导入所有模型,确保 Base.metadata 知道所有表
import app.models # noqa: F401
engine = _get_engine()
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("SQLite 自动建表完成")
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话的依赖注入函数。
在 FastAPI 路由中通过 Depends(get_db) 注入数据库会话。
使用 async with 确保会话在使用后正确关闭。
使用 try/finally 确保异常时也能回滚和关闭。
Yields:
AsyncSession: 异步数据库会话
"""
# 首次调用时自动建表(SQLite
await _ensure_tables()
# 创建一个新的数据库会话(懒加载会话工厂)
try:
session_factory = _get_session_factory()
except Exception as e:
logger.error(f"数据库连接失败(无法创建会话工厂): {e}")
raise AppException(1006, f"数据库连接失败: {str(e)}")
async with session_factory() as session:
try:
# 将会话交给路由函数使用
yield session
# 路由函数执行成功后提交事务
await session.commit()
except Exception:
# 路由函数执行失败时回滚事务
await session.rollback()
# 重新抛出异常,让 FastAPI 的异常处理器处理
raise