feat: 重构MCP功能和AI服务提供者架构

This commit is contained in:
xiamuceer-j
2026-01-09 17:13:19 +08:00
parent f3c224261d
commit 77c5489ff8
49 changed files with 4763 additions and 4307 deletions
+219 -1
View File
@@ -1,13 +1,231 @@
"""Server-Sent Events (SSE) 响应工具类"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional
from enum import Enum
from typing import AsyncGenerator, Dict, Any, Optional, Callable
from dataclasses import dataclass
from fastapi.responses import StreamingResponse
from app.logger import get_logger
logger = get_logger(__name__)
class ProgressStage(Enum):
"""标准化进度阶段枚举"""
# 初始化阶段 (0-5%)
INIT = "init"
# 加载数据阶段 (5-15%)
LOADING = "loading"
# 准备提示词阶段 (15-20%)
PREPARING = "preparing"
# AI生成阶段 (20-85%)
GENERATING = "generating"
# 解析数据阶段 (85-92%)
PARSING = "parsing"
# 保存数据阶段 (92-98%)
SAVING = "saving"
# 完成阶段 (100%)
COMPLETE = "complete"
@dataclass
class StageConfig:
"""阶段配置"""
start: int # 起始进度
end: int # 结束进度
default_message: str # 默认消息
# 标准进度阶段配置
STAGE_CONFIGS: Dict[ProgressStage, StageConfig] = {
ProgressStage.INIT: StageConfig(0, 5, "开始处理..."),
ProgressStage.LOADING: StageConfig(5, 15, "加载数据中..."),
ProgressStage.PREPARING: StageConfig(15, 20, "准备AI提示词..."),
ProgressStage.GENERATING: StageConfig(20, 85, "AI生成中..."),
ProgressStage.PARSING: StageConfig(85, 92, "解析数据..."),
ProgressStage.SAVING: StageConfig(92, 98, "保存到数据库..."),
ProgressStage.COMPLETE: StageConfig(100, 100, "完成!"),
}
class WizardProgressTracker:
"""
向导进度追踪器 - 标准化管理SSE进度推送
使用示例:
tracker = WizardProgressTracker("世界观")
yield await tracker.start()
yield await tracker.loading("加载项目信息")
yield await tracker.preparing()
async for chunk in ai_stream:
yield await tracker.generating_chunk(chunk, len(accumulated))
yield await tracker.parsing()
yield await tracker.saving("保存世界观数据")
yield await tracker.complete()
"""
def __init__(self, task_name: str = "任务"):
"""
初始化进度追踪器
Args:
task_name: 任务名称,用于消息前缀
"""
self.task_name = task_name
self.current_stage = ProgressStage.INIT
self.current_progress = 0
self._last_generating_progress = 20 # 生成阶段的最后进度值
def _get_stage_progress(
self,
stage: ProgressStage,
sub_progress: float = 0.0
) -> int:
"""
计算阶段内的进度值
Args:
stage: 当前阶段
sub_progress: 阶段内子进度 (0.0-1.0)
Returns:
总进度值 (0-100)
"""
config = STAGE_CONFIGS[stage]
if sub_progress <= 0:
return config.start
if sub_progress >= 1:
return config.end
return config.start + int((config.end - config.start) * sub_progress)
async def start(self, message: str = None) -> str:
"""开始阶段"""
self.current_stage = ProgressStage.INIT
self.current_progress = 0
msg = message or f"开始生成{self.task_name}..."
return await SSEResponse.send_progress(msg, 0, "processing")
async def loading(self, message: str = None, sub_progress: float = 0.5) -> str:
"""加载数据阶段"""
self.current_stage = ProgressStage.LOADING
progress = self._get_stage_progress(ProgressStage.LOADING, sub_progress)
self.current_progress = progress
msg = message or STAGE_CONFIGS[ProgressStage.LOADING].default_message
return await SSEResponse.send_progress(msg, progress, "processing")
async def preparing(self, message: str = None) -> str:
"""准备提示词阶段"""
self.current_stage = ProgressStage.PREPARING
progress = self._get_stage_progress(ProgressStage.PREPARING, 0.5)
self.current_progress = progress
msg = message or STAGE_CONFIGS[ProgressStage.PREPARING].default_message
return await SSEResponse.send_progress(msg, progress, "processing")
async def generating(
self,
current_chars: int = 0,
estimated_total: int = 5000,
message: str = None,
retry_count: int = 0,
max_retries: int = 3
) -> str:
"""
AI生成阶段进度更新
Args:
current_chars: 当前已生成字符数
estimated_total: 预估总字符数
message: 自定义消息
retry_count: 当前重试次数
max_retries: 最大重试次数
"""
self.current_stage = ProgressStage.GENERATING
# 计算生成进度 (0.0-1.0)
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
progress = self._get_stage_progress(ProgressStage.GENERATING, sub_progress)
# 确保进度单调递增
if progress < self._last_generating_progress:
progress = self._last_generating_progress
else:
self._last_generating_progress = progress
self.current_progress = progress
# 构建消息
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
if message:
msg = f"{message}{retry_suffix}"
else:
msg = f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
return await SSEResponse.send_progress(msg, progress, "processing")
async def generating_chunk(self, chunk: str) -> str:
"""发送生成的内容块"""
return await SSEResponse.send_chunk(chunk)
async def parsing(self, message: str = None, sub_progress: float = 0.5) -> str:
"""解析数据阶段"""
self.current_stage = ProgressStage.PARSING
progress = self._get_stage_progress(ProgressStage.PARSING, sub_progress)
self.current_progress = progress
msg = message or f"解析{self.task_name}数据..."
return await SSEResponse.send_progress(msg, progress, "processing")
async def saving(self, message: str = None, sub_progress: float = 0.5) -> str:
"""保存数据阶段"""
self.current_stage = ProgressStage.SAVING
progress = self._get_stage_progress(ProgressStage.SAVING, sub_progress)
self.current_progress = progress
msg = message or f"保存{self.task_name}到数据库..."
return await SSEResponse.send_progress(msg, progress, "processing")
async def complete(self, message: str = None) -> str:
"""完成阶段"""
self.current_stage = ProgressStage.COMPLETE
self.current_progress = 100
msg = message or f"{self.task_name}生成完成!"
return await SSEResponse.send_progress(msg, 100, "success")
async def warning(self, message: str) -> str:
"""发送警告消息(保持当前进度)"""
return await SSEResponse.send_progress(
f"⚠️ {message}",
self.current_progress,
"warning"
)
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试") -> str:
"""发送重试消息"""
return await SSEResponse.send_progress(
f"⚠️ {reason}... ({retry_count}/{max_retries})",
self.current_progress,
"warning"
)
async def error(self, error_message: str, code: int = 500) -> str:
"""发送错误消息"""
return await SSEResponse.send_error(error_message, code)
async def result(self, data: Dict[str, Any]) -> str:
"""发送结果数据"""
return await SSEResponse.send_result(data)
async def done(self) -> str:
"""发送完成信号"""
return await SSEResponse.send_done()
async def heartbeat(self) -> str:
"""发送心跳"""
return await SSEResponse.send_heartbeat()
def reset_generating_progress(self):
"""重置生成阶段进度(用于重试时)"""
self._last_generating_progress = 20
class SSEResponse:
"""SSE响应构建器"""