926 lines
35 KiB
Python
926 lines
35 KiB
Python
# =============================================================================
|
||
# 企微IT智能服务台 — H5 OAuth2 认证流程测试
|
||
# =============================================================================
|
||
# 测试覆盖:
|
||
# 1. OAuth2 授权 URL 接口(GET /api/h5/oauth/authorize)
|
||
# 2. OAuth2 回调接口(POST /api/h5/oauth/callback)
|
||
# 3. Token 验证依赖函数 _get_current_employee
|
||
# 4. 获取当前员工信息(GET /api/h5/me)
|
||
# 5. 向后兼容(X-Employee-Id 头降级)
|
||
# 6. 错误处理(WecomService 失败、Redis 不可用)
|
||
# =============================================================================
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from httpx import ASGITransport, AsyncClient
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.database import Base, get_db
|
||
from app.models.conversation import Conversation
|
||
from app.models.funny_phrase import FunnyPhrase
|
||
from tests.conftest import MockRedis, create_test_conversation, test_engine, test_session_factory
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 专用 fixtures:带 h5 API Redis mock 的测试客户端
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@pytest_asyncio.fixture
|
||
async def h5_client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncClient:
|
||
"""提供针对 H5 OAuth2 API 的异步测试客户端。
|
||
|
||
与 conftest.py 的 client fixture 类似,但额外 mock 了
|
||
app.api.h5 模块中的 _get_redis,确保 OAuth2 流程中
|
||
Redis 操作使用内存模拟。
|
||
"""
|
||
async def _override_get_db():
|
||
yield db_session
|
||
|
||
from app.main import create_app
|
||
|
||
app = create_app()
|
||
app.dependency_overrides[get_db] = _override_get_db
|
||
|
||
with patch("app.api.h5._get_redis", return_value=mock_redis):
|
||
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
yield ac
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_redis_fresh() -> MockRedis:
|
||
"""提供干净的模拟 Redis(每个测试独立)。"""
|
||
return MockRedis()
|
||
|
||
|
||
# ===========================================================================
|
||
# 1. OAuth2 授权 URL 接口
|
||
# ===========================================================================
|
||
|
||
class TestOAuthAuthorizeURL:
|
||
"""测试 GET /api/h5/oauth/authorize — 获取企微 OAuth2 授权 URL。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_returns_correct_structure(self, h5_client):
|
||
"""验证返回结构包含 authorize_url 字段。"""
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["code"] == 0
|
||
assert "authorize_url" in data["data"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_contains_correct_base(self, h5_client):
|
||
"""验证授权 URL 以企微 OAuth2 基础地址开头。"""
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
assert url.startswith("https://open.weixin.qq.com/connect/oauth2/authorize")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_contains_appid(self, h5_client):
|
||
"""验证授权 URL 包含 appid 参数(企微 CorpID)。"""
|
||
from app.config import settings
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
# corp_id 来自实际配置(可能是 .env 覆盖后的值)
|
||
assert f"appid={settings.wecom_corp_id}" in url
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_contains_scope_snsapi_base(self, h5_client):
|
||
"""验证授权 URL 使用 snsapi_base 作用域(静默授权)。"""
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
assert "scope=snsapi_base" in url
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_contains_response_type_code(self, h5_client):
|
||
"""验证授权 URL 包含 response_type=code。"""
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
assert "response_type=code" in url
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_contains_wechat_redirect(self, h5_client):
|
||
"""验证授权 URL 末尾包含 #wechat_redirect。"""
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
assert url.endswith("#wechat_redirect")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_with_redirect_uri_param(self, h5_client):
|
||
"""验证传入 redirect_uri 参数时 URL 包含自定义回调地址。"""
|
||
custom_uri = "https://myapp.example.com/h5/"
|
||
response = await h5_client.get(
|
||
"/h5/oauth/authorize",
|
||
params={"redirect_uri": custom_uri},
|
||
)
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
# redirect_uri 需要经过 URL 编码
|
||
from urllib.parse import quote
|
||
encoded = quote(custom_uri, safe="")
|
||
assert f"redirect_uri={encoded}" in url
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_with_host_header(self, h5_client):
|
||
"""验证使用 Host 头构造默认回调地址。"""
|
||
response = await h5_client.get(
|
||
"/h5/oauth/authorize",
|
||
headers={"Host": "myapp.example.com"},
|
||
)
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
# Host 头构造的 URL 应使用 https 协议
|
||
from urllib.parse import quote
|
||
expected_redirect = quote("https://myapp.example.com/h5/", safe="")
|
||
assert f"redirect_uri={expected_redirect}" in url
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_authorize_url_without_redirect_uri_uses_default(self, h5_client):
|
||
"""验证不带 redirect_uri 且无 Host 头时使用配置默认值。"""
|
||
response = await h5_client.get("/h5/oauth/authorize")
|
||
data = response.json()
|
||
url = data["data"]["authorize_url"]
|
||
# 应该仍然返回有效的 URL(使用默认 origin)
|
||
assert "redirect_uri=" in url
|
||
|
||
|
||
# ===========================================================================
|
||
# 2. OAuth2 回调接口
|
||
# ===========================================================================
|
||
|
||
class TestOAuthCallback:
|
||
"""测试 POST /api/h5/oauth/callback — OAuth2 回调处理。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_returns_token_and_employee_info(self, h5_client, mock_redis):
|
||
"""验证 OAuth2 回调返回 token 和员工信息。"""
|
||
# Mock WecomService
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "test_user_001", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "测试员工",
|
||
"department": [1, 2],
|
||
"position": "工程师",
|
||
"avatar": "https://avatar.example.com/test.jpg",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "valid_auth_code"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["code"] == 0
|
||
# 验证返回字段
|
||
assert "token" in data["data"]
|
||
assert data["data"]["employee_id"] == "test_user_001"
|
||
assert data["data"]["employee_name"] == "测试员工"
|
||
assert data["data"]["department"] == "1,2"
|
||
assert data["data"]["position"] == "工程师"
|
||
assert data["data"]["avatar"] == "https://avatar.example.com/test.jpg"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_stores_token_in_redis(self, h5_client, mock_redis):
|
||
"""验证 token 存入 Redis,key 格式为 employee:token:{token}。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "redis_test_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "Redis测试",
|
||
"department": [],
|
||
"position": "",
|
||
"avatar": "",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "valid_auth_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
token = data["data"]["token"]
|
||
|
||
# 验证 Redis 中存在对应的 key
|
||
stored = await mock_redis.get(f"employee:token:{token}")
|
||
assert stored is not None
|
||
assert stored == b"redis_test_user"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_caches_employee_info_in_redis(self, h5_client, mock_redis):
|
||
"""验证员工信息缓存到 Redis。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "cache_test_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "缓存测试",
|
||
"department": [3],
|
||
"position": "经理",
|
||
"avatar": "https://avatar.example.com/cache.jpg",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "valid_auth_code"},
|
||
)
|
||
|
||
# 验证 Redis 中存在员工信息缓存
|
||
cached = await mock_redis.get("employee:info:cache_test_user")
|
||
assert cached is not None
|
||
cached_info = json.loads(cached)
|
||
assert cached_info["employee_id"] == "cache_test_user"
|
||
assert cached_info["employee_name"] == "缓存测试"
|
||
assert cached_info["department"] == "3"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_with_empty_userid_returns_error(self, h5_client, mock_redis):
|
||
"""验证 OAuth2 返回空 UserID 时报错。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "", "user_ticket": ""})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "bad_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] != 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_wecom_service_failure(self, h5_client, mock_redis):
|
||
"""验证 WecomService 调用失败时的错误处理。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(side_effect=Exception("企微API不可用"))
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "will_fail"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] != 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_detail_fetch_failure_still_returns_token(self, h5_client, mock_redis):
|
||
"""验证获取员工详细信息失败时仍返回 token(降级处理)。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "degrade_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(side_effect=Exception("通讯录API失败"))
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "valid_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
# 应该仍然返回成功,token 和 employee_id
|
||
assert data["code"] == 0
|
||
assert "token" in data["data"]
|
||
assert data["data"]["employee_id"] == "degrade_user"
|
||
# 详细信息为空降级
|
||
assert data["data"]["employee_name"] == ""
|
||
assert data["data"]["department"] == ""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_missing_code_field(self, h5_client, mock_redis):
|
||
"""验证缺少 code 字段时返回参数错误。"""
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={},
|
||
)
|
||
# Pydantic 验证失败
|
||
assert response.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_callback_empty_code_field(self, h5_client, mock_redis):
|
||
"""验证空 code 字段时返回参数错误。"""
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": ""},
|
||
)
|
||
# Pydantic min_length=1 验证失败
|
||
assert response.status_code == 422
|
||
|
||
|
||
# ===========================================================================
|
||
# 3. Token 验证依赖函数 _get_current_employee
|
||
# ===========================================================================
|
||
|
||
class TestGetCurrentEmployee:
|
||
"""测试 _get_current_employee 依赖注入函数。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_valid_bearer_token(self, h5_client, mock_redis):
|
||
"""验证有效 Bearer token 返回对应 employee_id。"""
|
||
# 预设 Redis 中的 token 和员工信息缓存
|
||
await mock_redis.setex("employee:token:test_valid_token", 28800, "authed_user_001")
|
||
employee_info = {
|
||
"employee_id": "authed_user_001",
|
||
"employee_name": "认证测试用户",
|
||
"department": "IT部",
|
||
"position": "工程师",
|
||
"mobile": "",
|
||
"email": "",
|
||
"avatar": "",
|
||
}
|
||
await mock_redis.setex(
|
||
"employee:info:authed_user_001",
|
||
28800,
|
||
json.dumps(employee_info, ensure_ascii=False),
|
||
)
|
||
|
||
# 调用需要认证的 /api/h5/me 接口
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer test_valid_token"},
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
# 接口成功返回,说明认证通过
|
||
assert data["code"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_invalid_token_returns_unauthorized(self, h5_client, mock_redis):
|
||
"""验证无效 token 返回 401(业务码 1002)。"""
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer non_existent_token"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] == 1002
|
||
assert "未授权" in data["message"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_missing_authorization_header(self, h5_client, mock_redis):
|
||
"""验证缺少 Authorization 头返回未授权。"""
|
||
response = await h5_client.get("/h5/me")
|
||
|
||
data = response.json()
|
||
assert data["code"] == 1002
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_authorization_header(self, h5_client, mock_redis):
|
||
"""验证空的 Authorization 头返回未授权。"""
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": ""},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] == 1002
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_bearer_prefix_extraction(self, h5_client, mock_redis):
|
||
"""验证 Bearer 前缀正确提取 token。"""
|
||
# 设置 Redis token 和员工信息缓存
|
||
await mock_redis.setex("employee:token:my_token_123", 28800, "prefix_test_user")
|
||
employee_info = {
|
||
"employee_id": "prefix_test_user",
|
||
"employee_name": "前缀测试",
|
||
"department": "",
|
||
"position": "",
|
||
"mobile": "",
|
||
"email": "",
|
||
"avatar": "",
|
||
}
|
||
await mock_redis.setex(
|
||
"employee:info:prefix_test_user",
|
||
28800,
|
||
json.dumps(employee_info, ensure_ascii=False),
|
||
)
|
||
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer my_token_123"},
|
||
)
|
||
|
||
data = response.json()
|
||
# 认证通过,接口返回成功
|
||
assert data["code"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_token_without_bearer_prefix(self, h5_client, mock_redis):
|
||
"""验证不带 Bearer 前缀的 token 也能被识别(兼容)。"""
|
||
await mock_redis.setex("employee:token:raw_token_456", 28800, "raw_token_user")
|
||
employee_info = {
|
||
"employee_id": "raw_token_user",
|
||
"employee_name": "原始Token测试",
|
||
"department": "",
|
||
"position": "",
|
||
"mobile": "",
|
||
"email": "",
|
||
"avatar": "",
|
||
}
|
||
await mock_redis.setex(
|
||
"employee:info:raw_token_user",
|
||
28800,
|
||
json.dumps(employee_info, ensure_ascii=False),
|
||
)
|
||
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "raw_token_456"},
|
||
)
|
||
|
||
data = response.json()
|
||
# 源码中:如果 token 不以 "Bearer " 开头,直接使用整个值
|
||
assert data["code"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_expired_token_returns_unauthorized(self, h5_client, mock_redis):
|
||
"""验证过期 token(Redis 中不存在)返回未授权。"""
|
||
# 不在 Redis 中设置任何 token,模拟过期
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer expired_token_xyz"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] == 1002
|
||
|
||
|
||
# ===========================================================================
|
||
# 4. GET /api/h5/me 接口
|
||
# ===========================================================================
|
||
|
||
class TestGetCurrentEmployeeInfo:
|
||
"""测试 GET /api/h5/me — 获取当前员工详细信息。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_me_returns_employee_info_from_cache(self, h5_client, mock_redis):
|
||
"""验证从 Redis 缓存读取员工信息。"""
|
||
# 预设 token 和缓存信息
|
||
await mock_redis.setex("employee:token:cache_me_token", 28800, "me_cache_user")
|
||
employee_info = {
|
||
"employee_id": "me_cache_user",
|
||
"employee_name": "缓存用户",
|
||
"department": "技术部",
|
||
"position": "开发",
|
||
"mobile": "13800138000",
|
||
"email": "cache@test.com",
|
||
"avatar": "https://avatar.example.com/me.jpg",
|
||
}
|
||
await mock_redis.setex(
|
||
"employee:info:me_cache_user",
|
||
28800,
|
||
json.dumps(employee_info, ensure_ascii=False),
|
||
)
|
||
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer cache_me_token"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] == 0
|
||
assert data["data"]["employee_id"] == "me_cache_user"
|
||
assert data["data"]["employee_name"] == "缓存用户"
|
||
# is_vip 由接口补充
|
||
assert data["data"]["is_vip"] is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_me_falls_back_to_wecom_api(self, h5_client, mock_redis):
|
||
"""验证缓存不存在时从企微 API 获取员工信息。"""
|
||
# 预设 token 但不设缓存
|
||
await mock_redis.setex("employee:token:nocache_me_token", 28800, "me_nocache_user")
|
||
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "API用户",
|
||
"department": [5],
|
||
"position": "测试",
|
||
"avatar": "https://avatar.example.com/api.jpg",
|
||
"mobile": "13900139000",
|
||
"email": "api@test.com",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer nocache_me_token"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] == 0
|
||
assert data["data"]["employee_id"] == "me_nocache_user"
|
||
assert data["data"]["employee_name"] == "API用户"
|
||
assert data["data"]["department"] == "5"
|
||
assert data["data"]["mobile"] == "13900139000"
|
||
assert data["data"]["is_vip"] is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_me_unauthenticated_returns_401(self, h5_client, mock_redis):
|
||
"""验证未认证时 /me 返回 401。"""
|
||
response = await h5_client.get("/h5/me")
|
||
|
||
data = response.json()
|
||
assert data["code"] == 1002
|
||
|
||
|
||
# ===========================================================================
|
||
# 5. 向后兼容
|
||
# ===========================================================================
|
||
|
||
class TestBackwardCompatibility:
|
||
"""测试向后兼容:X-Employee-Id 头降级模式。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_x_employee_id_header_still_works_for_old_endpoints(self, h5_client, db_session, mock_redis):
|
||
"""验证旧版 X-Employee-Id 头仍可用于兼容旧接口。
|
||
|
||
注意:新接口(/h5/me, /h5/oauth/*)使用 Bearer Token,
|
||
但旧端点(如 /h5/conversations/current)使用 _get_current_employee
|
||
也支持旧方式需要看具体端点实现。此处验证旧方式在
|
||
_get_employee_id 中仍然工作。
|
||
"""
|
||
# /h5/user 接口使用 _get_current_employee(需要 Bearer Token)
|
||
# 但 /h5/conversations/current 也用 _get_current_employee
|
||
# 旧版 _get_employee_id 只在特定端点使用
|
||
|
||
# 测试通过 Bearer Token 方式访问 /h5/conversations/current
|
||
await mock_redis.setex("employee:token:compat_token", 28800, "compat_user")
|
||
|
||
# 先创建一个会话
|
||
conv = create_test_conversation(
|
||
employee_id="compat_user",
|
||
status="queued",
|
||
)
|
||
db_session.add(conv)
|
||
await db_session.flush()
|
||
|
||
response = await h5_client.get(
|
||
"/h5/conversations/current",
|
||
headers={"Authorization": "Bearer compat_token"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] == 0
|
||
assert data["data"] is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_old_x_employee_id_header_not_accepted_by_new_auth(self, h5_client, mock_redis):
|
||
"""验证仅用 X-Employee-Id 头(无 Bearer Token)访问新接口返回未授权。
|
||
|
||
新的 _get_current_employee 只认 Bearer Token,
|
||
不认 X-Employee-Id。这是正确的安全行为。
|
||
"""
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"X-Employee-Id": "old_style_user"},
|
||
)
|
||
|
||
data = response.json()
|
||
# 新接口只认 Bearer Token,X-Employee-Id 不应通过认证
|
||
assert data["code"] == 1002
|
||
|
||
|
||
# ===========================================================================
|
||
# 6. 错误处理
|
||
# ===========================================================================
|
||
|
||
class TestErrorHandling:
|
||
"""测试错误处理场景。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_unavailable_during_token_validation(self, h5_client, mock_redis):
|
||
"""验证 Redis 不可用时 token 验证降级返回未授权。"""
|
||
# 模拟 Redis get 抛出异常
|
||
original_get = mock_redis.get
|
||
|
||
async def broken_get(key):
|
||
raise Exception("Redis connection refused")
|
||
|
||
mock_redis.get = broken_get
|
||
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer some_token"},
|
||
)
|
||
|
||
data = response.json()
|
||
# Redis 不可用时应返回未授权
|
||
assert data["code"] == 1002
|
||
|
||
# 恢复
|
||
mock_redis.get = original_get
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_write_failure_during_callback(self, h5_client, mock_redis):
|
||
"""验证 Redis 写入失败时 OAuth2 回调仍能完成(降级处理)。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "redis_fail_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "Redis故障测试",
|
||
"department": [],
|
||
"position": "",
|
||
"avatar": "",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
# Mock Redis setex to fail
|
||
original_setex = mock_redis.setex
|
||
|
||
async def broken_setex(name, time, value):
|
||
raise Exception("Redis write failed")
|
||
|
||
mock_redis.setex = broken_setex
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "valid_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
# Redis 写入失败不应阻塞 OAuth2 回调流程
|
||
# token 仍然返回(虽然不会被持久化)
|
||
assert data["code"] == 0
|
||
assert "token" in data["data"]
|
||
assert data["data"]["employee_id"] == "redis_fail_user"
|
||
|
||
# 恢复
|
||
mock_redis.setex = original_setex
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_wecom_oauth_failure_returns_error(self, h5_client, mock_redis):
|
||
"""验证企微 OAuth2 服务失败时返回错误。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(
|
||
side_effect=Exception("企微API超时")
|
||
)
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "timeout_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
assert data["code"] != 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_me_wecom_api_failure(self, h5_client, mock_redis):
|
||
"""验证 /me 接口企微 API 失败时返回错误。"""
|
||
# 预设 token 但不设缓存
|
||
await mock_redis.setex("employee:token:wecom_fail_token", 28800, "wecom_fail_user")
|
||
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_user_info = AsyncMock(
|
||
side_effect=Exception("通讯录API失败")
|
||
)
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": "Bearer wecom_fail_token"},
|
||
)
|
||
|
||
data = response.json()
|
||
# 缓存不存在 + 企微API失败,应返回错误
|
||
assert data["code"] != 0
|
||
|
||
|
||
# ===========================================================================
|
||
# 7. Token TTL 与格式
|
||
# ===========================================================================
|
||
|
||
class TestTokenTTLAndFormat:
|
||
"""测试 Token TTL 和格式。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_token_stored_with_correct_ttl(self, h5_client, mock_redis):
|
||
"""验证 Token 存入 Redis 时设置了正确的 TTL(8小时=28800秒)。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "ttl_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "TTL测试",
|
||
"department": [],
|
||
"position": "",
|
||
"avatar": "",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "ttl_test_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
token = data["data"]["token"]
|
||
|
||
# 验证 TTL
|
||
ttl = mock_redis._ttl.get(f"employee:token:{token}")
|
||
assert ttl == 28800 # 8小时 = 28800 秒
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_employee_info_cache_has_same_ttl(self, h5_client, mock_redis):
|
||
"""验证员工信息缓存与 Token 使用相同的 TTL。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "info_ttl_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "InfoTTL测试",
|
||
"department": [],
|
||
"position": "",
|
||
"avatar": "",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "info_ttl_code"},
|
||
)
|
||
|
||
# 验证员工信息缓存的 TTL
|
||
info_ttl = mock_redis._ttl.get("employee:info:info_ttl_user")
|
||
assert info_ttl == 28800
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_token_is_urlsafe(self, h5_client, mock_redis):
|
||
"""验证生成的 Token 是 URL-safe 格式(secrets.token_urlsafe)。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "fmt_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "格式测试",
|
||
"department": [],
|
||
"position": "",
|
||
"avatar": "",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "fmt_test_code"},
|
||
)
|
||
|
||
data = response.json()
|
||
token = data["data"]["token"]
|
||
# token 应该是非空字符串
|
||
assert isinstance(token, str)
|
||
assert len(token) > 0
|
||
# URL-safe base64 字符集:A-Z, a-z, 0-9, -, _
|
||
import re
|
||
assert re.match(r'^[A-Za-z0-9_-]+$', token), f"Token '{token}' is not URL-safe"
|
||
|
||
|
||
# ===========================================================================
|
||
# 8. Schema 验证
|
||
# ===========================================================================
|
||
|
||
class TestSchemaValidation:
|
||
"""测试 Pydantic Schema 验证。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_oauth_callback_request_requires_code(self, h5_client, mock_redis):
|
||
"""验证 OAuthCallbackRequest 必须包含 code 字段。"""
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={},
|
||
)
|
||
assert response.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_oauth_callback_request_code_min_length(self, h5_client, mock_redis):
|
||
"""验证 code 字段最小长度为 1。"""
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": ""},
|
||
)
|
||
assert response.status_code == 422
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_oauth_callback_request_valid_code(self, h5_client, mock_redis):
|
||
"""验证有效的 code 字段格式被接受。"""
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "schema_user", "user_ticket": ""})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={"name": "", "department": [], "position": "", "avatar": ""})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "valid_code_here"},
|
||
)
|
||
# 请求格式正确,应返回 200(非 422)
|
||
assert response.status_code == 200
|
||
|
||
|
||
# ===========================================================================
|
||
# 9. 端到端 OAuth2 流程
|
||
# ===========================================================================
|
||
|
||
class TestOAuth2EndToEnd:
|
||
"""测试完整的 OAuth2 认证流程。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_full_oauth2_flow(self, h5_client, mock_redis):
|
||
"""验证完整 OAuth2 流程:获取授权URL → 回调获取token → 用token访问/me。"""
|
||
# Step 1: 获取授权 URL
|
||
auth_response = await h5_client.get("/h5/oauth/authorize")
|
||
assert auth_response.json()["code"] == 0
|
||
auth_url = auth_response.json()["data"]["authorize_url"]
|
||
assert "snsapi_base" in auth_url
|
||
|
||
# Step 2: 模拟回调获取 token
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={
|
||
"userid": "e2e_user",
|
||
"user_ticket": "",
|
||
})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "端到端用户",
|
||
"department": [1, 2],
|
||
"position": "架构师",
|
||
"avatar": "https://avatar.example.com/e2e.jpg",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
callback_response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "e2e_auth_code"},
|
||
)
|
||
|
||
callback_data = callback_response.json()
|
||
assert callback_data["code"] == 0
|
||
token = callback_data["data"]["token"]
|
||
assert token # token 非空
|
||
|
||
# Step 3: 使用 token 访问 /me
|
||
me_response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": f"Bearer {token}"},
|
||
)
|
||
me_data = me_response.json()
|
||
assert me_data["code"] == 0
|
||
assert me_data["data"]["employee_id"] == "e2e_user"
|
||
assert me_data["data"]["employee_name"] == "端到端用户"
|
||
assert me_data["data"]["is_vip"] is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_full_flow_with_cached_info(self, h5_client, mock_redis):
|
||
"""验证 OAuth2 流程完成后,后续 /me 请求从缓存读取。"""
|
||
# Step 1: 模拟回调
|
||
mock_wecom = AsyncMock()
|
||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={
|
||
"userid": "cached_flow_user",
|
||
"user_ticket": "",
|
||
})
|
||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||
"name": "缓存流程用户",
|
||
"department": [10],
|
||
"position": "产品",
|
||
"avatar": "",
|
||
})
|
||
mock_wecom.close = AsyncMock()
|
||
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
callback_response = await h5_client.post(
|
||
"/h5/oauth/callback",
|
||
json={"code": "cached_flow_code"},
|
||
)
|
||
|
||
token = callback_response.json()["data"]["token"]
|
||
|
||
# Step 2: 第一次访问 /me(应从缓存读取,不再调用 WecomService)
|
||
with patch("app.api.h5.WecomService") as MockWecomClass:
|
||
me_response = await h5_client.get(
|
||
"/h5/me",
|
||
headers={"Authorization": f"Bearer {token}"},
|
||
)
|
||
# WecomService 不应被实例化(因为缓存命中)
|
||
MockWecomClass.assert_not_called()
|
||
|
||
me_data = me_response.json()
|
||
assert me_data["code"] == 0
|
||
assert me_data["data"]["employee_name"] == "缓存流程用户"
|