147 lines
5.2 KiB
Python
147 lines
5.2 KiB
Python
# =============================================================================
|
||
# 企微IT智能服务台 — 数据库连接与 Session 管理
|
||
# =============================================================================
|
||
# 说明:使用 SQLAlchemy 2.0 的异步引擎和会话管理,负责:
|
||
# 1. 创建异步数据库引擎(懒加载,支持 PostgreSQL 和 SQLite)
|
||
# 2. 创建异步会话工厂
|
||
# 3. 提供 get_db 依赖注入函数(FastAPI 路由中使用)
|
||
# 4. 自动建表(SQLite 本地开发时自动创建所有表)
|
||
# =============================================================================
|
||
|
||
import logging
|
||
from typing import AsyncGenerator, Optional
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||
from sqlalchemy.orm import DeclarativeBase
|
||
|
||
# 导入配置(读取数据库连接地址)
|
||
from app.config import settings
|
||
from app.utils.response import AppException
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 声明式基类(单独定义,不依赖引擎创建)
|
||
# ----------------------------------------------------------------------
|
||
# 所有模型类都继承自 Base,SQLAlchemy 通过它来检测所有模型定义
|
||
# Alembic 也通过 Base.metadata 来生成迁移脚本
|
||
# ----------------------------------------------------------------------
|
||
class Base(DeclarativeBase):
|
||
"""SQLAlchemy 声明式基类。
|
||
|
||
所有模型类都继承此类,SQLAlchemy 通过它管理所有表的元数据。
|
||
"""
|
||
|
||
pass
|
||
|
||
|
||
# ----------------------------------------------------------------------
|
||
# 懒加载引擎和会话工厂
|
||
# ----------------------------------------------------------------------
|
||
_engine: Optional[object] = None
|
||
_async_session_factory: Optional[async_sessionmaker] = None
|
||
_tables_created: bool = False # 标记是否已自动建表
|
||
|
||
|
||
def _is_sqlite() -> bool:
|
||
"""判断当前数据库 URL 是否为 SQLite。"""
|
||
return "sqlite" in settings.database_url.lower()
|
||
|
||
|
||
def _get_engine():
|
||
"""懒加载获取数据库引擎。
|
||
|
||
支持 PostgreSQL 和 SQLite 两种后端:
|
||
- PostgreSQL: 使用 asyncpg 异步驱动,带连接池
|
||
- SQLite: 使用 aiosqlite 异步驱动,无需连接池
|
||
"""
|
||
global _engine
|
||
if _engine is None:
|
||
db_url = settings.database_url
|
||
|
||
if _is_sqlite():
|
||
# SQLite 异步驱动:aiosqlite
|
||
# 不需要连接池,SQLite 是单文件数据库
|
||
_engine = create_async_engine(
|
||
db_url,
|
||
echo=False,
|
||
)
|
||
logger.info(f"使用 SQLite 数据库: {db_url}")
|
||
else:
|
||
# PostgreSQL 异步驱动:asyncpg
|
||
_engine = create_async_engine(
|
||
db_url.replace("postgresql://", "postgresql+asyncpg://"),
|
||
echo=False,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
pool_pre_ping=True,
|
||
)
|
||
logger.info(f"使用 PostgreSQL 数据库: {db_url.split('@')[-1]}")
|
||
return _engine
|
||
|
||
|
||
def _get_session_factory() -> async_sessionmaker:
|
||
"""懒加载获取会话工厂。"""
|
||
global _async_session_factory
|
||
if _async_session_factory is None:
|
||
_async_session_factory = async_sessionmaker(
|
||
_get_engine(),
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
)
|
||
return _async_session_factory
|
||
|
||
|
||
async def _ensure_tables():
|
||
"""自动建表(仅 SQLite 本地开发时使用)。
|
||
|
||
PostgreSQL 环境应使用 Alembic 迁移管理表结构。
|
||
SQLite 环境下直接用 metadata.create_all 创建所有表,省去迁移步骤。
|
||
"""
|
||
global _tables_created
|
||
if _tables_created:
|
||
return
|
||
_tables_created = True
|
||
|
||
if _is_sqlite():
|
||
# 导入所有模型,确保 Base.metadata 知道所有表
|
||
import app.models # noqa: F401
|
||
|
||
engine = _get_engine()
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
logger.info("SQLite 自动建表完成")
|
||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""获取数据库会话的依赖注入函数。
|
||
|
||
在 FastAPI 路由中通过 Depends(get_db) 注入数据库会话。
|
||
使用 async with 确保会话在使用后正确关闭。
|
||
使用 try/finally 确保异常时也能回滚和关闭。
|
||
|
||
Yields:
|
||
AsyncSession: 异步数据库会话
|
||
"""
|
||
# 首次调用时自动建表(SQLite)
|
||
await _ensure_tables()
|
||
|
||
# 创建一个新的数据库会话(懒加载会话工厂)
|
||
try:
|
||
session_factory = _get_session_factory()
|
||
except Exception as e:
|
||
logger.error(f"数据库连接失败(无法创建会话工厂): {e}")
|
||
raise AppException(1006, f"数据库连接失败: {str(e)}")
|
||
|
||
async with session_factory() as session:
|
||
try:
|
||
# 将会话交给路由函数使用
|
||
yield session
|
||
# 路由函数执行成功后提交事务
|
||
await session.commit()
|
||
except Exception:
|
||
# 路由函数执行失败时回滚事务
|
||
await session.rollback()
|
||
# 重新抛出异常,让 FastAPI 的异常处理器处理
|
||
raise
|