update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词

This commit is contained in:
xiamuceer
2025-12-28 19:35:23 +08:00
parent f32e51b594
commit 89848e2258
40 changed files with 2752 additions and 1824 deletions
+120 -20
View File
@@ -662,10 +662,10 @@ async def generate_character_stream(
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(request.project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成角色...", 0)
yield await SSEResponse.send_progress("开始生成角色...", 1)
# 获取已存在的角色列表
yield await SSEResponse.send_progress("获取项目上下文...", 10)
yield await SSEResponse.send_progress("获取项目上下文...", 2)
existing_chars_result = await db.execute(
select(Character)
@@ -757,7 +757,7 @@ async def generate_character_stream(
- 其他要求:{request.requirements or ''}
"""
yield await SSEResponse.send_progress("构建AI提示词...", 20)
yield await SSEResponse.send_progress("构建AI提示词...", 3)
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
@@ -768,11 +768,14 @@ async def generate_character_stream(
user_input=user_input
)
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
try:
# 🔧 MCP工具增强:静默检查并收集参考资料
ai_response = ""
chunk_count = 0
if user_id:
try:
from app.services.mcp_tool_service import mcp_tool_service
@@ -789,7 +792,7 @@ async def generate_character_stream(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # 减少为1轮,避免超时
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -797,22 +800,119 @@ async def generate_character_stream(
if isinstance(result, dict):
ai_response = result.get('content', '')
if result.get('tool_calls_made', 0) > 0:
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
finish_reason = result.get('finish_reason', '')
tool_calls_made = result.get('tool_calls_made', 0)
# 🔧 修复:检查工具调用是否真正成功
if tool_calls_made > 0:
if finish_reason == 'tool_error':
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
# 工具调用失败,重新用基础模式生成
ai_response = ""
elif not ai_response.strip():
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
# 工具调用成功但返回空内容,重新生成
ai_response = ""
else:
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
# MCP成功且有内容,模拟流式输出(分块发送)
chunk_size = 50
for i in range(0, len(ai_response), chunk_size):
chunk = ai_response[i:i+chunk_size]
chunk_count += 1
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 3 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
)
# 跳过后续的流式生成
ai_response = result.get('content', '')
else:
ai_response = result
# 如果MCP调用失败或返回空,继续走流式生成
if not ai_response or not ai_response.strip():
logger.info(f"🔄 开始流式生成...")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as mcp_error:
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.debug(f"未登录用户,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
@@ -823,7 +923,7 @@ async def generate_character_stream(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 60)
yield await SSEResponse.send_progress("解析AI响应...", 96)
# ✅ 使用统一的 JSON 清洗方法
try:
@@ -836,7 +936,7 @@ async def generate_character_stream(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建角色记录...", 75)
yield await SSEResponse.send_progress("创建角色记录...", 97)
# 转换traits
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
# 如果是组织,创建Organization详情
if is_organization:
yield await SSEResponse.send_progress("创建组织详情...", 85)
yield await SSEResponse.send_progress("创建组织详情...", 98)
org_check = await db.execute(
select(Organization).where(Organization.character_id == character.id)
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
yield await SSEResponse.send_progress("保存生成历史...", 95)
yield await SSEResponse.send_progress("保存生成历史...", 99)
# 记录生成历史
history = GenerationHistory(
project_id=request.project_id,
prompt=prompt,
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
generated_content=ai_response,
model=user_ai_service.default_model
)
db.add(history)