update:1.更新mcp插件功能,目前只支持remote调用
This commit is contained in:
@@ -823,12 +823,14 @@ async def generate_chapter_content_stream(
|
||||
请求体参数:
|
||||
- style_id: 可选,指定使用的写作风格ID。不提供则不使用任何风格
|
||||
- target_word_count: 可选,目标字数,默认3000字,范围500-10000字
|
||||
- enable_mcp: 可选,是否启用MCP工具增强,默认True
|
||||
|
||||
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
|
||||
以避免流式响应期间的连接泄漏问题
|
||||
"""
|
||||
style_id = generate_request.style_id
|
||||
target_word_count = generate_request.target_word_count or 3000
|
||||
enable_mcp = generate_request.enable_mcp if hasattr(generate_request, 'enable_mcp') else True
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
@@ -1002,7 +1004,60 @@ async def generate_chapter_content_stream(
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格和记忆增强
|
||||
# 🔧 MCP工具增强:收集章节参考资料
|
||||
mcp_reference_materials = ""
|
||||
if enable_mcp and current_user_id:
|
||||
try:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '🔍 尝试使用MCP工具收集参考资料...', 'progress': 28}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》创作第{current_chapter.chapter_number}章《{current_chapter.title}》。
|
||||
|
||||
【章节大纲】
|
||||
{outline.content if outline else current_chapter.summary or '暂无大纲'}
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
- 时代背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助创作更真实、更有深度的章节内容。
|
||||
你可以查询:
|
||||
1. 该章节涉及的历史事件或时代背景
|
||||
2. 地理环境和场景描写参考
|
||||
3. 相关领域的专业知识(如武术、科技、魔法等)
|
||||
4. 文化习俗和生活细节
|
||||
|
||||
请根据章节内容,有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=current_user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
tool_count = planning_result["tool_calls_made"]
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'✅ MCP工具调用成功({tool_count}次)', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': 'ℹ️ 未使用MCP工具(无可用工具或不需要)', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': '⚠️ MCP工具暂时不可用,使用基础模式', 'progress': 32}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格、记忆增强和MCP参考资料
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
@@ -1021,7 +1076,8 @@ async def generate_chapter_content_stream(
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context
|
||||
memory_context=memory_context,
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
@@ -1040,9 +1096,13 @@ async def generate_chapter_content_stream(
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context
|
||||
memory_context=memory_context,
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
if mcp_reference_materials:
|
||||
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 流式生成内容
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""角色管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
@@ -221,6 +221,7 @@ async def delete_character(
|
||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成角色")
|
||||
async def generate_character(
|
||||
request: CharacterGenerateRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -294,18 +295,42 @@ async def generate_character(
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
# 调用AI生成角色
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色")
|
||||
# 获取user_id用于MCP工具调用
|
||||
user_id = http_request.state.user_id if hasattr(http_request.state, 'user_id') else 'default_user'
|
||||
|
||||
# 调用AI生成角色(支持MCP工具)
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP)")
|
||||
logger.info(f" - 角色名:{request.name or 'AI生成'}")
|
||||
logger.info(f" - 角色定位:{request.role_type}")
|
||||
logger.info(f" - 背景设定:{request.background or '无'}")
|
||||
logger.info(f" - AI提供商:{user_ai_service.api_provider}")
|
||||
logger.info(f" - AI模型:{user_ai_service.default_model}")
|
||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||
logger.info(f" - 用户ID:{user_id}")
|
||||
|
||||
try:
|
||||
ai_response = await user_ai_service.generate_text(prompt=prompt)
|
||||
logger.info(f"✅ AI响应接收完成,长度:{len(ai_response) if ai_response else 0} 字符")
|
||||
# 使用支持MCP的生成方法
|
||||
result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None, # 使用AIService初始化时的配置
|
||||
model=None # 使用AIService初始化时的配置
|
||||
)
|
||||
|
||||
# 提取内容
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
logger.info(f"✅ AI响应接收完成(MCP增强),长度:{len(ai_response)} 字符")
|
||||
if result.get('tool_calls'):
|
||||
logger.info(f" - 工具调用:{len(result['tool_calls'])} 次")
|
||||
else:
|
||||
ai_response = result
|
||||
logger.info(f"✅ AI响应接收完成,长度:{len(ai_response) if ai_response else 0} 字符")
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
raise HTTPException(
|
||||
@@ -559,7 +584,7 @@ async def generate_character(
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
|
||||
model=user_ai_service.default_model
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
@@ -0,0 +1,862 @@
|
||||
"""MCP插件管理API"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.schemas.mcp_plugin import (
|
||||
MCPPluginCreate,
|
||||
MCPPluginSimpleCreate,
|
||||
MCPPluginUpdate,
|
||||
MCPPluginResponse,
|
||||
MCPToolCall,
|
||||
MCPTestResult
|
||||
)
|
||||
import json
|
||||
from app.user_manager import User
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.models.settings import Settings as UserSettings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
||||
|
||||
|
||||
def require_login(request: Request) -> User:
|
||||
"""依赖:要求用户已登录"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
|
||||
@router.get("", response_model=List[MCPPluginResponse])
|
||||
async def list_plugins(
|
||||
enabled_only: bool = Query(False, description="只返回启用的插件"),
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用户的所有MCP插件
|
||||
"""
|
||||
query = select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(MCPPlugin.enabled == True)
|
||||
|
||||
if category:
|
||||
query = query.where(MCPPlugin.category == category)
|
||||
|
||||
query = query.order_by(MCPPlugin.sort_order, MCPPlugin.created_at)
|
||||
|
||||
result = await db.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 查询插件列表,共 {len(plugins)} 个")
|
||||
return plugins
|
||||
|
||||
|
||||
@router.post("", response_model=MCPPluginResponse)
|
||||
async def create_plugin(
|
||||
data: MCPPluginCreate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新的MCP插件
|
||||
"""
|
||||
# 检查插件名是否已存在
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user.user_id,
|
||||
MCPPlugin.plugin_name == data.plugin_name
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"插件名已存在: {data.plugin_name}")
|
||||
|
||||
# 创建插件数据
|
||||
plugin_data = data.model_dump()
|
||||
|
||||
# 如果没有提供display_name,使用plugin_name作为默认值
|
||||
if not plugin_data.get("display_name"):
|
||||
plugin_data["display_name"] = plugin_data["plugin_name"]
|
||||
|
||||
# 创建插件
|
||||
plugin = MCPPlugin(
|
||||
user_id=user.user_id,
|
||||
**plugin_data
|
||||
)
|
||||
|
||||
db.add(plugin)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,加载到注册表
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@router.post("/simple", response_model=MCPPluginResponse)
|
||||
async def create_plugin_simple(
|
||||
data: MCPPluginSimpleCreate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
通过标准MCP配置JSON创建或更新插件(简化版)
|
||||
|
||||
接受格式:
|
||||
{
|
||||
"config_json": '{"mcpServers": {"exa": {"type": "http", "url": "...", "headers": {}}}}',
|
||||
"category": "search"
|
||||
}
|
||||
|
||||
自动从mcpServers中提取插件名称(取第一个键)
|
||||
如果插件已存在,则更新;否则创建新插件
|
||||
"""
|
||||
try:
|
||||
# 解析配置JSON
|
||||
config = json.loads(data.config_json)
|
||||
|
||||
# 验证格式
|
||||
if "mcpServers" not in config:
|
||||
raise HTTPException(status_code=400, detail="配置JSON必须包含mcpServers字段")
|
||||
|
||||
servers = config["mcpServers"]
|
||||
if not servers or len(servers) == 0:
|
||||
raise HTTPException(status_code=400, detail="mcpServers不能为空")
|
||||
|
||||
# 自动提取第一个插件名称
|
||||
plugin_name = list(servers.keys())[0]
|
||||
server_config = servers[plugin_name]
|
||||
|
||||
logger.info(f"从配置中提取插件名称: {plugin_name}")
|
||||
|
||||
# 提取配置
|
||||
server_type = server_config.get("type", "http")
|
||||
|
||||
if server_type not in ["http", "stdio"]:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
|
||||
|
||||
# 检查插件名是否已存在
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user.user_id,
|
||||
MCPPlugin.plugin_name == plugin_name
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
# 构建插件数据
|
||||
plugin_data = {
|
||||
"plugin_name": plugin_name,
|
||||
"display_name": plugin_name,
|
||||
"plugin_type": server_type,
|
||||
"enabled": data.enabled,
|
||||
"category": data.category,
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type == "http":
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail="HTTP类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
plugin_data["args"] = server_config.get("args", [])
|
||||
plugin_data["env"] = server_config.get("env", {})
|
||||
|
||||
if not plugin_data["command"]:
|
||||
raise HTTPException(status_code=400, detail="Stdio类型插件必须提供command字段")
|
||||
|
||||
if existing:
|
||||
# 更新现有插件
|
||||
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
|
||||
|
||||
# 先卸载旧插件
|
||||
if existing.enabled:
|
||||
await mcp_registry.unload_plugin(user.user_id, existing.plugin_name)
|
||||
|
||||
# 更新字段
|
||||
for key, value in plugin_data.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
plugin = existing
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,重新加载
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
|
||||
else:
|
||||
# 创建新插件
|
||||
plugin = MCPPlugin(
|
||||
user_id=user.user_id,
|
||||
**plugin_data
|
||||
)
|
||||
|
||||
db.add(plugin)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,加载到注册表
|
||||
if plugin.enabled:
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
|
||||
|
||||
return plugin
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置JSON格式错误: {str(e)}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建插件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创建插件失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=MCPPluginResponse)
|
||||
async def get_plugin(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取插件详情
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
return plugin
|
||||
|
||||
|
||||
@router.put("/{plugin_id}", response_model=MCPPluginResponse)
|
||||
async def update_plugin(
|
||||
plugin_id: str,
|
||||
data: MCPPluginUpdate,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新插件配置
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(plugin, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果插件已启用,重新加载
|
||||
if plugin.enabled:
|
||||
await mcp_registry.reload_plugin(plugin)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}")
|
||||
async def delete_plugin(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除插件
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 从注册表卸载
|
||||
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 删除数据库记录
|
||||
await db.delete(plugin)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 删除插件: {plugin.plugin_name}")
|
||||
return {"message": "插件已删除", "plugin_name": plugin.plugin_name}
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/toggle", response_model=MCPPluginResponse)
|
||||
async def toggle_plugin(
|
||||
plugin_id: str,
|
||||
enabled: bool = Query(..., description="启用或禁用"),
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
启用或禁用插件
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
plugin.enabled = enabled
|
||||
|
||||
if enabled:
|
||||
# 启用:加载到注册表
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
else:
|
||||
# 禁用:从注册表卸载
|
||||
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
|
||||
plugin.status = "inactive"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
action = "启用" if enabled else "禁用"
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin.plugin_name}")
|
||||
return plugin
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/test", response_model=MCPTestResult)
|
||||
async def test_plugin(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
测试插件连接并调用工具验证功能
|
||||
|
||||
测试流程:
|
||||
1. 测试MCP服务器连接
|
||||
2. 获取工具列表
|
||||
3. 自动选择一个工具进行实际调用测试
|
||||
4. 返回完整测试结果
|
||||
"""
|
||||
import time
|
||||
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
if not plugin.enabled:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件未启用",
|
||||
error="请先启用插件",
|
||||
suggestions=["点击开关按钮启用插件"]
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. 确保插件已加载
|
||||
if not mcp_registry.get_client(user.user_id, plugin.plugin_name):
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件加载失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
|
||||
# 2. 测试连接并获取工具列表
|
||||
test_result = await mcp_registry.test_plugin(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not test_result["success"]:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = test_result.get("error", "连接测试失败")
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
return MCPTestResult(**test_result)
|
||||
|
||||
tools = test_result.get("tools", [])
|
||||
|
||||
if not tools:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "插件没有提供任何工具"
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件没有提供任何工具",
|
||||
error="工具列表为空",
|
||||
response_time_ms=test_result.get("response_time_ms"),
|
||||
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
|
||||
)
|
||||
|
||||
# 3. 使用AI智能选择工具并生成测试参数
|
||||
logger.info(f"使用AI分析工具并生成测试计划...")
|
||||
|
||||
# 获取用户的AI设置
|
||||
settings_result = await db.execute(
|
||||
select(UserSettings).where(UserSettings.user_id == user.user_id)
|
||||
)
|
||||
user_settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not user_settings or not user_settings.api_key:
|
||||
# 如果没有AI配置,回退到简单测试
|
||||
logger.warning("用户未配置AI服务,使用简单连接测试")
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
plugin.last_test_at = datetime.now()
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
|
||||
response_time_ms=test_result.get("response_time_ms"),
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"连接测试: 成功",
|
||||
f"可用工具数: {len(tools)}",
|
||||
"提示: 配置AI服务后可进行智能工具调用测试"
|
||||
]
|
||||
)
|
||||
|
||||
# 使用AI的标准Function Calling机制选择工具
|
||||
ai_service = create_user_ai_service(
|
||||
api_provider=user_settings.api_provider,
|
||||
api_key=user_settings.api_key,
|
||||
api_base_url=user_settings.api_base_url,
|
||||
model_name=user_settings.llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 将MCP工具格式转换为OpenAI Function Calling格式
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
}
|
||||
# 将 inputSchema 转换为 parameters
|
||||
if "inputSchema" in tool:
|
||||
openai_tool["function"]["parameters"] = tool["inputSchema"]
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
|
||||
logger.info(f"转换了 {len(openai_tools)} 个MCP工具为OpenAI格式")
|
||||
logger.info(f"工具列表: {[t['function']['name'] for t in openai_tools]}")
|
||||
|
||||
# 使用标准的Function Calling,将转换后的工具传递给AI
|
||||
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
|
||||
|
||||
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
|
||||
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
|
||||
|
||||
现在开始测试这个插件。"""
|
||||
|
||||
system_prompt = "你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。"
|
||||
|
||||
# 调用AI的Function Calling
|
||||
logger.info(f"📞 准备调用AI Function Calling")
|
||||
logger.info(f" - Provider: {user_settings.api_provider}")
|
||||
logger.info(f" - Model: {user_settings.llm_model}")
|
||||
logger.info(f" - Tools count: {len(openai_tools)}")
|
||||
logger.debug(f" - Tools: {json.dumps(openai_tools, ensure_ascii=False, indent=2)}")
|
||||
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
tools=openai_tools, # 传递转换后的OpenAI格式工具
|
||||
tool_choice="required" # 要求AI必须选择一个工具
|
||||
)
|
||||
|
||||
logger.info(f"📥 收到AI响应")
|
||||
logger.info(f" - Response keys: {list(ai_response.keys())}")
|
||||
logger.debug(f" - Full response: {json.dumps(ai_response, ensure_ascii=False, indent=2)}")
|
||||
|
||||
# 检查AI是否请求调用工具
|
||||
if not ai_response.get("tool_calls"):
|
||||
# AI未调用工具,记录详细信息
|
||||
logger.error(f"❌ AI未返回工具调用")
|
||||
logger.error(f" - Response: {ai_response}")
|
||||
logger.error(f" - Content: {ai_response.get('content', 'N/A')}")
|
||||
logger.error(f" - Finish reason: {ai_response.get('finish_reason', 'N/A')}")
|
||||
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "AI未返回工具调用请求"
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI Function Calling失败",
|
||||
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
"请确认使用的AI模型支持Function Calling",
|
||||
"OpenAI: 需要gpt-4, gpt-3.5-turbo等模型",
|
||||
"Anthropic: 需要claude-3系列模型",
|
||||
f"当前Provider: {user_settings.api_provider}",
|
||||
f"当前模型: {user_settings.llm_model}",
|
||||
f"AI返回内容: {ai_response.get('content', 'N/A')[:100]}"
|
||||
]
|
||||
)
|
||||
|
||||
# 获取第一个工具调用
|
||||
tool_call = ai_response["tool_calls"][0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
# AI返回的arguments可能是JSON字符串,需要解析
|
||||
if isinstance(test_arguments, str):
|
||||
try:
|
||||
test_arguments = json.loads(test_arguments)
|
||||
logger.info(f"✅ 解析AI返回的JSON字符串参数")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析AI参数失败: {e}")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI返回的参数格式错误",
|
||||
error=f"无法解析参数JSON: {str(e)}",
|
||||
tools_count=len(tools),
|
||||
suggestions=["AI返回的参数不是有效的JSON格式"]
|
||||
)
|
||||
|
||||
logger.info(f"🤖 AI通过Function Calling选择的工具: {tool_name}")
|
||||
logger.info(f"📝 AI生成的参数: {test_arguments}")
|
||||
logger.info(f"📝 参数类型: {type(test_arguments).__name__}")
|
||||
|
||||
# 4. 使用AI选择的工具和参数调用MCP工具
|
||||
call_start = time.time()
|
||||
try:
|
||||
tool_result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
tool_name,
|
||||
test_arguments
|
||||
)
|
||||
|
||||
call_end = time.time()
|
||||
call_time = round((call_end - call_start) * 1000, 2)
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
# 6. 测试成功,更新插件状态
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
plugin.last_test_at = datetime.now()
|
||||
plugin.tools = tools # 缓存工具列表
|
||||
await db.commit()
|
||||
|
||||
# 格式化工具结果用于显示
|
||||
result_str = str(tool_result)
|
||||
|
||||
# 如果结果太长,截取前800字符
|
||||
if len(result_str) > 800:
|
||||
result_preview = result_str[:800] + "\n...(结果已截断,完整结果请查看日志)"
|
||||
else:
|
||||
result_preview = result_str
|
||||
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"🤖 AI (Function Calling) 选择: {tool_name}",
|
||||
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
|
||||
f"⏱️ 调用耗时: {call_time}ms",
|
||||
f"📊 返回结果:\n{result_preview}"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as call_error:
|
||||
call_end = time.time()
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
|
||||
|
||||
# 工具调用失败,但连接成功
|
||||
plugin.status = "active" # 仍标记为active,因为连接是成功的
|
||||
plugin.last_error = f"工具调用测试失败: {str(call_error)}"
|
||||
plugin.last_test_at = datetime.now()
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=True, # 连接成功就算测试通过
|
||||
message=f"⚠️ 连接成功,但工具调用失败",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
|
||||
suggestions=[
|
||||
f"✅ 连接测试: 成功",
|
||||
f"❌ 工具调用测试: 失败",
|
||||
f"🤖 AI (Function Calling) 选择: {tool_name}",
|
||||
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
|
||||
f"❌ 错误: {str(call_error)}",
|
||||
"💡 可能原因: API Key无效、参数错误或服务限制"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
total_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 测试失败",
|
||||
response_time_ms=total_time,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
suggestions=["请检查服务器是否在线", "请确认配置正确", "请检查API Key是否有效"]
|
||||
)
|
||||
|
||||
|
||||
def _build_test_arguments(tool_name: str, input_schema: dict, plugin_name: str) -> dict:
|
||||
"""
|
||||
根据工具schema智能构造测试参数
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
input_schema: 输入schema
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
测试参数字典
|
||||
"""
|
||||
# 针对常见MCP工具的默认测试参数
|
||||
test_cases = {
|
||||
# Exa搜索工具
|
||||
"search": {
|
||||
"query": "AI technology",
|
||||
"num_results": 3
|
||||
},
|
||||
"search_and_contents": {
|
||||
"query": "artificial intelligence",
|
||||
"num_results": 2
|
||||
},
|
||||
# Brave搜索
|
||||
"brave_web_search": {
|
||||
"query": "AI news",
|
||||
"count": 3
|
||||
},
|
||||
# Filesystem工具
|
||||
"read_file": {
|
||||
"path": "README.md"
|
||||
},
|
||||
"list_directory": {
|
||||
"path": "."
|
||||
},
|
||||
}
|
||||
|
||||
# 如果有针对特定工具的测试用例,使用它
|
||||
if tool_name in test_cases:
|
||||
logger.info(f"使用预定义测试参数: {test_cases[tool_name]}")
|
||||
return test_cases[tool_name]
|
||||
|
||||
# 否则根据schema自动构造
|
||||
properties = input_schema.get("properties", {})
|
||||
required = input_schema.get("required", [])
|
||||
|
||||
test_args = {}
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
# 只填充必需的参数
|
||||
if prop_name not in required:
|
||||
continue
|
||||
|
||||
prop_type = prop_schema.get("type", "string")
|
||||
|
||||
# 根据参数名称和类型猜测合适的测试值
|
||||
if "query" in prop_name.lower() or "search" in prop_name.lower():
|
||||
test_args[prop_name] = "test query"
|
||||
elif "url" in prop_name.lower():
|
||||
test_args[prop_name] = "https://example.com"
|
||||
elif "path" in prop_name.lower():
|
||||
test_args[prop_name] = "."
|
||||
elif "count" in prop_name.lower() or "limit" in prop_name.lower() or "num" in prop_name.lower():
|
||||
test_args[prop_name] = 3
|
||||
elif prop_type == "string":
|
||||
test_args[prop_name] = "test"
|
||||
elif prop_type == "number" or prop_type == "integer":
|
||||
test_args[prop_name] = 1
|
||||
elif prop_type == "boolean":
|
||||
test_args[prop_name] = True
|
||||
elif prop_type == "array":
|
||||
test_args[prop_name] = []
|
||||
elif prop_type == "object":
|
||||
test_args[prop_name] = {}
|
||||
|
||||
logger.info(f"自动构造测试参数: {test_args}")
|
||||
return test_args
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/tools")
|
||||
async def get_plugin_tools(
|
||||
plugin_id: str,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取插件提供的工具列表
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
if not plugin.enabled:
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tools": tools,
|
||||
"count": len(tools)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/call")
|
||||
async def call_mcp_tool(
|
||||
data: MCPToolCall,
|
||||
user: User = Depends(require_login),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
调用MCP工具
|
||||
"""
|
||||
# 获取插件
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == data.plugin_id,
|
||||
MCPPlugin.user_id == user.user_id
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
if not plugin.enabled:
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 调用工具
|
||||
result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
data.tool_name,
|
||||
data.arguments
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"tool_name": data.tool_name,
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
@@ -38,6 +38,7 @@ class OrganizationGenerateRequest(BaseModel):
|
||||
organization_type: Optional[str] = Field(None, description="组织类型")
|
||||
background: Optional[str] = Field(None, description="组织背景")
|
||||
requirements: Optional[str] = Field(None, description="特殊要求")
|
||||
enable_mcp: bool = Field(True, description="是否启用MCP工具增强(搜索组织架构参考)")
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
||||
|
||||
+255
-23
@@ -404,8 +404,8 @@ async def _generate_new_outline(
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, keep_existing: {request.keep_existing}")
|
||||
"""全新生成大纲(MCP增强版)"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
@@ -418,7 +418,59 @@ async def _generate_new_outline(
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 使用完整提示词
|
||||
# 🔍 MCP工具增强:收集情节设计参考资料
|
||||
mcp_reference_materials = ""
|
||||
if request.enable_mcp:
|
||||
try:
|
||||
logger.info(f"🔍 尝试使用MCP工具收集大纲设计参考资料...")
|
||||
|
||||
# 构建资料收集查询
|
||||
planning_query = f"""你正在为小说《{project.title}》设计完整大纲。
|
||||
项目信息:
|
||||
- 主题:{request.theme or project.theme}
|
||||
- 类型:{request.genre or project.genre}
|
||||
- 章节数:{request.chapter_count}
|
||||
- 叙事视角:{request.narrative_perspective}
|
||||
- 目标字数:{request.target_words}
|
||||
|
||||
世界观设定:
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
|
||||
角色信息:
|
||||
{characters_info or '暂无角色'}
|
||||
|
||||
请搜索:
|
||||
1. 该类型小说的经典情节结构和套路
|
||||
2. 适合该主题的冲突设计思路
|
||||
3. 符合世界观的情节元素和场景设计灵感
|
||||
|
||||
请有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_query,
|
||||
user_id="system", # 全新生成时可能没有用户上下文
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
else:
|
||||
logger.info(f"ℹ️ MCP工具未进行调用,继续正常生成")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,继续使用常规模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
# 使用完整提示词(插入MCP参考资料)
|
||||
prompt = prompt_service.get_complete_outline_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
@@ -431,18 +483,22 @@ async def _generate_new_outline(
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
requirements=request.requirements or ""
|
||||
requirements=request.requirements or "",
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
# 调用AI生成大纲
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 全新生成模式:必须删除旧大纲和章节
|
||||
# 注意:这是"new"模式的核心逻辑,应该始终删除旧数据
|
||||
@@ -463,7 +519,7 @@ async def _generate_new_outline(
|
||||
history = GenerationHistory(
|
||||
project_id=project.id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
generated_content=json.dumps(ai_response, ensure_ascii=False) if isinstance(ai_response, dict) else ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
@@ -571,8 +627,8 @@ async def _continue_outline(
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲 - 分批生成,每批5章(记忆增强版)"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章")
|
||||
"""续写大纲 - 分批生成,每批5章(记忆+MCP增强版)"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}")
|
||||
|
||||
# 分析已有大纲
|
||||
current_chapter_count = len(existing_outlines)
|
||||
@@ -664,7 +720,57 @@ async def _continue_outline(
|
||||
logger.warning(f"⚠️ 记忆上下文构建失败,继续不使用记忆: {str(e)}")
|
||||
memory_context = None
|
||||
|
||||
# 使用标准续写提示词模板(支持记忆增强)
|
||||
# 🔍 MCP工具增强:收集续写参考资料
|
||||
mcp_reference_materials = ""
|
||||
if request.enable_mcp:
|
||||
try:
|
||||
logger.info(f"🔍 第{batch_num + 1}批:尝试使用MCP工具收集续写参考资料...")
|
||||
|
||||
# 构建资料收集查询
|
||||
latest_summary = latest_outlines[-1].content if latest_outlines else ""
|
||||
planning_query = f"""你正在为小说《{project.title}》续写大纲。
|
||||
当前进度:已有{len(latest_outlines)}章,即将续写第{current_start_chapter}-{current_start_chapter + current_batch_size - 1}章
|
||||
|
||||
项目信息:
|
||||
- 主题:{request.theme or project.theme}
|
||||
- 类型:{request.genre or project.genre}
|
||||
- 叙事视角:{request.narrative_perspective}
|
||||
- 情节阶段:{request.plot_stage}
|
||||
- 故事发展方向:{request.story_direction or '自然延续'}
|
||||
|
||||
最近章节概要:
|
||||
{latest_summary[:200]}
|
||||
|
||||
请搜索:
|
||||
1. 该情节阶段的经典处理手法和技巧
|
||||
2. 适合该发展方向的情节转折和冲突设计
|
||||
3. 符合类型特点的场景设计和剧情元素
|
||||
|
||||
请有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_query,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 第{batch_num + 1}批MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
else:
|
||||
logger.info(f"ℹ️ 第{batch_num + 1}批MCP工具未进行调用,继续正常生成")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 第{batch_num + 1}批MCP工具调用失败,继续使用常规模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
# 使用标准续写提示词模板(支持记忆+MCP增强)
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
@@ -683,7 +789,8 @@ async def _continue_outline(
|
||||
start_chapter=current_start_chapter,
|
||||
story_direction=request.story_direction or "自然延续",
|
||||
requirements=request.requirements or "",
|
||||
memory_context=memory_context
|
||||
memory_context=memory_context,
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
@@ -694,8 +801,11 @@ async def _continue_outline(
|
||||
model=request.model
|
||||
)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 保存当前批次的大纲
|
||||
batch_outlines = await _save_outlines(
|
||||
@@ -706,7 +816,7 @@ async def _continue_outline(
|
||||
history = GenerationHistory(
|
||||
project_id=project.id,
|
||||
prompt=f"[批次{batch_num + 1}/{total_batches}] {str(prompt)[:500]}",
|
||||
generated_content=ai_response,
|
||||
generated_content=json.dumps(ai_response, ensure_ascii=False) if isinstance(ai_response, dict) else ai_response,
|
||||
model=request.model or "default"
|
||||
)
|
||||
db.add(history)
|
||||
@@ -820,7 +930,7 @@ async def new_outline_generator(
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""全新生成大纲SSE生成器"""
|
||||
"""全新生成大纲SSE生成器(MCP增强版)"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始生成大纲...", 5)
|
||||
@@ -828,6 +938,7 @@ async def new_outline_generator(
|
||||
project_id = data.get("project_id")
|
||||
# 确保chapter_count是整数(前端可能传字符串)
|
||||
chapter_count = int(data.get("chapter_count", 10))
|
||||
enable_mcp = data.get("enable_mcp", True)
|
||||
|
||||
# 验证项目
|
||||
yield await SSEResponse.send_progress("加载项目信息...", 10)
|
||||
@@ -852,7 +963,61 @@ async def new_outline_generator(
|
||||
for char in characters
|
||||
])
|
||||
|
||||
# 使用完整提示词
|
||||
# 🔍 MCP工具增强:收集情节设计参考资料
|
||||
mcp_reference_materials = ""
|
||||
if enable_mcp:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 使用MCP工具收集参考资料...", 18)
|
||||
logger.info(f"🔍 尝试使用MCP工具收集大纲设计参考资料...")
|
||||
|
||||
# 构建资料收集查询
|
||||
planning_query = f"""你正在为小说《{project.title}》设计完整大纲。
|
||||
项目信息:
|
||||
- 主题:{data.get('theme') or project.theme}
|
||||
- 类型:{data.get('genre') or project.genre}
|
||||
- 章节数:{chapter_count}
|
||||
- 叙事视角:{data.get('narrative_perspective') or '第三人称'}
|
||||
- 目标字数:{data.get('target_words') or project.target_words or 100000}
|
||||
|
||||
世界观设定:
|
||||
- 时间背景:{project.world_time_period or '未设定'}
|
||||
- 地理位置:{project.world_location or '未设定'}
|
||||
- 氛围基调:{project.world_atmosphere or '未设定'}
|
||||
|
||||
角色信息:
|
||||
{characters_info or '暂无角色'}
|
||||
|
||||
请搜索:
|
||||
1. 该类型小说的经典情节结构和套路
|
||||
2. 适合该主题的冲突设计思路
|
||||
3. 符合世界观的情节元素和场景设计灵感
|
||||
|
||||
请有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_query,
|
||||
user_id="system",
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
yield await SSEResponse.send_progress(f"📚 MCP收集到参考资料 ({len(mcp_reference_materials)}字符)", 19)
|
||||
else:
|
||||
logger.info(f"ℹ️ MCP工具未进行调用,继续正常生成")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,继续使用常规模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
# 使用完整提示词(插入MCP参考资料)
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 20)
|
||||
prompt = prompt_service.get_complete_outline_prompt(
|
||||
title=project.title,
|
||||
@@ -866,7 +1031,8 @@ async def new_outline_generator(
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
requirements=data.get("requirements") or ""
|
||||
requirements=data.get("requirements") or "",
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
@@ -879,8 +1045,11 @@ async def new_outline_generator(
|
||||
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 删除旧大纲和章节
|
||||
yield await SSEResponse.send_progress("清理旧数据...", 75)
|
||||
@@ -902,7 +1071,7 @@ async def new_outline_generator(
|
||||
history = GenerationHistory(
|
||||
project_id=project_id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_response,
|
||||
generated_content=json.dumps(ai_response, ensure_ascii=False) if isinstance(ai_response, dict) else ai_response,
|
||||
model=data.get("model") or "default"
|
||||
)
|
||||
db.add(history)
|
||||
@@ -957,7 +1126,7 @@ async def continue_outline_generator(
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""大纲续写SSE生成器 - 分批生成,推送进度(记忆增强版)"""
|
||||
"""大纲续写SSE生成器 - 分批生成,推送进度(记忆+MCP增强版)"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始续写大纲...", 5)
|
||||
@@ -1090,13 +1259,72 @@ async def continue_outline_generator(
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 记忆上下文构建失败: {str(e)}")
|
||||
memory_context = None
|
||||
# 🔍 MCP工具增强:收集续写参考资料
|
||||
mcp_reference_materials = ""
|
||||
enable_mcp = data.get("enable_mcp", True)
|
||||
if enable_mcp:
|
||||
try:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"🔍 第{str(batch_num + 1)}批:使用MCP工具收集参考资料...",
|
||||
batch_progress + 4
|
||||
)
|
||||
logger.info(f"🔍 第{batch_num + 1}批:尝试使用MCP工具收集续写参考资料...")
|
||||
|
||||
# 构建资料收集查询
|
||||
latest_summary = latest_outlines[-1].content if latest_outlines else ""
|
||||
planning_query = f"""你正在为小说《{project.title}》续写大纲。
|
||||
当前进度:已有{len(latest_outlines)}章,即将续写第{current_start_chapter}-{current_start_chapter + current_batch_size - 1}章
|
||||
|
||||
项目信息:
|
||||
- 主题:{data.get('theme') or project.theme}
|
||||
- 类型:{data.get('genre') or project.genre}
|
||||
- 叙事视角:{data.get('narrative_perspective') or project.narrative_perspective or '第三人称'}
|
||||
- 情节阶段:{data.get('plot_stage', 'development')}
|
||||
- 故事发展方向:{data.get('story_direction', '自然延续')}
|
||||
|
||||
最近章节概要:
|
||||
{latest_summary[:200]}
|
||||
|
||||
请搜索:
|
||||
1. 该情节阶段的经典处理手法和技巧
|
||||
2. 适合该发展方向的情节转折和冲突设计
|
||||
3. 符合类型特点的场景设计和剧情元素
|
||||
|
||||
请有针对性地查询1-2个最关键的问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_query,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
mcp_reference_materials = planning_result.get("content", "")
|
||||
logger.info(f"📚 第{batch_num + 1}批MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
|
||||
yield await SSEResponse.send_progress(
|
||||
f"📚 第{str(batch_num + 1)}批收集到参考资料 ({len(mcp_reference_materials)}字符)",
|
||||
batch_progress + 4.5
|
||||
)
|
||||
else:
|
||||
logger.info(f"ℹ️ 第{batch_num + 1}批MCP工具未进行调用,继续正常生成")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 第{batch_num + 1}批MCP工具调用失败,继续使用常规模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
|
||||
yield await SSEResponse.send_progress(
|
||||
f" 调用AI生成第{str(batch_num + 1)}批...",
|
||||
batch_progress + 5
|
||||
)
|
||||
|
||||
# 使用标准续写提示词模板(支持记忆增强)
|
||||
# 使用标准续写提示词模板(支持记忆+MCP增强)
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
title=project.title,
|
||||
theme=data.get("theme") or project.theme or "未设定",
|
||||
@@ -1115,7 +1343,8 @@ async def continue_outline_generator(
|
||||
start_chapter=current_start_chapter,
|
||||
story_direction=data.get("story_direction", "自然延续"),
|
||||
requirements=data.get("requirements", ""),
|
||||
memory_context=memory_context
|
||||
memory_context=memory_context,
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
@@ -1130,8 +1359,11 @@ async def continue_outline_generator(
|
||||
batch_progress + 10
|
||||
)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_response)
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 保存当前批次的大纲
|
||||
batch_outlines = await _save_outlines(
|
||||
@@ -1142,7 +1374,7 @@ async def continue_outline_generator(
|
||||
history = GenerationHistory(
|
||||
project_id=project_id,
|
||||
prompt=f"[续写批次{batch_num + 1}/{total_batches}] {str(prompt)[:500]}",
|
||||
generated_content=ai_response,
|
||||
generated_content=json.dumps(ai_response, ensure_ascii=False) if isinstance(ai_response, dict) else ai_response,
|
||||
model=data.get("model") or "default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
@@ -359,7 +359,10 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
|
||||
logger.info(f"✅ API 测试成功")
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - 响应内容: {response[:100] if response else 'N/A'}")
|
||||
|
||||
# 安全地处理响应内容(确保是字符串)
|
||||
response_str = str(response) if response else 'N/A'
|
||||
logger.info(f" - 响应内容: {response_str[:100]}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -367,7 +370,7 @@ async def test_api_connection(data: ApiTestRequest):
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": llm_model,
|
||||
"response_preview": response[:100] if response and len(response) > 100 else response,
|
||||
"response_preview": response_str[:100] if len(response_str) > 100 else response_str,
|
||||
"details": {
|
||||
"api_available": True,
|
||||
"model_accessible": True,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""项目创建向导流式API - 使用SSE避免超时"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
@@ -15,6 +15,7 @@ from app.models.relationship import CharacterRelationship, Organization, Organiz
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.mcp_tool_service import MCPToolService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
@@ -29,7 +30,7 @@ async def world_building_generator(
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""世界构建流式生成器"""
|
||||
"""世界构建流式生成器 - 支持MCP工具增强"""
|
||||
# 标记数据库会话是否已提交
|
||||
db_committed = False
|
||||
try:
|
||||
@@ -47,27 +48,94 @@ async def world_building_generator(
|
||||
character_count = data.get("character_count")
|
||||
provider = data.get("provider")
|
||||
model = data.get("model")
|
||||
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
||||
user_id = data.get("user_id") # 从中间件注入
|
||||
|
||||
if not title or not description or not theme or not genre:
|
||||
yield await SSEResponse.send_error("title、description、theme 和 genre 是必需的参数", 400)
|
||||
return
|
||||
|
||||
# 获取提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 20)
|
||||
prompt = prompt_service.get_world_building_prompt(
|
||||
# 获取基础提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre
|
||||
)
|
||||
|
||||
# 流式调用AI
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
# MCP工具增强:收集参考资料
|
||||
reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
|
||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{title}》设计世界观。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{genre}
|
||||
- 主题:{theme}
|
||||
- 简介:{description}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
|
||||
你可以查询:
|
||||
1. 历史背景(如果是历史题材)
|
||||
2. 地理环境和文化特征
|
||||
3. 相关领域的专业知识
|
||||
4. 类似作品的设定参考
|
||||
|
||||
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
|
||||
|
||||
# 构建增强提示词
|
||||
if reference_materials:
|
||||
enhanced_prompt = f"""{base_prompt}
|
||||
|
||||
【参考资料】
|
||||
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
|
||||
|
||||
{reference_materials}
|
||||
|
||||
请结合上述资料,生成符合历史/现实的世界观设定。"""
|
||||
final_prompt = enhanced_prompt
|
||||
yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 30)
|
||||
else:
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
@@ -190,6 +258,7 @@ async def world_building_generator(
|
||||
|
||||
@router.post("/world-building", summary="流式生成世界构建")
|
||||
async def generate_world_building_stream(
|
||||
request: Request,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
@@ -198,6 +267,10 @@ async def generate_world_building_stream(
|
||||
使用SSE流式生成世界构建,避免超时
|
||||
前端使用EventSource接收实时进度和结果
|
||||
"""
|
||||
# 从中间件注入user_id到data中
|
||||
if hasattr(request.state, 'user_id'):
|
||||
data['user_id'] = request.state.user_id
|
||||
|
||||
return create_sse_response(world_building_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
@@ -206,7 +279,7 @@ async def characters_generator(
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""角色批量生成流式生成器 - 优化版:分批+重试"""
|
||||
"""角色批量生成流式生成器 - 优化版:分批+重试+MCP工具增强"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 5)
|
||||
@@ -219,6 +292,8 @@ async def characters_generator(
|
||||
requirements = data.get("requirements", "")
|
||||
provider = data.get("provider")
|
||||
model = data.get("model")
|
||||
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
||||
user_id = data.get("user_id") # 从中间件注入
|
||||
|
||||
# 验证项目
|
||||
yield await SSEResponse.send_progress("验证项目...", 10)
|
||||
@@ -239,6 +314,57 @@ async def characters_generator(
|
||||
"rules": project.world_rules or "未设定"
|
||||
}
|
||||
|
||||
# MCP工具增强:收集角色参考资料
|
||||
character_reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8)
|
||||
|
||||
# 构建角色资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》设计角色。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{genre or project.genre}
|
||||
- 主题:{theme or project.theme}
|
||||
- 时代背景:{world_context.get('time_period', '未设定')}
|
||||
- 地理位置:{world_context.get('location', '未设定')}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关参考资料,帮助设计更真实、更有深度的角色。
|
||||
你可以查询:
|
||||
1. 该时代/地域的真实历史人物特征
|
||||
2. 文化背景和社会习俗
|
||||
3. 职业特点和生活方式
|
||||
4. 相关领域的人物原型
|
||||
|
||||
请根据题材特点,有针对性地查询1-2个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
12
|
||||
)
|
||||
character_reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 12)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 12)
|
||||
|
||||
# 优化的分批策略:每批生成3个,平衡效率和成功率
|
||||
BATCH_SIZE = 3 # 每批生成3个角色
|
||||
MAX_RETRIES = 3 # 每批最多重试3次
|
||||
@@ -291,7 +417,8 @@ async def characters_generator(
|
||||
else:
|
||||
batch_requirements += "\n主要是配角(supporting)和反派(antagonist)"
|
||||
|
||||
prompt = prompt_service.get_characters_batch_prompt(
|
||||
# 构建基础提示词
|
||||
base_prompt = prompt_service.get_characters_batch_prompt(
|
||||
count=current_batch_size, # 传递精确数量
|
||||
time_period=world_context.get("time_period", ""),
|
||||
location=world_context.get("location", ""),
|
||||
@@ -302,6 +429,19 @@ async def characters_generator(
|
||||
requirements=batch_requirements
|
||||
)
|
||||
|
||||
# 如果有MCP参考资料,增强提示词
|
||||
if character_reference_materials:
|
||||
prompt = f"""{base_prompt}
|
||||
|
||||
【参考资料】
|
||||
以下是通过MCP工具收集的真实背景资料,请参考这些信息设计更真实的角色:
|
||||
|
||||
{character_reference_materials}
|
||||
|
||||
请结合上述资料,设计符合历史/文化背景的角色。"""
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# 流式生成
|
||||
accumulated_text = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
@@ -708,13 +848,19 @@ async def characters_generator(
|
||||
|
||||
@router.post("/characters", summary="流式批量生成角色")
|
||||
async def generate_characters_stream(
|
||||
request: Request,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式批量生成角色,避免超时
|
||||
支持MCP工具增强
|
||||
"""
|
||||
# 从中间件注入user_id到data中
|
||||
if hasattr(request.state, 'user_id'):
|
||||
data['user_id'] = request.state.user_id
|
||||
|
||||
return create_sse_response(characters_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
@@ -1071,7 +1217,7 @@ async def regenerate_world_building_generator(
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""重新生成世界观流式生成器"""
|
||||
"""重新生成世界观流式生成器 - 支持MCP工具增强"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始重新生成世界观...", 10)
|
||||
@@ -1087,23 +1233,89 @@ async def regenerate_world_building_generator(
|
||||
|
||||
provider = data.get("provider")
|
||||
model = data.get("model")
|
||||
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
||||
user_id = data.get("user_id") # 从中间件注入
|
||||
|
||||
# 获取世界构建提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 20)
|
||||
prompt = prompt_service.get_world_building_prompt(
|
||||
# 获取基础提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or "",
|
||||
genre=project.genre or ""
|
||||
)
|
||||
|
||||
# 流式调用AI
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
# MCP工具增强:收集参考资料
|
||||
reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
|
||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
|
||||
你可以查询:
|
||||
1. 历史背景(如果是历史题材)
|
||||
2. 地理环境和文化特征
|
||||
3. 相关领域的专业知识
|
||||
4. 类似作品的设定参考
|
||||
|
||||
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
|
||||
|
||||
# 构建增强提示词
|
||||
if reference_materials:
|
||||
enhanced_prompt = f"""{base_prompt}
|
||||
|
||||
【参考资料】
|
||||
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
|
||||
|
||||
{reference_materials}
|
||||
|
||||
请结合上述资料,生成符合历史/现实的世界观设定。"""
|
||||
final_prompt = enhanced_prompt
|
||||
yield await SSEResponse.send_progress("💡 已整合参考资料,开始重新生成世界观...", 30)
|
||||
else:
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
@@ -1187,6 +1399,7 @@ async def regenerate_world_building_generator(
|
||||
|
||||
@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观")
|
||||
async def regenerate_world_building_stream(
|
||||
request: Request,
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -1200,6 +1413,10 @@ async def regenerate_world_building_stream(
|
||||
"model": "模型名称(可选)"
|
||||
}
|
||||
"""
|
||||
# 从中间件注入user_id到data中
|
||||
if hasattr(request.state, 'user_id'):
|
||||
data['user_id'] = request.state.user_id
|
||||
|
||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user