feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
+28
-130
@@ -775,148 +775,46 @@ async def generate_character_stream(
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||
|
||||
try:
|
||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||
# 直接使用 AIService 流式生成
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
available_tools = await mcp_tool_service.get_user_enabled_tools(
|
||||
user_id=user_id,
|
||||
db_session=db
|
||||
)
|
||||
|
||||
# 只在有工具时才调用
|
||||
if available_tools:
|
||||
logger.info(f"🔍 检测到可用MCP工具,尝试收集参考资料...")
|
||||
result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
finish_reason = result.get('finish_reason', '')
|
||||
tool_calls_made = result.get('tool_calls_made', 0)
|
||||
|
||||
# 🔧 修复:检查工具调用是否真正成功
|
||||
if tool_calls_made > 0:
|
||||
if finish_reason == 'tool_error':
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||
# 工具调用失败,重新用基础模式生成
|
||||
ai_response = ""
|
||||
elif not ai_response.strip():
|
||||
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||
# 工具调用成功但返回空内容,重新生成
|
||||
ai_response = ""
|
||||
else:
|
||||
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||
chunk_size = 50
|
||||
for i in range(0, len(ai_response), chunk_size):
|
||||
chunk = ai_response[i:i+chunk_size]
|
||||
chunk_count += 1
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 3 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||
)
|
||||
|
||||
# 跳过后续的流式生成
|
||||
ai_response = result.get('content', '')
|
||||
else:
|
||||
ai_response = result
|
||||
|
||||
# 如果MCP调用失败或返回空,继续走流式生成
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.info(f"🔄 开始流式生成...")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as mcp_error:
|
||||
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"未登录用户,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
logger.info(f"🎯 开始生成角色(流式模式)...")
|
||||
yield await SSEResponse.send_progress("🎯 开始生成角色...", 15)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
tool_choice="required",
|
||||
):
|
||||
# chunk 现在可能是 dict 或 str,提取 content 字段
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
else:
|
||||
content = chunk
|
||||
|
||||
if content:
|
||||
ai_response += content
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
yield await SSEResponse.send_chunk(content)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
# 定期更新进度(每收到约500字符更新一次,避免过于频繁)
|
||||
current_len = len(ai_response)
|
||||
if current_len >= chunk_count * 500:
|
||||
chunk_count += 1
|
||||
# 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存)
|
||||
# 估算最终字符数约为提示词的8倍,最少3000字符
|
||||
estimated_total = max(3000, len(prompt) * 8)
|
||||
progress = min(15 + int(current_len / estimated_total * 70), 85)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
f"AI生成角色中... ({current_len}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
|
||||
Reference in New Issue
Block a user