update:1.更新导入导出功能 2.实现RAG记忆功能,引入剧情分析功能

This commit is contained in:
xiamuceer
2025-11-04 14:38:59 +08:00
parent 1cde345ed9
commit e4f90d5da0
26 changed files with 6722 additions and 84 deletions
+62 -12
View File
@@ -1,5 +1,5 @@
"""大纲管理API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List, AsyncGenerator, Dict, Any
@@ -21,6 +21,7 @@ from app.schemas.outline import (
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.memory_service import memory_service
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import SSEResponse, create_sse_response
@@ -328,6 +329,7 @@ async def reorder_outlines(
@router.post("/generate", response_model=OutlineListResponse, summary="AI生成/续写大纲")
async def generate_outline(
request: OutlineGenerateRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -377,8 +379,10 @@ async def generate_outline(
detail="续写模式需要已有大纲,当前项目没有大纲"
)
# 获取用户ID用于记忆检索
user_id = getattr(http_request.state, "user_id", "system")
return await _continue_outline(
request, project, existing_outlines, db, user_ai_service
request, project, existing_outlines, db, user_ai_service, user_id
)
else:
@@ -478,9 +482,10 @@ async def _continue_outline(
project: Project,
existing_outlines: List[Outline],
db: AsyncSession,
user_ai_service: AIService
user_ai_service: AIService,
user_id: str = "system"
) -> OutlineListResponse:
"""续写大纲 - 分批生成,每批5章"""
"""续写大纲 - 分批生成,每批5章(记忆增强版)"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)}")
# 分析已有大纲
@@ -545,7 +550,25 @@ async def _continue_outline(
for o in latest_outlines
])
# 使用标准续写提示词模板
# 🧠 构建记忆增强上下文(仅续写模式需要)
memory_context = None
try:
logger.info(f"🧠 为第{batch_num + 1}批构建记忆上下文...")
# 使用最近一章的大纲作为查询
query_outline = recent_outlines[-1].content if recent_outlines else ""
memory_context = await memory_service.build_context_for_generation(
user_id=user_id,
project_id=project.id,
current_chapter=current_start_chapter,
chapter_outline=query_outline,
character_names=[c.name for c in characters] if characters else None
)
logger.info(f"✅ 记忆上下文构建完成: {memory_context['stats']}")
except Exception as e:
logger.warning(f"⚠️ 记忆上下文构建失败,继续不使用记忆: {str(e)}")
memory_context = None
# 使用标准续写提示词模板(支持记忆增强)
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=request.theme or project.theme or "未设定",
@@ -563,7 +586,8 @@ async def _continue_outline(
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or ""
requirements=request.requirements or "",
memory_context=memory_context
)
# 调用AI生成当前批次
@@ -834,9 +858,10 @@ async def new_outline_generator(
async def continue_outline_generator(
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
user_ai_service: AIService,
user_id: str = "system"
) -> AsyncGenerator[str, None]:
"""大纲续写SSE生成器 - 分批生成,推送进度"""
"""大纲续写SSE生成器 - 分批生成,推送进度(记忆增强版)"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始续写大纲...", 5)
@@ -940,12 +965,32 @@ async def continue_outline_generator(
for o in latest_outlines
])
# 🧠 构建记忆增强上下文
memory_context = None
try:
yield await SSEResponse.send_progress(
f"🧠 构建记忆上下文...",
batch_progress + 3
)
query_outline = recent_outlines[-1].content if recent_outlines else ""
memory_context = await memory_service.build_context_for_generation(
user_id=user_id,
project_id=project_id,
current_chapter=current_start_chapter,
chapter_outline=query_outline,
character_names=[c.name for c in characters] if characters else None
)
logger.info(f"✅ 记忆上下文: {memory_context['stats']}")
except Exception as e:
logger.warning(f"⚠️ 记忆上下文构建失败: {str(e)}")
memory_context = None
yield await SSEResponse.send_progress(
f"🤖 调用AI生成第{str(batch_num + 1)}批...",
f" 调用AI生成第{str(batch_num + 1)}批...",
batch_progress + 5
)
# 使用标准续写提示词模板
# 使用标准续写提示词模板(支持记忆增强)
prompt = prompt_service.get_outline_continue_prompt(
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
@@ -963,7 +1008,8 @@ async def continue_outline_generator(
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
story_direction=data.get("story_direction", "自然延续"),
requirements=data.get("requirements", "")
requirements=data.get("requirements", ""),
memory_context=memory_context
)
# 调用AI生成当前批次
@@ -1062,6 +1108,7 @@ async def continue_outline_generator(
@router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)")
async def generate_outline_stream(
data: Dict[str, Any],
request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -1111,6 +1158,9 @@ async def generate_outline_stream(
mode = "continue" if existing_outlines else "new"
logger.info(f"自动判断模式:{'续写' if existing_outlines else '新建'}")
# 获取用户ID
user_id = getattr(request.state, "user_id", "system")
# 根据模式选择生成器
if mode == "new":
return create_sse_response(new_outline_generator(data, db, user_ai_service))
@@ -1120,7 +1170,7 @@ async def generate_outline_stream(
status_code=400,
detail="续写模式需要已有大纲,当前项目没有大纲"
)
return create_sse_response(continue_outline_generator(data, db, user_ai_service))
return create_sse_response(continue_outline_generator(data, db, user_ai_service, user_id))
else:
raise HTTPException(
status_code=400,