feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
@@ -1,13 +1,231 @@
|
||||
"""Server-Sent Events (SSE) 响应工具类"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ProgressStage(Enum):
|
||||
"""标准化进度阶段枚举"""
|
||||
# 初始化阶段 (0-5%)
|
||||
INIT = "init"
|
||||
# 加载数据阶段 (5-15%)
|
||||
LOADING = "loading"
|
||||
# 准备提示词阶段 (15-20%)
|
||||
PREPARING = "preparing"
|
||||
# AI生成阶段 (20-85%)
|
||||
GENERATING = "generating"
|
||||
# 解析数据阶段 (85-92%)
|
||||
PARSING = "parsing"
|
||||
# 保存数据阶段 (92-98%)
|
||||
SAVING = "saving"
|
||||
# 完成阶段 (100%)
|
||||
COMPLETE = "complete"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageConfig:
|
||||
"""阶段配置"""
|
||||
start: int # 起始进度
|
||||
end: int # 结束进度
|
||||
default_message: str # 默认消息
|
||||
|
||||
|
||||
# 标准进度阶段配置
|
||||
STAGE_CONFIGS: Dict[ProgressStage, StageConfig] = {
|
||||
ProgressStage.INIT: StageConfig(0, 5, "开始处理..."),
|
||||
ProgressStage.LOADING: StageConfig(5, 15, "加载数据中..."),
|
||||
ProgressStage.PREPARING: StageConfig(15, 20, "准备AI提示词..."),
|
||||
ProgressStage.GENERATING: StageConfig(20, 85, "AI生成中..."),
|
||||
ProgressStage.PARSING: StageConfig(85, 92, "解析数据..."),
|
||||
ProgressStage.SAVING: StageConfig(92, 98, "保存到数据库..."),
|
||||
ProgressStage.COMPLETE: StageConfig(100, 100, "完成!"),
|
||||
}
|
||||
|
||||
|
||||
class WizardProgressTracker:
|
||||
"""
|
||||
向导进度追踪器 - 标准化管理SSE进度推送
|
||||
|
||||
使用示例:
|
||||
tracker = WizardProgressTracker("世界观")
|
||||
yield await tracker.start()
|
||||
yield await tracker.loading("加载项目信息")
|
||||
yield await tracker.preparing()
|
||||
async for chunk in ai_stream:
|
||||
yield await tracker.generating_chunk(chunk, len(accumulated))
|
||||
yield await tracker.parsing()
|
||||
yield await tracker.saving("保存世界观数据")
|
||||
yield await tracker.complete()
|
||||
"""
|
||||
|
||||
def __init__(self, task_name: str = "任务"):
|
||||
"""
|
||||
初始化进度追踪器
|
||||
|
||||
Args:
|
||||
task_name: 任务名称,用于消息前缀
|
||||
"""
|
||||
self.task_name = task_name
|
||||
self.current_stage = ProgressStage.INIT
|
||||
self.current_progress = 0
|
||||
self._last_generating_progress = 20 # 生成阶段的最后进度值
|
||||
|
||||
def _get_stage_progress(
|
||||
self,
|
||||
stage: ProgressStage,
|
||||
sub_progress: float = 0.0
|
||||
) -> int:
|
||||
"""
|
||||
计算阶段内的进度值
|
||||
|
||||
Args:
|
||||
stage: 当前阶段
|
||||
sub_progress: 阶段内子进度 (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
总进度值 (0-100)
|
||||
"""
|
||||
config = STAGE_CONFIGS[stage]
|
||||
if sub_progress <= 0:
|
||||
return config.start
|
||||
if sub_progress >= 1:
|
||||
return config.end
|
||||
return config.start + int((config.end - config.start) * sub_progress)
|
||||
|
||||
async def start(self, message: str = None) -> str:
|
||||
"""开始阶段"""
|
||||
self.current_stage = ProgressStage.INIT
|
||||
self.current_progress = 0
|
||||
msg = message or f"开始生成{self.task_name}..."
|
||||
return await SSEResponse.send_progress(msg, 0, "processing")
|
||||
|
||||
async def loading(self, message: str = None, sub_progress: float = 0.5) -> str:
|
||||
"""加载数据阶段"""
|
||||
self.current_stage = ProgressStage.LOADING
|
||||
progress = self._get_stage_progress(ProgressStage.LOADING, sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or STAGE_CONFIGS[ProgressStage.LOADING].default_message
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def preparing(self, message: str = None) -> str:
|
||||
"""准备提示词阶段"""
|
||||
self.current_stage = ProgressStage.PREPARING
|
||||
progress = self._get_stage_progress(ProgressStage.PREPARING, 0.5)
|
||||
self.current_progress = progress
|
||||
msg = message or STAGE_CONFIGS[ProgressStage.PREPARING].default_message
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def generating(
|
||||
self,
|
||||
current_chars: int = 0,
|
||||
estimated_total: int = 5000,
|
||||
message: str = None,
|
||||
retry_count: int = 0,
|
||||
max_retries: int = 3
|
||||
) -> str:
|
||||
"""
|
||||
AI生成阶段进度更新
|
||||
|
||||
Args:
|
||||
current_chars: 当前已生成字符数
|
||||
estimated_total: 预估总字符数
|
||||
message: 自定义消息
|
||||
retry_count: 当前重试次数
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
self.current_stage = ProgressStage.GENERATING
|
||||
|
||||
# 计算生成进度 (0.0-1.0)
|
||||
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
|
||||
progress = self._get_stage_progress(ProgressStage.GENERATING, sub_progress)
|
||||
|
||||
# 确保进度单调递增
|
||||
if progress < self._last_generating_progress:
|
||||
progress = self._last_generating_progress
|
||||
else:
|
||||
self._last_generating_progress = progress
|
||||
|
||||
self.current_progress = progress
|
||||
|
||||
# 构建消息
|
||||
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
|
||||
if message:
|
||||
msg = f"{message}{retry_suffix}"
|
||||
else:
|
||||
msg = f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
|
||||
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def generating_chunk(self, chunk: str) -> str:
|
||||
"""发送生成的内容块"""
|
||||
return await SSEResponse.send_chunk(chunk)
|
||||
|
||||
async def parsing(self, message: str = None, sub_progress: float = 0.5) -> str:
|
||||
"""解析数据阶段"""
|
||||
self.current_stage = ProgressStage.PARSING
|
||||
progress = self._get_stage_progress(ProgressStage.PARSING, sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or f"解析{self.task_name}数据..."
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def saving(self, message: str = None, sub_progress: float = 0.5) -> str:
|
||||
"""保存数据阶段"""
|
||||
self.current_stage = ProgressStage.SAVING
|
||||
progress = self._get_stage_progress(ProgressStage.SAVING, sub_progress)
|
||||
self.current_progress = progress
|
||||
msg = message or f"保存{self.task_name}到数据库..."
|
||||
return await SSEResponse.send_progress(msg, progress, "processing")
|
||||
|
||||
async def complete(self, message: str = None) -> str:
|
||||
"""完成阶段"""
|
||||
self.current_stage = ProgressStage.COMPLETE
|
||||
self.current_progress = 100
|
||||
msg = message or f"{self.task_name}生成完成!"
|
||||
return await SSEResponse.send_progress(msg, 100, "success")
|
||||
|
||||
async def warning(self, message: str) -> str:
|
||||
"""发送警告消息(保持当前进度)"""
|
||||
return await SSEResponse.send_progress(
|
||||
f"⚠️ {message}",
|
||||
self.current_progress,
|
||||
"warning"
|
||||
)
|
||||
|
||||
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试") -> str:
|
||||
"""发送重试消息"""
|
||||
return await SSEResponse.send_progress(
|
||||
f"⚠️ {reason}... ({retry_count}/{max_retries})",
|
||||
self.current_progress,
|
||||
"warning"
|
||||
)
|
||||
|
||||
async def error(self, error_message: str, code: int = 500) -> str:
|
||||
"""发送错误消息"""
|
||||
return await SSEResponse.send_error(error_message, code)
|
||||
|
||||
async def result(self, data: Dict[str, Any]) -> str:
|
||||
"""发送结果数据"""
|
||||
return await SSEResponse.send_result(data)
|
||||
|
||||
async def done(self) -> str:
|
||||
"""发送完成信号"""
|
||||
return await SSEResponse.send_done()
|
||||
|
||||
async def heartbeat(self) -> str:
|
||||
"""发送心跳"""
|
||||
return await SSEResponse.send_heartbeat()
|
||||
|
||||
def reset_generating_progress(self):
|
||||
"""重置生成阶段进度(用于重试时)"""
|
||||
self._last_generating_progress = 20
|
||||
|
||||
|
||||
class SSEResponse:
|
||||
"""SSE响应构建器"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user