fix: MCP插件TimeoutError修复 + 多项Bug修复和性能优化
- fix: MCP插件管理接口改为后台任务,修复TimeoutError - fix: MCP连接失败后上下文清理的cancel scope错误 - feat: MCP插件后台注册添加重试机制 - fix: 限制每章自动创建伏笔数量上限 - fix: 修复JSON非法转义字符清洗 - fix: SSE流式生成添加心跳保活 - fix: 职业生成改用POST请求避免URL长度限制 - perf: 使用torch CPU版本加速Docker构建 - fix: 自动修复JSON字符串值中的裸换行符 - feat: 集成json5容错解析器
This commit is contained in:
+19
-10
@@ -7,7 +7,7 @@ import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.database import get_db
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker, wrap_stream_with_heartbeat, HEARTBEAT
|
||||
from app.models.career import Career, CharacterCareer
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
@@ -25,6 +25,7 @@ from app.schemas.career import (
|
||||
CareerStage
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.json_helper import loads_json
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.api.common import verify_project_access
|
||||
@@ -155,14 +156,10 @@ async def create_career(
|
||||
raise HTTPException(status_code=500, detail=f"创建职业失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/generate-system", summary="AI生成新职业(增量式,流式)")
|
||||
@router.post("/generate-system", summary="AI生成新职业(增量式,流式)")
|
||||
async def generate_career_system(
|
||||
project_id: str,
|
||||
main_career_count: int = 3,
|
||||
sub_career_count: int = 6,
|
||||
user_requirements: str = "",
|
||||
enable_mcp: bool = False,
|
||||
http_request: Request = None,
|
||||
request_data: CareerGenerateRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -176,6 +173,10 @@ async def generate_career_system(
|
||||
try:
|
||||
# 验证用户权限和项目是否存在
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project_id = request_data.project_id
|
||||
main_career_count = request_data.main_career_count
|
||||
sub_career_count = request_data.sub_career_count
|
||||
user_requirements = request_data.user_requirements
|
||||
project = await verify_project_access(project_id, user_id, db)
|
||||
|
||||
yield await tracker.start()
|
||||
@@ -316,7 +317,15 @@ async def generate_career_system(
|
||||
chunk_count = 0
|
||||
estimated_total = max(3000, len(prompt) * 8)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
async for chunk in wrap_stream_with_heartbeat(
|
||||
user_ai_service.generate_text_stream(prompt=prompt),
|
||||
heartbeat_interval=15.0
|
||||
):
|
||||
# 心跳哨兵:发送心跳保活,不混入AI响应
|
||||
if chunk is HEARTBEAT:
|
||||
yield await tracker.heartbeat()
|
||||
continue
|
||||
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
@@ -345,7 +354,7 @@ async def generate_career_system(
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(ai_response)
|
||||
career_data = json.loads(cleaned_response)
|
||||
career_data = loads_json(cleaned_response)
|
||||
logger.info(f"✅ 职业体系JSON解析成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
||||
|
||||
@@ -7,7 +7,7 @@ import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.database import get_db
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker, wrap_stream_with_heartbeat, HEARTBEAT
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.models.generation_history import GenerationHistory
|
||||
@@ -20,6 +20,7 @@ from app.schemas.character import (
|
||||
CharacterGenerateRequest
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.json_helper import loads_json
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.services.import_export_service import ImportExportService
|
||||
from app.schemas.import_export import CharactersExportRequest, CharactersImportResult
|
||||
@@ -947,10 +948,18 @@ async def generate_character_stream(
|
||||
logger.info(f"🎯 开始生成角色(流式模式)...")
|
||||
yield await tracker.generating(0, estimated_total, "开始生成角色...")
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
tool_choice="required",
|
||||
async for chunk in wrap_stream_with_heartbeat(
|
||||
user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
tool_choice="required",
|
||||
),
|
||||
heartbeat_interval=15.0
|
||||
):
|
||||
# 心跳哨兵:发送心跳保活,不混入AI响应
|
||||
if chunk is HEARTBEAT:
|
||||
yield await tracker.heartbeat()
|
||||
continue
|
||||
|
||||
# chunk 现在可能是 dict 或 str,提取 content 字段
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk.get("content", "")
|
||||
@@ -987,7 +996,7 @@ async def generate_character_stream(
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(ai_response)
|
||||
character_data = json.loads(cleaned_response)
|
||||
character_data = loads_json(cleaned_response)
|
||||
logger.info(f"✅ 角色JSON解析成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 角色JSON解析失败: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.json_helper import loads_json
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.services.prompt_service import PromptService
|
||||
from app.logger import get_logger
|
||||
@@ -166,7 +167,7 @@ async def generate_options(
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
|
||||
result = json.loads(cleaned_content)
|
||||
result = loads_json(cleaned_content)
|
||||
|
||||
# 校验返回格式
|
||||
is_valid, error_msg = validate_options_response(result, step)
|
||||
@@ -343,7 +344,7 @@ async def refine_options(
|
||||
# 解析JSON
|
||||
try:
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
result = json.loads(cleaned_content)
|
||||
result = loads_json(cleaned_content)
|
||||
|
||||
# 校验返回格式
|
||||
is_valid, error_msg = validate_options_response(result, step)
|
||||
@@ -466,7 +467,7 @@ async def quick_generate(
|
||||
# 使用统一的JSON清洗方法
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
|
||||
result = json.loads(cleaned_content)
|
||||
result = loads_json(cleaned_content)
|
||||
|
||||
# 合并用户已提供的信息(用户输入优先)
|
||||
final_result = {
|
||||
@@ -487,4 +488,4 @@ async def quick_generate(
|
||||
logger.error(f"智能补全失败: {e}", exc_info=True)
|
||||
return {
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
+119
-99
@@ -54,65 +54,75 @@ async def _register_plugin_background(
|
||||
plugin_type: str,
|
||||
server_url: str,
|
||||
headers: Optional[dict],
|
||||
config: Optional[dict]
|
||||
config: Optional[dict],
|
||||
max_retries: int = 2,
|
||||
retry_delay: float = 3.0
|
||||
):
|
||||
"""
|
||||
后台任务:注册MCP插件并更新数据库状态
|
||||
后台任务:注册MCP插件并更新数据库状态(带重试)
|
||||
|
||||
在独立的任务中执行MCP连接,避免阻塞请求处理
|
||||
在独立的任务中执行MCP连接,避免阻塞请求处理。
|
||||
连接失败时会自动重试,提高对临时网络问题的容错性。
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
if attempt > 0:
|
||||
logger.info(f"后台注册MCP插件重试 ({attempt}/{max_retries}): {plugin_name}")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
if success:
|
||||
# 更新数据库状态为active
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(status="active", last_error=None)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
logger.info(f"后台注册MCP插件成功: {plugin_name}")
|
||||
return
|
||||
else:
|
||||
last_error = "连接失败"
|
||||
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
logger.warning(f"后台注册MCP插件异常 (尝试 {attempt + 1}/{max_retries + 1}): {plugin_name}, 错误: {e}")
|
||||
|
||||
# 所有重试都失败,更新数据库状态为error
|
||||
logger.error(f"后台注册MCP插件最终失败 (已重试{max_retries}次): {plugin_name}, 错误: {last_error}")
|
||||
try:
|
||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
# 更新数据库状态
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(
|
||||
status="active" if success else "error",
|
||||
last_error=None if success else "连接失败"
|
||||
)
|
||||
.values(status="error", last_error=str(last_error)[:500] if last_error else "连接失败")
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
if success:
|
||||
logger.info(f"后台注册MCP插件成功: {plugin_name}")
|
||||
else:
|
||||
logger.warning(f"后台注册MCP插件失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"后台注册MCP插件异常: {plugin_name}, 错误: {e}")
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(status="error", last_error=str(e))
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"更新插件状态失败: {db_error}")
|
||||
except Exception as db_error:
|
||||
logger.error(f"更新插件状态失败: {db_error}")
|
||||
|
||||
|
||||
async def _unregister_plugin_safe(user_id: str, plugin_name: str):
|
||||
@@ -215,22 +225,26 @@ async def create_plugin(
|
||||
**plugin_data
|
||||
)
|
||||
|
||||
# 如果启用,设为pending状态等待后台连接
|
||||
if plugin.enabled:
|
||||
plugin.status = "pending"
|
||||
|
||||
db.add(plugin)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果启用,注册到统一门面
|
||||
# 如果启用,后台注册到统一门面(避免MCP操作阻塞导致超时)
|
||||
if plugin.enabled:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
if success:
|
||||
plugin.status = "active"
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
asyncio.create_task(_register_plugin_background(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
plugin_type=plugin.plugin_type,
|
||||
server_url=plugin.server_url,
|
||||
headers=plugin.headers,
|
||||
config=plugin.config
|
||||
))
|
||||
|
||||
logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}")
|
||||
logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}(MCP注册在后台执行)")
|
||||
return plugin
|
||||
|
||||
|
||||
@@ -438,15 +452,29 @@ async def update_plugin(
|
||||
for key, value in update_data.items():
|
||||
setattr(plugin, key, value)
|
||||
|
||||
# 如果启用,设为pending状态等待后台连接
|
||||
if plugin.enabled:
|
||||
plugin.status = "pending"
|
||||
plugin.last_error = None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 如果插件已启用,重新注册
|
||||
# 如果插件已启用,后台重新注册MCP连接
|
||||
if plugin.enabled:
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
await _register_plugin_to_facade(plugin, user.user_id)
|
||||
# 先后台注销旧连接
|
||||
asyncio.create_task(_unregister_plugin_safe(user.user_id, plugin.plugin_name))
|
||||
# 再后台注册新连接
|
||||
asyncio.create_task(_register_plugin_background(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
plugin_type=plugin.plugin_type,
|
||||
server_url=plugin.server_url,
|
||||
headers=plugin.headers,
|
||||
config=plugin.config
|
||||
))
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}(MCP操作在后台执行)")
|
||||
return plugin
|
||||
|
||||
|
||||
@@ -470,15 +498,19 @@ async def delete_plugin(
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
# 从统一门面注销
|
||||
await mcp_client.unregister(user.user_id, plugin.plugin_name)
|
||||
# 保存插件信息用于后台注销
|
||||
plugin_name = plugin.plugin_name
|
||||
user_id = user.user_id
|
||||
|
||||
# 删除数据库记录
|
||||
# 先删除数据库记录
|
||||
await db.delete(plugin)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"用户 {user.user_id} 删除插件: {plugin.plugin_name}")
|
||||
return {"message": "插件已删除", "plugin_name": plugin.plugin_name}
|
||||
# 后台从统一门面注销(避免MCP操作阻塞导致超时)
|
||||
asyncio.create_task(_unregister_plugin_safe(user_id, plugin_name))
|
||||
|
||||
logger.info(f"用户 {user.user_id} 删除插件: {plugin_name}(MCP注销在后台执行)")
|
||||
return {"message": "插件已删除", "plugin_name": plugin_name}
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/toggle", response_model=MCPPluginResponse)
|
||||
@@ -490,6 +522,10 @@ async def toggle_plugin(
|
||||
):
|
||||
"""
|
||||
启用或禁用插件
|
||||
|
||||
启用时:先更新数据库状态为pending,再通过后台任务注册MCP连接,
|
||||
避免长时间持有数据库会话导致超时。
|
||||
禁用时:先更新数据库状态,再通过后台任务注销MCP连接。
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
@@ -509,51 +545,35 @@ async def toggle_plugin(
|
||||
headers = plugin.headers
|
||||
config = plugin.config
|
||||
|
||||
# 先更新数据库状态
|
||||
# 更新数据库状态
|
||||
plugin.enabled = enabled
|
||||
if not enabled:
|
||||
if enabled:
|
||||
# 启用时先设为pending状态,等待后台MCP连接完成
|
||||
plugin.status = "pending"
|
||||
plugin.last_error = None
|
||||
else:
|
||||
plugin.status = "inactive"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库操作完成后,再进行MCP操作
|
||||
# 数据库操作完成后,通过后台任务进行MCP操作(避免长时间持有数据库会话)
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
# 更新状态
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
# 启用:后台注册到统一门面
|
||||
asyncio.create_task(_register_plugin_background(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
plugin_type=plugin_type,
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
config=config
|
||||
))
|
||||
else:
|
||||
# 禁用:从统一门面注销(不影响数据库状态)
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
|
||||
# 禁用:后台从统一门面注销(不影响数据库状态)
|
||||
asyncio.create_task(_unregister_plugin_safe(user.user_id, plugin_name))
|
||||
|
||||
action = "启用" if enabled else "禁用"
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
|
||||
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}(MCP操作在后台执行)")
|
||||
return plugin
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker, wrap_stream_with_heartbeat, HEARTBEAT
|
||||
from app.models.relationship import Organization, OrganizationMember
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
@@ -24,6 +24,7 @@ from app.schemas.relationship import (
|
||||
)
|
||||
from app.schemas.character import CharacterResponse
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.json_helper import loads_json
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
@@ -500,7 +501,15 @@ async def generate_organization_stream(
|
||||
chunk_count = 0
|
||||
estimated_total = max(3000, len(prompt) * 8)
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
async for chunk in wrap_stream_with_heartbeat(
|
||||
user_ai_service.generate_text_stream(prompt=prompt),
|
||||
heartbeat_interval=15.0
|
||||
):
|
||||
# 心跳哨兵:发送心跳保活,不混入AI响应
|
||||
if chunk is HEARTBEAT:
|
||||
yield await tracker.heartbeat()
|
||||
continue
|
||||
|
||||
chunk_count += 1
|
||||
ai_content += chunk
|
||||
|
||||
@@ -529,7 +538,7 @@ async def generate_organization_stream(
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(ai_content)
|
||||
organization_data = json.loads(cleaned_response)
|
||||
organization_data = loads_json(cleaned_response)
|
||||
logger.info(f"✅ 组织JSON解析成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 组织JSON解析失败: {e}")
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.schemas.outline import (
|
||||
CreateChaptersFromPlansResponse
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.json_helper import loads_json
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.plot_expansion_service import PlotExpansionService
|
||||
@@ -850,7 +851,7 @@ def _parse_ai_response(ai_response: str, raise_on_error: bool = False) -> list:
|
||||
ai_service_temp = AIService()
|
||||
cleaned_text = ai_service_temp._clean_json_response(ai_response)
|
||||
|
||||
outline_data = json.loads(cleaned_text)
|
||||
outline_data = loads_json(cleaned_text)
|
||||
|
||||
# 确保是列表格式
|
||||
if not isinstance(outline_data, list):
|
||||
@@ -1447,6 +1448,31 @@ async def continue_outline_generator(
|
||||
message=f"🤖 调用AI生成第{str(batch_num + 1)}批..."
|
||||
)
|
||||
|
||||
# 获取伏笔提醒信息(用于大纲续写)
|
||||
foreshadow_reminders_text = "暂无需要关注的伏笔"
|
||||
try:
|
||||
foreshadow_context = await foreshadow_service.build_chapter_context(
|
||||
db=db,
|
||||
project_id=project_id,
|
||||
chapter_number=current_start_chapter,
|
||||
include_pending=False,
|
||||
include_overdue=True,
|
||||
lookahead=10
|
||||
)
|
||||
if foreshadow_context and foreshadow_context.get("context_text"):
|
||||
foreshadow_reminders_text = foreshadow_context["context_text"]
|
||||
logger.info(f"✅ 大纲续写获取到伏笔提醒: {len(foreshadow_reminders_text)}字符")
|
||||
# 追加伏笔统计信息
|
||||
foreshadow_stats = await foreshadow_service.get_stats(db, project_id)
|
||||
if foreshadow_stats:
|
||||
planted = foreshadow_stats.get('planted', 0)
|
||||
resolved = foreshadow_stats.get('resolved', 0)
|
||||
partial = foreshadow_stats.get('partially_resolved', 0)
|
||||
pending = foreshadow_stats.get('pending', 0)
|
||||
foreshadow_reminders_text += f"\n【📊 伏笔统计】已埋设:{planted} 已回收:{resolved} 部分回收:{partial} 待埋入:{pending}"
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ 获取大纲续写伏笔提醒失败: {str(e)}")
|
||||
|
||||
# 使用标准续写提示词模板(简化版)
|
||||
template = await PromptService.get_template("OUTLINE_CONTINUE", user_id, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
@@ -1463,6 +1489,8 @@ async def continue_outline_generator(
|
||||
# 上下文信息
|
||||
recent_outlines=context['recent_outlines'],
|
||||
characters_info=context['characters_info'],
|
||||
# 伏笔提醒
|
||||
foreshadow_reminders=foreshadow_reminders_text,
|
||||
# 续写参数
|
||||
chapter_count=current_batch_size,
|
||||
start_chapter=current_start_chapter,
|
||||
@@ -2482,4 +2510,4 @@ async def create_chapters_from_existing_plans(
|
||||
except Exception as e:
|
||||
logger.error(f"根据已有规划创建章节失败: {str(e)}", exc_info=True)
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建章节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创建章节失败: {str(e)}")
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.models.relationship import CharacterRelationship, Organization, Organiz
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.json_helper import loads_json
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.services.plot_expansion_service import PlotExpansionService
|
||||
from app.logger import get_logger
|
||||
@@ -169,7 +170,7 @@ async def world_building_generator(
|
||||
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||
logger.info(f" 清洗后预览: {cleaned_text[:300]}...")
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
world_data = loads_json(cleaned_text)
|
||||
logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_generation_success = True # 解析成功,标记完成
|
||||
|
||||
@@ -433,7 +434,7 @@ async def career_system_generator(
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(career_response)
|
||||
career_data = json.loads(cleaned_response)
|
||||
career_data = loads_json(cleaned_response)
|
||||
logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||
|
||||
yield await tracker.saving("保存职业数据...")
|
||||
@@ -771,7 +772,7 @@ async def characters_generator(
|
||||
|
||||
# 解析批次结果 - 使用统一的JSON清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
characters_data = json.loads(cleaned_text)
|
||||
characters_data = loads_json(cleaned_text)
|
||||
if not isinstance(characters_data, list):
|
||||
characters_data = [characters_data]
|
||||
|
||||
@@ -1362,7 +1363,7 @@ async def outline_generator(
|
||||
|
||||
try:
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
outline_data = json.loads(cleaned_text)
|
||||
outline_data = loads_json(cleaned_text)
|
||||
if not isinstance(outline_data, list):
|
||||
outline_data = [outline_data]
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -1668,7 +1669,7 @@ async def world_building_regenerate_generator(
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
world_data = loads_json(cleaned_text)
|
||||
logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_generation_success = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user