Files
MuMuAINovel/backend/app/api/memories.py
T
2025-11-10 21:16:55 +08:00

430 lines
14 KiB
Python

"""记忆管理API - 提供记忆的查询、分析等接口"""
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, desc
from typing import List, Optional
from app.database import get_db
from app.models.memory import StoryMemory, PlotAnalysis
from app.models.chapter import Chapter
from app.models.project import Project
from app.services.memory_service import memory_service
from app.services.plot_analyzer import get_plot_analyzer
from app.services.ai_service import create_user_ai_service
from app.models.settings import Settings
from app.logger import get_logger
import uuid
logger = get_logger(__name__)
router = APIRouter(prefix="/api/memories", tags=["memories"])
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
async def analyze_chapter(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
分析章节并生成记忆
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 获取章节内容
result = await db.execute(
select(Chapter).where(
and_(
Chapter.id == chapter_id,
Chapter.project_id == project_id
)
)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content:
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
# 获取用户AI设置
settings_result = await db.execute(select(Settings))
settings = settings_result.scalar_one_or_none()
if not settings:
raise HTTPException(status_code=400, detail="请先配置AI设置")
# 创建AI服务
ai_service = create_user_ai_service(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url,
model_name=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens
)
# 执行剧情分析
analyzer = get_plot_analyzer(ai_service)
analysis_result = await analyzer.analyze_chapter(
chapter_number=chapter.chapter_number,
title=chapter.title,
content=chapter.content,
word_count=chapter.word_count or len(chapter.content)
)
if not analysis_result:
raise HTTPException(status_code=500, detail="剧情分析失败")
# 保存分析结果到数据库
plot_analysis = PlotAnalysis(
id=str(uuid.uuid4()),
project_id=project_id,
chapter_id=chapter_id,
plot_stage=analysis_result.get('plot_stage'),
conflict_level=analysis_result.get('conflict', {}).get('level'),
conflict_types=analysis_result.get('conflict', {}).get('types'),
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion'),
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10,
emotional_curve=analysis_result.get('emotional_arc'),
hooks=analysis_result.get('hooks'),
hooks_count=len(analysis_result.get('hooks', [])),
hooks_avg_strength=sum(h.get('strength', 0) for h in analysis_result.get('hooks', [])) / max(len(analysis_result.get('hooks', [])), 1),
foreshadows=analysis_result.get('foreshadows'),
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
plot_points=analysis_result.get('plot_points'),
plot_points_count=len(analysis_result.get('plot_points', [])),
character_states=analysis_result.get('character_states'),
scenes=analysis_result.get('scenes'),
pacing=analysis_result.get('pacing'),
dialogue_ratio=analysis_result.get('dialogue_ratio'),
description_ratio=analysis_result.get('description_ratio'),
overall_quality_score=analysis_result.get('scores', {}).get('overall'),
pacing_score=analysis_result.get('scores', {}).get('pacing'),
engagement_score=analysis_result.get('scores', {}).get('engagement'),
coherence_score=analysis_result.get('scores', {}).get('coherence'),
analysis_report=analyzer.generate_analysis_summary(analysis_result),
suggestions=analysis_result.get('suggestions'),
word_count=chapter.word_count
)
# 检查是否已存在分析记录
existing = await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
if existing.scalar_one_or_none():
# 删除旧记录
await db.execute(
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
)
await db.delete(existing.scalar_one())
db.add(plot_analysis)
await db.commit()
# 从分析结果中提取记忆片段
memories_data = analyzer.extract_memories_from_analysis(
analysis_result,
chapter_id,
chapter.chapter_number
)
# 保存记忆到数据库和向量库
saved_count = 0
for mem_data in memories_data:
memory_id = str(uuid.uuid4())
# 保存到关系数据库
memory = StoryMemory(
id=memory_id,
project_id=project_id,
chapter_id=chapter_id,
memory_type=mem_data['type'],
title=mem_data.get('title', ''),
content=mem_data['content'],
story_timeline=chapter.chapter_number,
vector_id=memory_id,
**mem_data['metadata']
)
db.add(memory)
# 保存到向量库
await memory_service.add_memory(
user_id=user_id,
project_id=project_id,
memory_id=memory_id,
content=mem_data['content'],
memory_type=mem_data['type'],
metadata=mem_data['metadata']
)
saved_count += 1
await db.commit()
logger.info(f"✅ 章节分析完成: 保存{saved_count}条记忆")
return {
"success": True,
"message": f"分析完成,提取了{saved_count}条记忆",
"analysis": plot_analysis.to_dict(),
"memories_count": saved_count
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ 章节分析失败: {str(e)}")
await db.rollback()
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
@router.get("/projects/{project_id}/memories")
async def get_project_memories(
project_id: str,
request: Request,
memory_type: Optional[str] = None,
chapter_id: Optional[str] = None,
limit: int = 50,
db: AsyncSession = Depends(get_db)
):
"""获取项目的记忆列表"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 构建查询
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
if memory_type:
query = query.where(StoryMemory.memory_type == memory_type)
if chapter_id:
query = query.where(StoryMemory.chapter_id == chapter_id)
query = query.order_by(desc(StoryMemory.importance_score), desc(StoryMemory.created_at)).limit(limit)
result = await db.execute(query)
memories = result.scalars().all()
return {
"success": True,
"memories": [mem.to_dict() for mem in memories],
"total": len(memories)
}
except Exception as e:
logger.error(f"❌ 获取记忆失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/analysis/{chapter_id}")
async def get_chapter_analysis(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取章节的剧情分析"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
result = await db.execute(
select(PlotAnalysis).where(
and_(
PlotAnalysis.project_id == project_id,
PlotAnalysis.chapter_id == chapter_id
)
)
)
analysis = result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节还未进行分析")
return {
"success": True,
"analysis": analysis.to_dict()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ 获取分析失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/projects/{project_id}/search")
async def search_memories(
project_id: str,
request: Request,
query: str,
memory_types: Optional[List[str]] = None,
limit: int = 10,
min_importance: float = 0.0,
db: AsyncSession = Depends(get_db)
):
"""语义搜索项目记忆"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
memories = await memory_service.search_memories(
user_id=user_id,
project_id=project_id,
query=query,
memory_types=memory_types,
limit=limit,
min_importance=min_importance
)
return {
"success": True,
"query": query,
"memories": memories,
"total": len(memories)
}
except Exception as e:
logger.error(f"❌ 搜索记忆失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/foreshadows")
async def get_unresolved_foreshadows(
project_id: str,
request: Request,
current_chapter: int,
db: AsyncSession = Depends(get_db)
):
"""获取未完结的伏笔"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 从向量库搜索
foreshadows = await memory_service.find_unresolved_foreshadows(
user_id=user_id,
project_id=project_id,
current_chapter=current_chapter
)
return {
"success": True,
"foreshadows": foreshadows,
"total": len(foreshadows)
}
except Exception as e:
logger.error(f"❌ 获取伏笔失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/projects/{project_id}/stats")
async def get_memory_stats(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取记忆统计信息"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
stats = await memory_service.get_memory_stats(
user_id=user_id,
project_id=project_id
)
return {
"success": True,
"stats": stats
}
except Exception as e:
logger.error(f"❌ 获取统计失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/projects/{project_id}/chapters/{chapter_id}/memories")
async def delete_chapter_memories(
project_id: str,
chapter_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""删除章节的所有记忆"""
try:
user_id = getattr(request.state, 'user_id', None)
# 验证用户权限
await verify_project_access(project_id, user_id, db)
# 从数据库删除
result = await db.execute(
select(StoryMemory).where(
and_(
StoryMemory.project_id == project_id,
StoryMemory.chapter_id == chapter_id
)
)
)
memories = result.scalars().all()
for memory in memories:
await db.delete(memory)
# 从向量库删除
await memory_service.delete_chapter_memories(
user_id=user_id,
project_id=project_id,
chapter_id=chapter_id
)
await db.commit()
return {
"success": True,
"message": f"已删除{len(memories)}条记忆"
}
except Exception as e:
logger.error(f"❌ 删除记忆失败: {str(e)}")
await db.rollback()
raise HTTPException(status_code=500, detail=str(e))