643 lines
22 KiB
Python
643 lines
22 KiB
Python
|
|
# =============================================================================
|
||
|
|
# 企微IT智能服务台 — MFA 二次认证测试
|
||
|
|
# =============================================================================
|
||
|
|
# Phase 2.1 task #17: pyotp TOTP 服务 + User MFA 字段
|
||
|
|
# 覆盖:status / bind/start / bind/confirm / verify / disable / admin reset
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import io
|
||
|
|
|
||
|
|
import pyotp
|
||
|
|
import pytest
|
||
|
|
import pytest_asyncio
|
||
|
|
from sqlalchemy import select
|
||
|
|
|
||
|
|
from app.models.agent import Agent
|
||
|
|
from app.services.mfa_service import MFA_VERIFIED_TTL_SECONDS, MFAService
|
||
|
|
from app.utils.error_codes import ErrorCode
|
||
|
|
from tests.conftest import create_test_agent
|
||
|
|
|
||
|
|
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
# 辅助:获取真实 token(走 /agents/login,与生产路径一致)
|
||
|
|
# -----------------------------------------------------------------------------
|
||
|
|
async def _login_and_get_token(client, user_id: str, name: str, role: str = "agent") -> str:
|
||
|
|
"""调用 /agents/login 拿 token。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: Bearer token
|
||
|
|
"""
|
||
|
|
response = await client.post(
|
||
|
|
"/agents/login",
|
||
|
|
json={"user_id": user_id, "name": name},
|
||
|
|
)
|
||
|
|
assert response.status_code == 200, f"登录失败: {response.text}"
|
||
|
|
body = response.json()
|
||
|
|
assert body.get("code") == 0, f"登录业务码非 0: {body}"
|
||
|
|
return body["data"]["token"]
|
||
|
|
|
||
|
|
|
||
|
|
def _bearer(token: str) -> dict:
|
||
|
|
"""构造 Authorization header。"""
|
||
|
|
return {"Authorization": f"Bearer {token}"}
|
||
|
|
|
||
|
|
|
||
|
|
def _is_valid_png_base64(s: str) -> bool:
|
||
|
|
"""校验字符串能 decode 成 PNG 二进制。"""
|
||
|
|
try:
|
||
|
|
raw = base64.b64decode(s, validate=True)
|
||
|
|
# PNG magic bytes: 89 50 4E 47 0D 0A 1A 0A
|
||
|
|
return raw[:8] == b"\x89PNG\r\n\x1a\n"
|
||
|
|
except Exception:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
async def _seed_admin_role(db_session, employee_id: str, role_name: str = "admin") -> str:
|
||
|
|
"""为用户分配指定角色(role_mapping_service 通过 user_roles 表查角色)。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
db_session: 数据库会话
|
||
|
|
employee_id: 企微 userid
|
||
|
|
role_name: 角色名(admin / agent / user)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: 角色 id
|
||
|
|
"""
|
||
|
|
from app.models.role import Role
|
||
|
|
from app.models.user_role import UserRole
|
||
|
|
import uuid as _uuid
|
||
|
|
from datetime import datetime as _dt
|
||
|
|
|
||
|
|
# 1. 找或建 role 行
|
||
|
|
stmt = select(Role).where(Role.name == role_name)
|
||
|
|
role = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
if not role:
|
||
|
|
role = Role(
|
||
|
|
id=str(_uuid.uuid4()),
|
||
|
|
name=role_name,
|
||
|
|
display_name={"admin": "管理员", "agent": "坐席", "user": "员工"}.get(role_name, role_name),
|
||
|
|
is_default=(role_name == "user"),
|
||
|
|
permissions=[],
|
||
|
|
)
|
||
|
|
db_session.add(role)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
# 2. 建 user_role 关联(若已存在则跳过)
|
||
|
|
stmt = select(UserRole).where(
|
||
|
|
UserRole.employee_id == employee_id,
|
||
|
|
UserRole.role_id == role.id,
|
||
|
|
)
|
||
|
|
existing = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
if not existing:
|
||
|
|
user_role = UserRole(
|
||
|
|
id=str(_uuid.uuid4()),
|
||
|
|
employee_id=employee_id,
|
||
|
|
role_id=role.id,
|
||
|
|
source="manual",
|
||
|
|
assigned_at=_dt.now(),
|
||
|
|
)
|
||
|
|
db_session.add(user_role)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
return role.id
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 1. GET /mfa/status — 全新用户
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFAStatus:
|
||
|
|
"""GET /mfa/status 行为测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_new_user_status_unbound(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""全新用户(已注册但没绑定 MFA)→ bound=false, enabled=false"""
|
||
|
|
agent = create_test_agent(user_id="alice_001", name="Alice")
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "alice_001", "Alice")
|
||
|
|
resp = await client.get("/mfa/status", headers=_bearer(token))
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
data = body["data"]
|
||
|
|
assert data["bound"] is False
|
||
|
|
assert data["enabled"] is False
|
||
|
|
assert data["last_verified_at"] is None
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 2. POST /mfa/bind/start — 生成 secret + 二维码
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFABindStart:
|
||
|
|
"""POST /mfa/bind/start 行为测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_bind_start_returns_secret_and_qrcode(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""bind/start 返回 secret + otpauth_url + base64 PNG"""
|
||
|
|
agent = create_test_agent(user_id="bob_001", name="Bob")
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "bob_001", "Bob")
|
||
|
|
resp = await client.post("/mfa/bind/start", headers=_bearer(token))
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
data = body["data"]
|
||
|
|
# 三件套都在
|
||
|
|
assert "secret" in data
|
||
|
|
assert "otpauth_url" in data
|
||
|
|
assert "qr_code_base64" in data
|
||
|
|
# secret 是 32 位 base32
|
||
|
|
assert len(data["secret"]) == 32
|
||
|
|
# otpauth 格式
|
||
|
|
assert data["otpauth_url"].startswith("otpauth://totp/")
|
||
|
|
# qr_code 是合法 PNG base64
|
||
|
|
assert _is_valid_png_base64(data["qr_code_base64"])
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_bind_start_writes_secret_to_db(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""bind/start 后 DB: mfa_secret 已存,mfa_enabled=False,mfa_bound_at=None"""
|
||
|
|
agent = create_test_agent(user_id="carol_001", name="Carol")
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "carol_001", "Carol")
|
||
|
|
resp = await client.post("/mfa/bind/start", headers=_bearer(token))
|
||
|
|
assert resp.status_code == 200
|
||
|
|
secret_returned = resp.json()["data"]["secret"]
|
||
|
|
|
||
|
|
# 重新从 DB 读取(绕开 session 缓存)
|
||
|
|
stmt = select(Agent).where(Agent.user_id == "carol_001")
|
||
|
|
result = await db_session.execute(stmt)
|
||
|
|
db_agent = result.scalars().first()
|
||
|
|
|
||
|
|
assert db_agent.mfa_secret == secret_returned
|
||
|
|
assert db_agent.mfa_enabled is False
|
||
|
|
assert db_agent.mfa_bound_at is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_bind_start_when_already_enabled_rejected(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""已启用的用户再次 bind/start → 拒绝"""
|
||
|
|
agent = create_test_agent(user_id="dave_001", name="Dave")
|
||
|
|
agent.mfa_secret = pyotp.random_base32()
|
||
|
|
agent.mfa_enabled = True
|
||
|
|
agent.mfa_bound_at = __import__("datetime").datetime.now()
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "dave_001", "Dave")
|
||
|
|
resp = await client.post("/mfa/bind/start", headers=_bearer(token))
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] != 0 # 业务错误
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 3. POST /mfa/bind/confirm — 用 OTP 完成绑定
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFABindConfirm:
|
||
|
|
"""POST /mfa/bind/confirm 行为测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_bind_confirm_correct_code_enables(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""正确 OTP → mfa_enabled=True, mfa_bound_at 有值"""
|
||
|
|
from datetime import datetime
|
||
|
|
|
||
|
|
agent = create_test_agent(user_id="eve_001", name="Eve")
|
||
|
|
secret = pyotp.random_base32()
|
||
|
|
agent.mfa_secret = secret
|
||
|
|
agent.mfa_enabled = False
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
# 生成当前有效 OTP
|
||
|
|
totp = pyotp.TOTP(secret)
|
||
|
|
otp_code = totp.now()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "eve_001", "Eve")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/bind/confirm",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": otp_code},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
assert body["data"]["success"] is True
|
||
|
|
|
||
|
|
# DB 状态
|
||
|
|
stmt = select(Agent).where(Agent.user_id == "eve_001")
|
||
|
|
db_agent = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
assert db_agent.mfa_enabled is True
|
||
|
|
assert db_agent.mfa_bound_at is not None
|
||
|
|
assert isinstance(db_agent.mfa_bound_at, datetime)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_bind_confirm_wrong_code_rejected(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""错误 OTP → 业务失败"""
|
||
|
|
agent = create_test_agent(user_id="frank_001", name="Frank")
|
||
|
|
agent.mfa_secret = pyotp.random_base32()
|
||
|
|
agent.mfa_enabled = False
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "frank_001", "Frank")
|
||
|
|
# 用一个错的 6 位码
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/bind/confirm",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": "000000"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] != 0
|
||
|
|
|
||
|
|
# DB 状态未变
|
||
|
|
stmt = select(Agent).where(Agent.user_id == "frank_001")
|
||
|
|
db_agent = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
assert db_agent.mfa_enabled is False
|
||
|
|
assert db_agent.mfa_bound_at is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_bind_confirm_without_start_rejected(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""没调过 bind/start 直接 confirm → 拒绝"""
|
||
|
|
agent = create_test_agent(user_id="grace_001", name="Grace")
|
||
|
|
# 不设 mfa_secret
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "grace_001", "Grace")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/bind/confirm",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": "123456"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] != 0
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 4. POST /mfa/verify — 验证 + 写 Redis 30 分钟
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFAVerify:
|
||
|
|
"""POST /mfa/verify 行为测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_verify_correct_code_writes_redis(
|
||
|
|
self, client, db_session, mock_redis
|
||
|
|
):
|
||
|
|
"""正确码 → verified=True + Redis 有 key + 1800s TTL"""
|
||
|
|
agent = create_test_agent(user_id="henry_001", name="Henry")
|
||
|
|
secret = pyotp.random_base32()
|
||
|
|
agent.mfa_secret = secret
|
||
|
|
agent.mfa_enabled = True
|
||
|
|
agent.mfa_bound_at = __import__("datetime").datetime.now()
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
otp_code = pyotp.TOTP(secret).now()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "henry_001", "Henry")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/verify",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": otp_code},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
data = body["data"]
|
||
|
|
assert data["verified"] is True
|
||
|
|
assert data["expires_in"] == MFA_VERIFIED_TTL_SECONDS
|
||
|
|
|
||
|
|
# Redis 标记存在
|
||
|
|
key = f"mfa:verified:henry_001"
|
||
|
|
assert key in mock_redis._data, (
|
||
|
|
f"key {key} 不在 mock_redis._data 中: {list(mock_redis._data.keys())}"
|
||
|
|
)
|
||
|
|
assert mock_redis._data[key] == "1"
|
||
|
|
assert mock_redis._ttl.get(key) == MFA_VERIFIED_TTL_SECONDS
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_verify_wrong_code_returns_false(
|
||
|
|
self, client, db_session, mock_redis
|
||
|
|
):
|
||
|
|
"""错误码 → verified=False, Redis 不写"""
|
||
|
|
agent = create_test_agent(user_id="ivy_001", name="Ivy")
|
||
|
|
secret = pyotp.random_base32()
|
||
|
|
agent.mfa_secret = secret
|
||
|
|
agent.mfa_enabled = True
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "ivy_001", "Ivy")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/verify",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": "000000"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
assert body["data"]["verified"] is False
|
||
|
|
|
||
|
|
# Redis 没有标记
|
||
|
|
assert await mock_redis.exists(f"mfa:verified:ivy_001") == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_verify_when_not_bound_returns_false(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""未绑定的用户 verify → verified=False(不抛异常)"""
|
||
|
|
agent = create_test_agent(user_id="jack_001", name="Jack")
|
||
|
|
# 没设 mfa_secret
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "jack_001", "Jack")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/verify",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": "123456"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["data"]["verified"] is False
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 5. POST /mfa/disable — 用户关闭 MFA
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFADisable:
|
||
|
|
"""POST /mfa/disable 行为测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_disable_clears_secret_after_otp(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""正确 OTP 验证后清空 mfa_secret + mfa_enabled=False"""
|
||
|
|
agent = create_test_agent(user_id="karen_001", name="Karen")
|
||
|
|
secret = pyotp.random_base32()
|
||
|
|
agent.mfa_secret = secret
|
||
|
|
agent.mfa_enabled = True
|
||
|
|
agent.mfa_bound_at = __import__("datetime").datetime.now()
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
otp_code = pyotp.TOTP(secret).now()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "karen_001", "Karen")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/disable",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": otp_code},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
assert body["data"]["success"] is True
|
||
|
|
|
||
|
|
# DB 状态
|
||
|
|
stmt = select(Agent).where(Agent.user_id == "karen_001")
|
||
|
|
db_agent = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
assert db_agent.mfa_secret is None
|
||
|
|
assert db_agent.mfa_enabled is False
|
||
|
|
assert db_agent.mfa_bound_at is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_disable_wrong_otp_rejected(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""错误 OTP → 关闭被拒绝"""
|
||
|
|
agent = create_test_agent(user_id="liam_001", name="Liam")
|
||
|
|
secret = pyotp.random_base32()
|
||
|
|
agent.mfa_secret = secret
|
||
|
|
agent.mfa_enabled = True
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
token = await _login_and_get_token(client, "liam_001", "Liam")
|
||
|
|
resp = await client.post(
|
||
|
|
"/mfa/disable",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": "000000"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] != 0
|
||
|
|
|
||
|
|
# DB 状态未变
|
||
|
|
stmt = select(Agent).where(Agent.user_id == "liam_001")
|
||
|
|
db_agent = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
assert db_agent.mfa_enabled is True
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_status_after_disable_is_unbound(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""disable 之后 GET /status → bound=false"""
|
||
|
|
agent = create_test_agent(user_id="mia_001", name="Mia")
|
||
|
|
secret = pyotp.random_base32()
|
||
|
|
agent.mfa_secret = secret
|
||
|
|
agent.mfa_enabled = True
|
||
|
|
agent.mfa_bound_at = __import__("datetime").datetime.now()
|
||
|
|
db_session.add(agent)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
otp_code = pyotp.TOTP(secret).now()
|
||
|
|
token = await _login_and_get_token(client, "mia_001", "Mia")
|
||
|
|
|
||
|
|
# 先 disable
|
||
|
|
await client.post(
|
||
|
|
"/mfa/disable",
|
||
|
|
headers=_bearer(token),
|
||
|
|
json={"otp_code": otp_code},
|
||
|
|
)
|
||
|
|
|
||
|
|
# 再查 status
|
||
|
|
resp = await client.get("/mfa/status", headers=_bearer(token))
|
||
|
|
assert resp.status_code == 200
|
||
|
|
data = resp.json()["data"]
|
||
|
|
assert data["bound"] is False
|
||
|
|
assert data["enabled"] is False
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 6. POST /admin/mfa/reset/{employee_id} — 管理员重置
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFAAdminReset:
|
||
|
|
"""POST /admin/mfa/reset/{employee_id} 行为测试"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_admin_reset_clears_target_user(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""管理员重置目标用户 → 该用户 mfa_secret 清空,mfa_enabled=False"""
|
||
|
|
# 1. 预置目标用户(已绑定 MFA)
|
||
|
|
target = create_test_agent(user_id="nina_001", name="Nina")
|
||
|
|
target.mfa_secret = pyotp.random_base32()
|
||
|
|
target.mfa_enabled = True
|
||
|
|
target.mfa_bound_at = __import__("datetime").datetime.now()
|
||
|
|
db_session.add(target)
|
||
|
|
|
||
|
|
# 2. 预置管理员(并分配 admin 角色到 user_roles 表)
|
||
|
|
admin = create_test_agent(user_id="oliver_admin", name="Oliver")
|
||
|
|
admin.role = "admin"
|
||
|
|
db_session.add(admin)
|
||
|
|
await db_session.flush()
|
||
|
|
await _seed_admin_role(db_session, "oliver_admin", "admin")
|
||
|
|
|
||
|
|
# 3. 管理员登录拿 token
|
||
|
|
admin_token = await _login_and_get_token(
|
||
|
|
client, "oliver_admin", "Oliver"
|
||
|
|
)
|
||
|
|
|
||
|
|
# 4. 调用 admin reset
|
||
|
|
resp = await client.post(
|
||
|
|
"/admin/mfa/reset/nina_001",
|
||
|
|
headers=_bearer(admin_token),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == 0
|
||
|
|
assert body["data"]["success"] is True
|
||
|
|
|
||
|
|
# 5. DB 状态:目标用户被清空
|
||
|
|
stmt = select(Agent).where(Agent.user_id == "nina_001")
|
||
|
|
target_db = (await db_session.execute(stmt)).scalars().first()
|
||
|
|
assert target_db.mfa_secret is None
|
||
|
|
assert target_db.mfa_enabled is False
|
||
|
|
assert target_db.mfa_bound_at is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_admin_reset_by_non_admin_forbidden(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""非 admin 调用 admin reset → 403"""
|
||
|
|
# 预置目标用户
|
||
|
|
target = create_test_agent(user_id="peter_001", name="Peter")
|
||
|
|
target.mfa_secret = pyotp.random_base32()
|
||
|
|
target.mfa_enabled = True
|
||
|
|
db_session.add(target)
|
||
|
|
|
||
|
|
# 预置普通坐席(非 admin)
|
||
|
|
normal = create_test_agent(user_id="quinn_agent", name="Quinn")
|
||
|
|
# role 默认就是 "agent"
|
||
|
|
db_session.add(normal)
|
||
|
|
await db_session.flush()
|
||
|
|
|
||
|
|
normal_token = await _login_and_get_token(
|
||
|
|
client, "quinn_agent", "Quinn"
|
||
|
|
)
|
||
|
|
|
||
|
|
resp = await client.post(
|
||
|
|
"/admin/mfa/reset/peter_001",
|
||
|
|
headers=_bearer(normal_token),
|
||
|
|
)
|
||
|
|
|
||
|
|
# 业务码校验:非 admin 应被拒绝(AppException 会被全局处理器转 HTTP 200 + 业务码)
|
||
|
|
assert resp.status_code == 200, (
|
||
|
|
f"预期 200(被全局处理器统一),实际 {resp.status_code}: {resp.text}"
|
||
|
|
)
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] == ErrorCode.FORBIDDEN.value, (
|
||
|
|
f"预期 FORBIDDEN 业务码 {ErrorCode.FORBIDDEN.value},"
|
||
|
|
f"实际 {body['code']}: {body}"
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_admin_reset_nonexistent_user_404(
|
||
|
|
self, client, db_session
|
||
|
|
):
|
||
|
|
"""管理员重置不存在的用户 → 404 业务码"""
|
||
|
|
admin = create_test_agent(user_id="rachel_admin", name="Rachel")
|
||
|
|
admin.role = "admin"
|
||
|
|
db_session.add(admin)
|
||
|
|
await db_session.flush()
|
||
|
|
await _seed_admin_role(db_session, "rachel_admin", "admin")
|
||
|
|
|
||
|
|
admin_token = await _login_and_get_token(
|
||
|
|
client, "rachel_admin", "Rachel"
|
||
|
|
)
|
||
|
|
|
||
|
|
resp = await client.post(
|
||
|
|
"/admin/mfa/reset/ghost_user_999",
|
||
|
|
headers=_bearer(admin_token),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert resp.status_code == 200
|
||
|
|
body = resp.json()
|
||
|
|
assert body["code"] != 0 # 业务错误(AGENT_NOT_FOUND)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 7. service 层单元测试(轻量覆盖)
|
||
|
|
# =============================================================================
|
||
|
|
class TestMFAServiceUnit:
|
||
|
|
"""MFAService 静态方法直接测试(不依赖 DB/Redis)"""
|
||
|
|
|
||
|
|
def test_generate_secret_format(self):
|
||
|
|
"""generate_secret 返回 32 位 base32"""
|
||
|
|
s = MFAService.generate_secret()
|
||
|
|
assert isinstance(s, str)
|
||
|
|
assert len(s) == 32
|
||
|
|
# base32 字符集
|
||
|
|
import string
|
||
|
|
valid_chars = set(string.ascii_uppercase + "234567")
|
||
|
|
assert all(c in valid_chars for c in s)
|
||
|
|
|
||
|
|
def test_verify_code_with_correct_code(self):
|
||
|
|
"""verify_code 用同一 secret 的当前码 → True"""
|
||
|
|
secret = MFAService.generate_secret()
|
||
|
|
totp = pyotp.TOTP(secret)
|
||
|
|
code = totp.now()
|
||
|
|
assert MFAService.verify_code(secret, code) is True
|
||
|
|
|
||
|
|
def test_verify_code_with_wrong_code(self):
|
||
|
|
"""verify_code 用错的码 → False"""
|
||
|
|
secret = MFAService.generate_secret()
|
||
|
|
assert MFAService.verify_code(secret, "000000") is False
|
||
|
|
|
||
|
|
def test_verify_code_with_empty_secret(self):
|
||
|
|
"""verify_code 空 secret → False(不抛异常)"""
|
||
|
|
assert MFAService.verify_code("", "123456") is False
|
||
|
|
assert MFAService.verify_code(None, "123456") is False
|
||
|
|
|
||
|
|
def test_start_binding_returns_all_three(self):
|
||
|
|
"""start_binding 返回 (secret, otpauth_url, qr_base64)"""
|
||
|
|
secret, otpauth_url, qr_b64 = MFAService.start_binding("test_user")
|
||
|
|
assert isinstance(secret, str) and len(secret) == 32
|
||
|
|
assert otpauth_url.startswith("otpauth://totp/")
|
||
|
|
# qrcode base64 解码后是 PNG
|
||
|
|
raw = base64.b64decode(qr_b64)
|
||
|
|
assert raw[:8] == b"\x89PNG\r\n\x1a\n"
|