update:使用标准化SSE统一进度推送逻辑
This commit is contained in:
+23
-30
@@ -7,7 +7,7 @@ import json
|
|||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
|
||||||
from app.models.career import Career, CharacterCareer
|
from app.models.career import Career, CharacterCareer
|
||||||
from app.models.character import Character
|
from app.models.character import Character
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
@@ -190,15 +190,16 @@ async def generate_career_system(
|
|||||||
通过Server-Sent Events返回实时进度信息
|
通过Server-Sent Events返回实时进度信息
|
||||||
"""
|
"""
|
||||||
async def generate() -> AsyncGenerator[str, None]:
|
async def generate() -> AsyncGenerator[str, None]:
|
||||||
|
tracker = WizardProgressTracker("职业体系")
|
||||||
try:
|
try:
|
||||||
# 验证用户权限和项目是否存在
|
# 验证用户权限和项目是否存在
|
||||||
user_id = getattr(http_request.state, 'user_id', None)
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
project = await verify_project_access(project_id, user_id, db)
|
project = await verify_project_access(project_id, user_id, db)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("开始生成新职业...", 0)
|
yield await tracker.start()
|
||||||
|
|
||||||
# 获取已有职业列表
|
# 获取已有职业列表
|
||||||
yield await SSEResponse.send_progress("分析已有职业...", 5)
|
yield await tracker.loading("分析已有职业...", 0.3)
|
||||||
|
|
||||||
existing_careers_result = await db.execute(
|
existing_careers_result = await db.execute(
|
||||||
select(Career).where(Career.project_id == project_id)
|
select(Career).where(Career.project_id == project_id)
|
||||||
@@ -228,7 +229,7 @@ async def generate_career_system(
|
|||||||
existing_careers_text = "\n当前还没有任何职业,这是第一次创建职业体系。"
|
existing_careers_text = "\n当前还没有任何职业,这是第一次创建职业体系。"
|
||||||
|
|
||||||
# 构建项目上下文
|
# 构建项目上下文
|
||||||
yield await SSEResponse.send_progress("分析项目世界观...", 15)
|
yield await tracker.loading("分析项目世界观...", 0.6)
|
||||||
|
|
||||||
project_context = f"""
|
project_context = f"""
|
||||||
项目信息:
|
项目信息:
|
||||||
@@ -253,7 +254,7 @@ async def generate_career_system(
|
|||||||
- 副职业可以更加自由灵活,包含生产、辅助、特殊类型
|
- 副职业可以更加自由灵活,包含生产、辅助、特殊类型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
yield await tracker.preparing("构建AI提示词...")
|
||||||
|
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = f"""{project_context}
|
prompt = f"""{project_context}
|
||||||
@@ -309,14 +310,14 @@ async def generate_career_system(
|
|||||||
7. 只返回纯JSON,不要添加任何解释文字
|
7. 只返回纯JSON,不要添加任何解释文字
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
|
yield await tracker.generating(0, max(3000, len(prompt) * 8), "调用AI生成新职业...")
|
||||||
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用流式生成替代非流式
|
# 使用流式生成替代非流式
|
||||||
ai_response = ""
|
ai_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
last_progress = 10
|
estimated_total = max(3000, len(prompt) * 8)
|
||||||
|
|
||||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
@@ -325,32 +326,24 @@ async def generate_career_system(
|
|||||||
# 发送内容块
|
# 发送内容块
|
||||||
yield await SSEResponse.send_chunk(chunk)
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
# 平滑更新进度(10-90%,AI生成占60%)
|
# 平滑更新进度(避免过于频繁)
|
||||||
# 每10个chunk增加约1%的进度,最多到90%
|
|
||||||
if chunk_count % 10 == 0:
|
if chunk_count % 10 == 0:
|
||||||
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
|
yield await tracker.generating(len(ai_response), estimated_total)
|
||||||
current_progress = min(10 + (chunk_count // 10), 90)
|
|
||||||
if current_progress > last_progress:
|
|
||||||
last_progress = current_progress
|
|
||||||
yield await SSEResponse.send_progress(
|
|
||||||
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
|
|
||||||
current_progress
|
|
||||||
)
|
|
||||||
|
|
||||||
# 心跳
|
# 心跳
|
||||||
if chunk_count % 20 == 0:
|
if chunk_count % 20 == 0:
|
||||||
yield await SSEResponse.send_heartbeat()
|
yield await tracker.heartbeat()
|
||||||
|
|
||||||
except Exception as ai_error:
|
except Exception as ai_error:
|
||||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
yield await tracker.error(f"AI服务调用失败:{str(ai_error)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not ai_response or not ai_response.strip():
|
if not ai_response or not ai_response.strip():
|
||||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
yield await tracker.error("AI服务返回空响应")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("解析AI响应...", 91)
|
yield await tracker.parsing("解析AI响应...", 0.5)
|
||||||
|
|
||||||
# 清洗并解析JSON
|
# 清洗并解析JSON
|
||||||
try:
|
try:
|
||||||
@@ -360,10 +353,10 @@ async def generate_career_system(
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
||||||
logger.error(f" 原始响应预览: {ai_response[:200]}")
|
logger.error(f" 原始响应预览: {ai_response[:200]}")
|
||||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
yield await tracker.error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
|
yield await tracker.saving("保存主职业到数据库...", 0.3)
|
||||||
|
|
||||||
# 保存主职业
|
# 保存主职业
|
||||||
main_careers_created = []
|
main_careers_created = []
|
||||||
@@ -395,7 +388,7 @@ async def generate_career_system(
|
|||||||
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
|
yield await tracker.saving("保存副职业到数据库...", 0.6)
|
||||||
|
|
||||||
# 保存副职业
|
# 保存副职业
|
||||||
sub_careers_created = []
|
sub_careers_created = []
|
||||||
@@ -435,24 +428,24 @@ async def generate_career_system(
|
|||||||
logger.info(f"🎉 新职业生成完成:新增主职业{len(main_careers_created)}个,新增副职业{len(sub_careers_created)}个")
|
logger.info(f"🎉 新职业生成完成:新增主职业{len(main_careers_created)}个,新增副职业{len(sub_careers_created)}个")
|
||||||
logger.info(f" 职业体系总数:主职业{total_main}个,副职业{total_sub}个")
|
logger.info(f" 职业体系总数:主职业{total_main}个,副职业{total_sub}个")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress(f"新职业生成完成!(主职业{total_main}个,副职业{total_sub}个)", 100, "success")
|
yield await tracker.complete(f"新职业生成完成!(主职业{total_main}个,副职业{total_sub}个)")
|
||||||
|
|
||||||
# 发送结果数据
|
# 发送结果数据
|
||||||
yield await SSEResponse.send_result({
|
yield await tracker.result({
|
||||||
"main_careers_count": len(main_careers_created),
|
"main_careers_count": len(main_careers_created),
|
||||||
"sub_careers_count": len(sub_careers_created),
|
"sub_careers_count": len(sub_careers_created),
|
||||||
"main_careers": main_careers_created,
|
"main_careers": main_careers_created,
|
||||||
"sub_careers": sub_careers_created
|
"sub_careers": sub_careers_created
|
||||||
})
|
})
|
||||||
|
|
||||||
yield await SSEResponse.send_done()
|
yield await tracker.done()
|
||||||
|
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
logger.error(f"HTTP异常: {he.detail}")
|
logger.error(f"HTTP异常: {he.detail}")
|
||||||
yield await SSEResponse.send_error(he.detail, he.status_code)
|
yield await tracker.error(he.detail, he.status_code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成职业体系失败: {str(e)}")
|
logger.error(f"生成职业体系失败: {str(e)}")
|
||||||
yield await SSEResponse.send_error(f"生成新职业失败: {str(e)}")
|
yield await tracker.error(f"生成新职业失败: {str(e)}")
|
||||||
|
|
||||||
return create_sse_response(generate())
|
return create_sse_response(generate())
|
||||||
|
|
||||||
@@ -935,4 +928,4 @@ async def remove_sub_career(
|
|||||||
|
|
||||||
logger.info(f"✅ 删除副职业成功:角色{character.name}移除职业{career_id}")
|
logger.info(f"✅ 删除副职业成功:角色{character.name}移除职业{career_id}")
|
||||||
|
|
||||||
return {"message": "副职业删除成功"}
|
return {"message": "副职业删除成功"}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import json
|
|||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
|
||||||
from app.models.character import Character
|
from app.models.character import Character
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
from app.models.generation_history import GenerationHistory
|
from app.models.generation_history import GenerationHistory
|
||||||
@@ -660,15 +660,16 @@ async def generate_character_stream(
|
|||||||
通过Server-Sent Events返回实时进度信息
|
通过Server-Sent Events返回实时进度信息
|
||||||
"""
|
"""
|
||||||
async def generate() -> AsyncGenerator[str, None]:
|
async def generate() -> AsyncGenerator[str, None]:
|
||||||
|
tracker = WizardProgressTracker("角色")
|
||||||
try:
|
try:
|
||||||
# 验证用户权限和项目是否存在
|
# 验证用户权限和项目是否存在
|
||||||
user_id = getattr(http_request.state, 'user_id', None)
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
project = await verify_project_access(request.project_id, user_id, db)
|
project = await verify_project_access(request.project_id, user_id, db)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("开始生成角色...", 1)
|
yield await tracker.start()
|
||||||
|
|
||||||
# 获取已存在的角色列表
|
# 获取已存在的角色列表
|
||||||
yield await SSEResponse.send_progress("获取项目上下文...", 2)
|
yield await tracker.loading("获取项目上下文...", 0.3)
|
||||||
|
|
||||||
existing_chars_result = await db.execute(
|
existing_chars_result = await db.execute(
|
||||||
select(Character)
|
select(Character)
|
||||||
@@ -760,7 +761,8 @@ async def generate_character_stream(
|
|||||||
- 其他要求:{request.requirements or '无'}
|
- 其他要求:{request.requirements or '无'}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("构建AI提示词...", 3)
|
yield await tracker.loading("项目上下文准备完成", 0.7)
|
||||||
|
yield await tracker.preparing("构建AI提示词...")
|
||||||
|
|
||||||
# 获取自定义提示词模板
|
# 获取自定义提示词模板
|
||||||
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
||||||
@@ -771,16 +773,17 @@ async def generate_character_stream(
|
|||||||
user_input=user_input
|
user_input=user_input
|
||||||
)
|
)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
|
yield await tracker.generating(0, max(3000, len(prompt) * 8), "调用AI服务生成角色...")
|
||||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 直接使用 AIService 流式生成
|
# 直接使用 AIService 流式生成
|
||||||
ai_response = ""
|
ai_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
estimated_total = max(3000, len(prompt) * 8)
|
||||||
|
|
||||||
logger.info(f"🎯 开始生成角色(流式模式)...")
|
logger.info(f"🎯 开始生成角色(流式模式)...")
|
||||||
yield await SSEResponse.send_progress("🎯 开始生成角色...", 15)
|
yield await tracker.generating(0, estimated_total, "开始生成角色...")
|
||||||
|
|
||||||
async for chunk in user_ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -802,29 +805,22 @@ async def generate_character_stream(
|
|||||||
current_len = len(ai_response)
|
current_len = len(ai_response)
|
||||||
if current_len >= chunk_count * 500:
|
if current_len >= chunk_count * 500:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
# 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存)
|
yield await tracker.generating(current_len, estimated_total)
|
||||||
# 估算最终字符数约为提示词的8倍,最少3000字符
|
|
||||||
estimated_total = max(3000, len(prompt) * 8)
|
|
||||||
progress = min(15 + int(current_len / estimated_total * 70), 85)
|
|
||||||
yield await SSEResponse.send_progress(
|
|
||||||
f"AI生成角色中... ({current_len}字符)",
|
|
||||||
progress
|
|
||||||
)
|
|
||||||
|
|
||||||
# 心跳
|
# 心跳
|
||||||
if chunk_count % 20 == 0:
|
if chunk_count % 20 == 0:
|
||||||
yield await SSEResponse.send_heartbeat()
|
yield await tracker.heartbeat()
|
||||||
|
|
||||||
except Exception as ai_error:
|
except Exception as ai_error:
|
||||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
yield await tracker.error(f"AI服务调用失败:{str(ai_error)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not ai_response or not ai_response.strip():
|
if not ai_response or not ai_response.strip():
|
||||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
yield await tracker.error("AI服务返回空响应")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
yield await tracker.parsing("解析AI响应...", 0.5)
|
||||||
|
|
||||||
# ✅ 使用统一的 JSON 清洗方法
|
# ✅ 使用统一的 JSON 清洗方法
|
||||||
try:
|
try:
|
||||||
@@ -834,10 +830,10 @@ async def generate_character_stream(
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"❌ 角色JSON解析失败: {e}")
|
logger.error(f"❌ 角色JSON解析失败: {e}")
|
||||||
logger.error(f" 原始响应预览: {ai_response[:200]}")
|
logger.error(f" 原始响应预览: {ai_response[:200]}")
|
||||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
yield await tracker.error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("创建角色记录...", 97)
|
yield await tracker.saving("创建角色记录...", 0.3)
|
||||||
|
|
||||||
# 转换traits
|
# 转换traits
|
||||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||||
@@ -1002,7 +998,7 @@ async def generate_character_stream(
|
|||||||
|
|
||||||
# 如果是组织,创建Organization详情
|
# 如果是组织,创建Organization详情
|
||||||
if is_organization:
|
if is_organization:
|
||||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
yield await tracker.saving("创建组织详情...", 0.6)
|
||||||
|
|
||||||
org_check = await db.execute(
|
org_check = await db.execute(
|
||||||
select(Organization).where(Organization.character_id == character.id)
|
select(Organization).where(Organization.character_id == character.id)
|
||||||
@@ -1169,7 +1165,7 @@ async def generate_character_stream(
|
|||||||
|
|
||||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
yield await tracker.saving("保存生成历史...", 0.9)
|
||||||
|
|
||||||
# 记录生成历史
|
# 记录生成历史
|
||||||
history = GenerationHistory(
|
history = GenerationHistory(
|
||||||
@@ -1185,10 +1181,10 @@ async def generate_character_stream(
|
|||||||
|
|
||||||
logger.info(f"🎉 成功生成角色: {character.name}")
|
logger.info(f"🎉 成功生成角色: {character.name}")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("角色生成完成!", 100, "success")
|
yield await tracker.complete("角色生成完成!")
|
||||||
|
|
||||||
# 发送结果数据
|
# 发送结果数据
|
||||||
yield await SSEResponse.send_result({
|
yield await tracker.result({
|
||||||
"character": {
|
"character": {
|
||||||
"id": character.id,
|
"id": character.id,
|
||||||
"name": character.name,
|
"name": character.name,
|
||||||
@@ -1197,14 +1193,14 @@ async def generate_character_stream(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
yield await SSEResponse.send_done()
|
yield await tracker.done()
|
||||||
|
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
logger.error(f"HTTP异常: {he.detail}")
|
logger.error(f"HTTP异常: {he.detail}")
|
||||||
yield await SSEResponse.send_error(he.detail, he.status_code)
|
yield await tracker.error(he.detail, he.status_code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成角色失败: {str(e)}")
|
logger.error(f"生成角色失败: {str(e)}")
|
||||||
yield await SSEResponse.send_error(f"生成角色失败: {str(e)}")
|
yield await tracker.error(f"生成角色失败: {str(e)}")
|
||||||
|
|
||||||
return create_sse_response(generate())
|
return create_sse_response(generate())
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
|
||||||
from app.models.relationship import Organization, OrganizationMember
|
from app.models.relationship import Organization, OrganizationMember
|
||||||
from app.models.character import Character
|
from app.models.character import Character
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
@@ -442,15 +442,16 @@ async def generate_organization_stream(
|
|||||||
通过Server-Sent Events返回实时进度信息
|
通过Server-Sent Events返回实时进度信息
|
||||||
"""
|
"""
|
||||||
async def generate() -> AsyncGenerator[str, None]:
|
async def generate() -> AsyncGenerator[str, None]:
|
||||||
|
tracker = WizardProgressTracker("组织")
|
||||||
try:
|
try:
|
||||||
# 验证用户权限和项目是否存在
|
# 验证用户权限和项目是否存在
|
||||||
user_id = getattr(http_request.state, 'user_id', None)
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
project = await verify_project_access(gen_request.project_id, user_id, db)
|
project = await verify_project_access(gen_request.project_id, user_id, db)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("开始生成组织...", 0)
|
yield await tracker.start()
|
||||||
|
|
||||||
# 获取已存在的角色和组织列表
|
# 获取已存在的角色和组织列表
|
||||||
yield await SSEResponse.send_progress("获取项目上下文...", 10)
|
yield await tracker.loading("获取项目上下文...", 0.3)
|
||||||
|
|
||||||
existing_chars_result = await db.execute(
|
existing_chars_result = await db.execute(
|
||||||
select(Character)
|
select(Character)
|
||||||
@@ -497,7 +498,8 @@ async def generate_organization_stream(
|
|||||||
- 其他要求:{gen_request.requirements or '无'}
|
- 其他要求:{gen_request.requirements or '无'}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("构建AI提示词...", 5)
|
yield await tracker.loading("项目上下文准备完成", 0.7)
|
||||||
|
yield await tracker.preparing("构建AI提示词...")
|
||||||
|
|
||||||
# 获取自定义提示词模板
|
# 获取自定义提示词模板
|
||||||
template = await PromptService.get_template("SINGLE_ORGANIZATION_GENERATION", user_id, db)
|
template = await PromptService.get_template("SINGLE_ORGANIZATION_GENERATION", user_id, db)
|
||||||
@@ -508,13 +510,14 @@ async def generate_organization_stream(
|
|||||||
user_input=user_input
|
user_input=user_input
|
||||||
)
|
)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("调用AI服务生成组织...", 10)
|
yield await tracker.generating(0, max(3000, len(prompt) * 8), "调用AI服务生成组织...")
|
||||||
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用流式生成替代非流式
|
# 使用流式生成替代非流式
|
||||||
ai_content = ""
|
ai_content = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
estimated_total = max(3000, len(prompt) * 8)
|
||||||
|
|
||||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
@@ -523,28 +526,24 @@ async def generate_organization_stream(
|
|||||||
# 发送内容块
|
# 发送内容块
|
||||||
yield await SSEResponse.send_chunk(chunk)
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
# 定期更新字数(5-95%,AI生成占90%)
|
# 定期更新字数(避免过于频繁)
|
||||||
if chunk_count % 5 == 0:
|
if chunk_count % 5 == 0:
|
||||||
progress = min(10 + (chunk_count // 5), 95)
|
yield await tracker.generating(len(ai_content), estimated_total)
|
||||||
yield await SSEResponse.send_progress(
|
|
||||||
f"AI生成组织中... ({len(ai_content)}字符)",
|
|
||||||
progress
|
|
||||||
)
|
|
||||||
|
|
||||||
# 心跳
|
# 心跳
|
||||||
if chunk_count % 20 == 0:
|
if chunk_count % 20 == 0:
|
||||||
yield await SSEResponse.send_heartbeat()
|
yield await tracker.heartbeat()
|
||||||
|
|
||||||
except Exception as ai_error:
|
except Exception as ai_error:
|
||||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
yield await tracker.error(f"AI服务调用失败:{str(ai_error)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not ai_content or not ai_content.strip():
|
if not ai_content or not ai_content.strip():
|
||||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
yield await tracker.error("AI服务返回空响应")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("解析AI响应...", 90)
|
yield await tracker.parsing("解析AI响应...", 0.5)
|
||||||
|
|
||||||
# ✅ 使用统一的 JSON 清洗方法
|
# ✅ 使用统一的 JSON 清洗方法
|
||||||
try:
|
try:
|
||||||
@@ -554,10 +553,10 @@ async def generate_organization_stream(
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"❌ 组织JSON解析失败: {e}")
|
logger.error(f"❌ 组织JSON解析失败: {e}")
|
||||||
logger.error(f" 原始响应预览: {ai_content[:200]}")
|
logger.error(f" 原始响应预览: {ai_content[:200]}")
|
||||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
yield await tracker.error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("创建组织记录...", 95)
|
yield await tracker.saving("创建组织记录...", 0.3)
|
||||||
|
|
||||||
# 创建角色记录(组织也是角色的一种)
|
# 创建角色记录(组织也是角色的一种)
|
||||||
character = Character(
|
character = Character(
|
||||||
@@ -584,7 +583,7 @@ async def generate_organization_stream(
|
|||||||
|
|
||||||
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
yield await tracker.saving("创建组织详情...", 0.6)
|
||||||
|
|
||||||
# 自动创建Organization详情记录
|
# 自动创建Organization详情记录
|
||||||
organization = Organization(
|
organization = Organization(
|
||||||
@@ -601,7 +600,7 @@ async def generate_organization_stream(
|
|||||||
|
|
||||||
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
yield await tracker.saving("保存生成历史...", 0.9)
|
||||||
|
|
||||||
# 记录生成历史
|
# 记录生成历史
|
||||||
history = GenerationHistory(
|
history = GenerationHistory(
|
||||||
@@ -617,10 +616,10 @@ async def generate_organization_stream(
|
|||||||
|
|
||||||
logger.info(f"🎉 成功生成组织: {character.name}")
|
logger.info(f"🎉 成功生成组织: {character.name}")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("组织生成完成!", 100, "success")
|
yield await tracker.complete("组织生成完成!")
|
||||||
|
|
||||||
# 发送结果数据
|
# 发送结果数据
|
||||||
yield await SSEResponse.send_result({
|
yield await tracker.result({
|
||||||
"character": {
|
"character": {
|
||||||
"id": character.id,
|
"id": character.id,
|
||||||
"name": character.name,
|
"name": character.name,
|
||||||
@@ -629,13 +628,13 @@ async def generate_organization_stream(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
yield await SSEResponse.send_done()
|
yield await tracker.done()
|
||||||
|
|
||||||
except HTTPException as he:
|
except HTTPException as he:
|
||||||
logger.error(f"HTTP异常: {he.detail}")
|
logger.error(f"HTTP异常: {he.detail}")
|
||||||
yield await SSEResponse.send_error(he.detail, he.status_code)
|
yield await tracker.error(he.detail, he.status_code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成组织失败: {str(e)}")
|
logger.error(f"生成组织失败: {str(e)}")
|
||||||
yield await SSEResponse.send_error(f"生成组织失败: {str(e)}")
|
yield await tracker.error(f"生成组织失败: {str(e)}")
|
||||||
|
|
||||||
return create_sse_response(generate())
|
return create_sse_response(generate())
|
||||||
|
|||||||
Reference in New Issue
Block a user