update:使用标准化SSE统一进度推送逻辑

This commit is contained in:
xiamuceer
2026-01-09 20:57:20 +08:00
parent f4f2caa367
commit 1b32d87581
3 changed files with 69 additions and 81 deletions
+23 -30
View File
@@ -7,7 +7,7 @@ import json
from typing import AsyncGenerator
from app.database import get_db
from app.utils.sse_response import SSEResponse, create_sse_response
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
from app.models.career import Career, CharacterCareer
from app.models.character import Character
from app.models.project import Project
@@ -190,15 +190,16 @@ async def generate_career_system(
通过Server-Sent Events返回实时进度信息
"""
async def generate() -> AsyncGenerator[str, None]:
tracker = WizardProgressTracker("职业体系")
try:
# 验证用户权限和项目是否存在
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成新职业...", 0)
yield await tracker.start()
# 获取已有职业列表
yield await SSEResponse.send_progress("分析已有职业...", 5)
yield await tracker.loading("分析已有职业...", 0.3)
existing_careers_result = await db.execute(
select(Career).where(Career.project_id == project_id)
@@ -228,7 +229,7 @@ async def generate_career_system(
existing_careers_text = "\n当前还没有任何职业,这是第一次创建职业体系。"
# 构建项目上下文
yield await SSEResponse.send_progress("分析项目世界观...", 15)
yield await tracker.loading("分析项目世界观...", 0.6)
project_context = f"""
项目信息:
@@ -253,7 +254,7 @@ async def generate_career_system(
- 副职业可以更加自由灵活,包含生产、辅助、特殊类型
"""
yield await SSEResponse.send_progress("构建AI提示词...", 20)
yield await tracker.preparing("构建AI提示词...")
# 构建提示词
prompt = f"""{project_context}
@@ -309,14 +310,14 @@ async def generate_career_system(
7. 只返回纯JSON,不要添加任何解释文字
"""
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
yield await tracker.generating(0, max(3000, len(prompt) * 8), "调用AI生成新职业...")
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
try:
# 使用流式生成替代非流式
ai_response = ""
chunk_count = 0
last_progress = 10
estimated_total = max(3000, len(prompt) * 8)
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
@@ -325,32 +326,24 @@ async def generate_career_system(
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 平滑更新进度(10-90%AI生成占60%
# 每10个chunk增加约1%的进度,最多到90%
# 平滑更新进度(避免过于频繁
if chunk_count % 10 == 0:
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
current_progress = min(10 + (chunk_count // 10), 90)
if current_progress > last_progress:
last_progress = current_progress
yield await SSEResponse.send_progress(
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
current_progress
)
yield await tracker.generating(len(ai_response), estimated_total)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await tracker.heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
yield await tracker.error(f"AI服务调用失败:{str(ai_error)}")
return
if not ai_response or not ai_response.strip():
yield await SSEResponse.send_error("AI服务返回空响应")
yield await tracker.error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 91)
yield await tracker.parsing("解析AI响应...", 0.5)
# 清洗并解析JSON
try:
@@ -360,10 +353,10 @@ async def generate_career_system(
except json.JSONDecodeError as e:
logger.error(f"❌ 职业体系JSON解析失败: {e}")
logger.error(f" 原始响应预览: {ai_response[:200]}")
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
yield await tracker.error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
yield await tracker.saving("保存主职业到数据库...", 0.3)
# 保存主职业
main_careers_created = []
@@ -395,7 +388,7 @@ async def generate_career_system(
logger.error(f" ❌ 创建主职业失败:{str(e)}")
continue
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
yield await tracker.saving("保存副职业到数据库...", 0.6)
# 保存副职业
sub_careers_created = []
@@ -435,24 +428,24 @@ async def generate_career_system(
logger.info(f"🎉 新职业生成完成:新增主职业{len(main_careers_created)}个,新增副职业{len(sub_careers_created)}")
logger.info(f" 职业体系总数:主职业{total_main}个,副职业{total_sub}")
yield await SSEResponse.send_progress(f"新职业生成完成!(主职业{total_main}个,副职业{total_sub}个)", 100, "success")
yield await tracker.complete(f"新职业生成完成!(主职业{total_main}个,副职业{total_sub}个)")
# 发送结果数据
yield await SSEResponse.send_result({
yield await tracker.result({
"main_careers_count": len(main_careers_created),
"sub_careers_count": len(sub_careers_created),
"main_careers": main_careers_created,
"sub_careers": sub_careers_created
})
yield await SSEResponse.send_done()
yield await tracker.done()
except HTTPException as he:
logger.error(f"HTTP异常: {he.detail}")
yield await SSEResponse.send_error(he.detail, he.status_code)
yield await tracker.error(he.detail, he.status_code)
except Exception as e:
logger.error(f"生成职业体系失败: {str(e)}")
yield await SSEResponse.send_error(f"生成新职业失败: {str(e)}")
yield await tracker.error(f"生成新职业失败: {str(e)}")
return create_sse_response(generate())
@@ -935,4 +928,4 @@ async def remove_sub_career(
logger.info(f"✅ 删除副职业成功:角色{character.name}移除职业{career_id}")
return {"message": "副职业删除成功"}
return {"message": "副职业删除成功"}