update:使用标准化SSE统一进度推送逻辑

This commit is contained in:
xiamuceer
2026-01-09 20:57:20 +08:00
parent f4f2caa367
commit 1b32d87581
3 changed files with 69 additions and 81 deletions
+23 -24
View File
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
import json
from app.database import get_db
from app.utils.sse_response import SSEResponse, create_sse_response
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
from app.models.relationship import Organization, OrganizationMember
from app.models.character import Character
from app.models.project import Project
@@ -442,15 +442,16 @@ async def generate_organization_stream(
通过Server-Sent Events返回实时进度信息
"""
async def generate() -> AsyncGenerator[str, None]:
tracker = WizardProgressTracker("组织")
try:
# 验证用户权限和项目是否存在
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(gen_request.project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成组织...", 0)
yield await tracker.start()
# 获取已存在的角色和组织列表
yield await SSEResponse.send_progress("获取项目上下文...", 10)
yield await tracker.loading("获取项目上下文...", 0.3)
existing_chars_result = await db.execute(
select(Character)
@@ -497,7 +498,8 @@ async def generate_organization_stream(
- 其他要求:{gen_request.requirements or ''}
"""
yield await SSEResponse.send_progress("构建AI提示词...", 5)
yield await tracker.loading("项目上下文准备完成", 0.7)
yield await tracker.preparing("构建AI提示词...")
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_ORGANIZATION_GENERATION", user_id, db)
@@ -508,13 +510,14 @@ async def generate_organization_stream(
user_input=user_input
)
yield await SSEResponse.send_progress("调用AI服务生成组织...", 10)
yield await tracker.generating(0, max(3000, len(prompt) * 8), "调用AI服务生成组织...")
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
try:
# 使用流式生成替代非流式
ai_content = ""
chunk_count = 0
estimated_total = max(3000, len(prompt) * 8)
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
@@ -523,28 +526,24 @@ async def generate_organization_stream(
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新字数(5-95%AI生成占90%
# 定期更新字数(避免过于频繁
if chunk_count % 5 == 0:
progress = min(10 + (chunk_count // 5), 95)
yield await SSEResponse.send_progress(
f"AI生成组织中... ({len(ai_content)}字符)",
progress
)
yield await tracker.generating(len(ai_content), estimated_total)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await tracker.heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
yield await tracker.error(f"AI服务调用失败:{str(ai_error)}")
return
if not ai_content or not ai_content.strip():
yield await SSEResponse.send_error("AI服务返回空响应")
yield await tracker.error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 90)
yield await tracker.parsing("解析AI响应...", 0.5)
# ✅ 使用统一的 JSON 清洗方法
try:
@@ -554,10 +553,10 @@ async def generate_organization_stream(
except json.JSONDecodeError as e:
logger.error(f"❌ 组织JSON解析失败: {e}")
logger.error(f" 原始响应预览: {ai_content[:200]}")
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
yield await tracker.error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建组织记录...", 95)
yield await tracker.saving("创建组织记录...", 0.3)
# 创建角色记录(组织也是角色的一种)
character = Character(
@@ -584,7 +583,7 @@ async def generate_organization_stream(
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
yield await SSEResponse.send_progress("创建组织详情...", 98)
yield await tracker.saving("创建组织详情...", 0.6)
# 自动创建Organization详情记录
organization = Organization(
@@ -601,7 +600,7 @@ async def generate_organization_stream(
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
yield await SSEResponse.send_progress("保存生成历史...", 99)
yield await tracker.saving("保存生成历史...", 0.9)
# 记录生成历史
history = GenerationHistory(
@@ -617,10 +616,10 @@ async def generate_organization_stream(
logger.info(f"🎉 成功生成组织: {character.name}")
yield await SSEResponse.send_progress("组织生成完成!", 100, "success")
yield await tracker.complete("组织生成完成!")
# 发送结果数据
yield await SSEResponse.send_result({
yield await tracker.result({
"character": {
"id": character.id,
"name": character.name,
@@ -629,13 +628,13 @@ async def generate_organization_stream(
}
})
yield await SSEResponse.send_done()
yield await tracker.done()
except HTTPException as he:
logger.error(f"HTTP异常: {he.detail}")
yield await SSEResponse.send_error(he.detail, he.status_code)
yield await tracker.error(he.detail, he.status_code)
except Exception as e:
logger.error(f"生成组织失败: {str(e)}")
yield await SSEResponse.send_error(f"生成组织失败: {str(e)}")
yield await tracker.error(f"生成组织失败: {str(e)}")
return create_sse_response(generate())