update:1.更新导入导出功能 2.实现RAG记忆功能,引入剧情分析功能
This commit is contained in:
@@ -0,0 +1,383 @@
|
||||
"""记忆管理API - 提供记忆的查询、分析等接口"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, desc
|
||||
from typing import List, Optional
|
||||
from app.database import get_db
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.chapter import Chapter
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.plot_analyzer import get_plot_analyzer
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.models.settings import Settings
|
||||
from app.logger import get_logger
|
||||
import uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
||||
async def analyze_chapter(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
分析章节并生成记忆
|
||||
|
||||
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
||||
"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 获取章节内容
|
||||
result = await db.execute(
|
||||
select(Chapter).where(
|
||||
and_(
|
||||
Chapter.id == chapter_id,
|
||||
Chapter.project_id == project_id
|
||||
)
|
||||
)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
if not chapter.content:
|
||||
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
|
||||
|
||||
# 获取用户AI设置
|
||||
settings_result = await db.execute(select(Settings))
|
||||
settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
raise HTTPException(status_code=400, detail="请先配置AI设置")
|
||||
|
||||
# 创建AI服务
|
||||
ai_service = create_user_ai_service(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url,
|
||||
model_name=settings.model_name,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
)
|
||||
|
||||
# 执行剧情分析
|
||||
analyzer = get_plot_analyzer(ai_service)
|
||||
analysis_result = await analyzer.analyze_chapter(
|
||||
chapter_number=chapter.chapter_number,
|
||||
title=chapter.title,
|
||||
content=chapter.content,
|
||||
word_count=chapter.word_count or len(chapter.content)
|
||||
)
|
||||
|
||||
if not analysis_result:
|
||||
raise HTTPException(status_code=500, detail="剧情分析失败")
|
||||
|
||||
# 保存分析结果到数据库
|
||||
plot_analysis = PlotAnalysis(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id,
|
||||
plot_stage=analysis_result.get('plot_stage'),
|
||||
conflict_level=analysis_result.get('conflict', {}).get('level'),
|
||||
conflict_types=analysis_result.get('conflict', {}).get('types'),
|
||||
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion'),
|
||||
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10,
|
||||
emotional_curve=analysis_result.get('emotional_arc'),
|
||||
hooks=analysis_result.get('hooks'),
|
||||
hooks_count=len(analysis_result.get('hooks', [])),
|
||||
hooks_avg_strength=sum(h.get('strength', 0) for h in analysis_result.get('hooks', [])) / max(len(analysis_result.get('hooks', [])), 1),
|
||||
foreshadows=analysis_result.get('foreshadows'),
|
||||
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
|
||||
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
|
||||
plot_points=analysis_result.get('plot_points'),
|
||||
plot_points_count=len(analysis_result.get('plot_points', [])),
|
||||
character_states=analysis_result.get('character_states'),
|
||||
scenes=analysis_result.get('scenes'),
|
||||
pacing=analysis_result.get('pacing'),
|
||||
dialogue_ratio=analysis_result.get('dialogue_ratio'),
|
||||
description_ratio=analysis_result.get('description_ratio'),
|
||||
overall_quality_score=analysis_result.get('scores', {}).get('overall'),
|
||||
pacing_score=analysis_result.get('scores', {}).get('pacing'),
|
||||
engagement_score=analysis_result.get('scores', {}).get('engagement'),
|
||||
coherence_score=analysis_result.get('scores', {}).get('coherence'),
|
||||
analysis_report=analyzer.generate_analysis_summary(analysis_result),
|
||||
suggestions=analysis_result.get('suggestions'),
|
||||
word_count=chapter.word_count
|
||||
)
|
||||
|
||||
# 检查是否已存在分析记录
|
||||
existing = await db.execute(
|
||||
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
# 删除旧记录
|
||||
await db.execute(
|
||||
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
||||
)
|
||||
await db.delete(existing.scalar_one())
|
||||
|
||||
db.add(plot_analysis)
|
||||
await db.commit()
|
||||
|
||||
# 从分析结果中提取记忆片段
|
||||
memories_data = analyzer.extract_memories_from_analysis(
|
||||
analysis_result,
|
||||
chapter_id,
|
||||
chapter.chapter_number
|
||||
)
|
||||
|
||||
# 保存记忆到数据库和向量库
|
||||
saved_count = 0
|
||||
for mem_data in memories_data:
|
||||
memory_id = str(uuid.uuid4())
|
||||
|
||||
# 保存到关系数据库
|
||||
memory = StoryMemory(
|
||||
id=memory_id,
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id,
|
||||
memory_type=mem_data['type'],
|
||||
title=mem_data.get('title', ''),
|
||||
content=mem_data['content'],
|
||||
story_timeline=chapter.chapter_number,
|
||||
vector_id=memory_id,
|
||||
**mem_data['metadata']
|
||||
)
|
||||
db.add(memory)
|
||||
|
||||
# 保存到向量库
|
||||
await memory_service.add_memory(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
memory_id=memory_id,
|
||||
content=mem_data['content'],
|
||||
memory_type=mem_data['type'],
|
||||
metadata=mem_data['metadata']
|
||||
)
|
||||
saved_count += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"✅ 章节分析完成: 保存{saved_count}条记忆")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"分析完成,提取了{saved_count}条记忆",
|
||||
"analysis": plot_analysis.to_dict(),
|
||||
"memories_count": saved_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 章节分析失败: {str(e)}")
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/memories")
|
||||
async def get_project_memories(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
memory_type: Optional[str] = None,
|
||||
chapter_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取项目的记忆列表"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 构建查询
|
||||
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
||||
|
||||
if memory_type:
|
||||
query = query.where(StoryMemory.memory_type == memory_type)
|
||||
if chapter_id:
|
||||
query = query.where(StoryMemory.chapter_id == chapter_id)
|
||||
|
||||
query = query.order_by(desc(StoryMemory.importance_score), desc(StoryMemory.created_at)).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
memories = result.scalars().all()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"memories": [mem.to_dict() for mem in memories],
|
||||
"total": len(memories)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取记忆失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/analysis/{chapter_id}")
|
||||
async def get_chapter_analysis(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取章节的剧情分析"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(PlotAnalysis).where(
|
||||
and_(
|
||||
PlotAnalysis.project_id == project_id,
|
||||
PlotAnalysis.chapter_id == chapter_id
|
||||
)
|
||||
)
|
||||
)
|
||||
analysis = result.scalar_one_or_none()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="该章节还未进行分析")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"analysis": analysis.to_dict()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取分析失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/search")
|
||||
async def search_memories(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
query: str,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.0
|
||||
):
|
||||
"""语义搜索项目记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
memories = await memory_service.search_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
memory_types=memory_types,
|
||||
limit=limit,
|
||||
min_importance=min_importance
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": query,
|
||||
"memories": memories,
|
||||
"total": len(memories)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索记忆失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/foreshadows")
|
||||
async def get_unresolved_foreshadows(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
current_chapter: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取未完结的伏笔"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 从向量库搜索
|
||||
foreshadows = await memory_service.find_unresolved_foreshadows(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
current_chapter=current_chapter
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"foreshadows": foreshadows,
|
||||
"total": len(foreshadows)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取伏笔失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/stats")
|
||||
async def get_memory_stats(
|
||||
project_id: str,
|
||||
request: Request
|
||||
):
|
||||
"""获取记忆统计信息"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
stats = await memory_service.get_memory_stats(
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取统计失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}/chapters/{chapter_id}/memories")
|
||||
async def delete_chapter_memories(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除章节的所有记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
|
||||
# 从数据库删除
|
||||
result = await db.execute(
|
||||
select(StoryMemory).where(
|
||||
and_(
|
||||
StoryMemory.project_id == project_id,
|
||||
StoryMemory.chapter_id == chapter_id
|
||||
)
|
||||
)
|
||||
)
|
||||
memories = result.scalars().all()
|
||||
|
||||
for memory in memories:
|
||||
await db.delete(memory)
|
||||
|
||||
# 从向量库删除
|
||||
await memory_service.delete_chapter_memories(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
chapter_id=chapter_id
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已删除{len(memories)}条记忆"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 删除记忆失败: {str(e)}")
|
||||
await db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user