update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
+120
-20
@@ -662,10 +662,10 @@ async def generate_character_stream(
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(request.project_id, user_id, db)
|
||||
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 0)
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 1)
|
||||
|
||||
# 获取已存在的角色列表
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 10)
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 2)
|
||||
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
@@ -757,7 +757,7 @@ async def generate_character_stream(
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 3)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
||||
@@ -768,11 +768,14 @@ async def generate_character_stream(
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
|
||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||
|
||||
try:
|
||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
@@ -789,7 +792,7 @@ async def generate_character_stream(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # 减少为1轮,避免超时
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -797,22 +800,119 @@ async def generate_character_stream(
|
||||
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
if result.get('tool_calls_made', 0) > 0:
|
||||
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
|
||||
finish_reason = result.get('finish_reason', '')
|
||||
tool_calls_made = result.get('tool_calls_made', 0)
|
||||
|
||||
# 🔧 修复:检查工具调用是否真正成功
|
||||
if tool_calls_made > 0:
|
||||
if finish_reason == 'tool_error':
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||
# 工具调用失败,重新用基础模式生成
|
||||
ai_response = ""
|
||||
elif not ai_response.strip():
|
||||
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||
# 工具调用成功但返回空内容,重新生成
|
||||
ai_response = ""
|
||||
else:
|
||||
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||
chunk_size = 50
|
||||
for i in range(0, len(ai_response), chunk_size):
|
||||
chunk = ai_response[i:i+chunk_size]
|
||||
chunk_count += 1
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 3 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||
)
|
||||
|
||||
# 跳过后续的流式生成
|
||||
ai_response = result.get('content', '')
|
||||
else:
|
||||
ai_response = result
|
||||
|
||||
# 如果MCP调用失败或返回空,继续走流式生成
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.info(f"🔄 开始流式生成...")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as mcp_error:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.debug(f"未登录用户,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
@@ -823,7 +923,7 @@ async def generate_character_stream(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
@@ -836,7 +936,7 @@ async def generate_character_stream(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("创建角色记录...", 75)
|
||||
yield await SSEResponse.send_progress("创建角色记录...", 97)
|
||||
|
||||
# 转换traits
|
||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
|
||||
|
||||
# 如果是组织,创建Organization详情
|
||||
if is_organization:
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||
|
||||
org_check = await db.execute(
|
||||
select(Organization).where(Organization.character_id == character.id)
|
||||
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
|
||||
generated_content=ai_response,
|
||||
model=user_ai_service.default_model
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
Reference in New Issue
Block a user