383 lines
13 KiB
Python
383 lines
13 KiB
Python
"""记忆管理API - 提供记忆的查询、分析等接口"""
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, desc
|
|
from typing import List, Optional
|
|
from app.database import get_db
|
|
from app.models.memory import StoryMemory, PlotAnalysis
|
|
from app.models.chapter import Chapter
|
|
from app.services.memory_service import memory_service
|
|
from app.services.plot_analyzer import get_plot_analyzer
|
|
from app.services.ai_service import create_user_ai_service
|
|
from app.models.settings import Settings
|
|
from app.logger import get_logger
|
|
import uuid
|
|
|
|
logger = get_logger(__name__)
|
|
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
|
|
|
|
|
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
|
async def analyze_chapter(
|
|
project_id: str,
|
|
chapter_id: str,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
分析章节并生成记忆
|
|
|
|
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
|
"""
|
|
try:
|
|
user_id = request.state.user_id
|
|
|
|
# 获取章节内容
|
|
result = await db.execute(
|
|
select(Chapter).where(
|
|
and_(
|
|
Chapter.id == chapter_id,
|
|
Chapter.project_id == project_id
|
|
)
|
|
)
|
|
)
|
|
chapter = result.scalar_one_or_none()
|
|
|
|
if not chapter:
|
|
raise HTTPException(status_code=404, detail="章节不存在")
|
|
|
|
if not chapter.content:
|
|
raise HTTPException(status_code=400, detail="章节内容为空,无法分析")
|
|
|
|
# 获取用户AI设置
|
|
settings_result = await db.execute(select(Settings))
|
|
settings = settings_result.scalar_one_or_none()
|
|
|
|
if not settings:
|
|
raise HTTPException(status_code=400, detail="请先配置AI设置")
|
|
|
|
# 创建AI服务
|
|
ai_service = create_user_ai_service(
|
|
api_provider=settings.api_provider,
|
|
api_key=settings.api_key,
|
|
api_base_url=settings.api_base_url,
|
|
model_name=settings.llm_model,
|
|
temperature=settings.temperature,
|
|
max_tokens=settings.max_tokens
|
|
)
|
|
|
|
# 执行剧情分析
|
|
analyzer = get_plot_analyzer(ai_service)
|
|
analysis_result = await analyzer.analyze_chapter(
|
|
chapter_number=chapter.chapter_number,
|
|
title=chapter.title,
|
|
content=chapter.content,
|
|
word_count=chapter.word_count or len(chapter.content)
|
|
)
|
|
|
|
if not analysis_result:
|
|
raise HTTPException(status_code=500, detail="剧情分析失败")
|
|
|
|
# 保存分析结果到数据库
|
|
plot_analysis = PlotAnalysis(
|
|
id=str(uuid.uuid4()),
|
|
project_id=project_id,
|
|
chapter_id=chapter_id,
|
|
plot_stage=analysis_result.get('plot_stage'),
|
|
conflict_level=analysis_result.get('conflict', {}).get('level'),
|
|
conflict_types=analysis_result.get('conflict', {}).get('types'),
|
|
emotional_tone=analysis_result.get('emotional_arc', {}).get('primary_emotion'),
|
|
emotional_intensity=analysis_result.get('emotional_arc', {}).get('intensity', 0) / 10,
|
|
emotional_curve=analysis_result.get('emotional_arc'),
|
|
hooks=analysis_result.get('hooks'),
|
|
hooks_count=len(analysis_result.get('hooks', [])),
|
|
hooks_avg_strength=sum(h.get('strength', 0) for h in analysis_result.get('hooks', [])) / max(len(analysis_result.get('hooks', [])), 1),
|
|
foreshadows=analysis_result.get('foreshadows'),
|
|
foreshadows_planted=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'planted'),
|
|
foreshadows_resolved=sum(1 for f in analysis_result.get('foreshadows', []) if f.get('type') == 'resolved'),
|
|
plot_points=analysis_result.get('plot_points'),
|
|
plot_points_count=len(analysis_result.get('plot_points', [])),
|
|
character_states=analysis_result.get('character_states'),
|
|
scenes=analysis_result.get('scenes'),
|
|
pacing=analysis_result.get('pacing'),
|
|
dialogue_ratio=analysis_result.get('dialogue_ratio'),
|
|
description_ratio=analysis_result.get('description_ratio'),
|
|
overall_quality_score=analysis_result.get('scores', {}).get('overall'),
|
|
pacing_score=analysis_result.get('scores', {}).get('pacing'),
|
|
engagement_score=analysis_result.get('scores', {}).get('engagement'),
|
|
coherence_score=analysis_result.get('scores', {}).get('coherence'),
|
|
analysis_report=analyzer.generate_analysis_summary(analysis_result),
|
|
suggestions=analysis_result.get('suggestions'),
|
|
word_count=chapter.word_count
|
|
)
|
|
|
|
# 检查是否已存在分析记录
|
|
existing = await db.execute(
|
|
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
# 删除旧记录
|
|
await db.execute(
|
|
select(PlotAnalysis).where(PlotAnalysis.chapter_id == chapter_id)
|
|
)
|
|
await db.delete(existing.scalar_one())
|
|
|
|
db.add(plot_analysis)
|
|
await db.commit()
|
|
|
|
# 从分析结果中提取记忆片段
|
|
memories_data = analyzer.extract_memories_from_analysis(
|
|
analysis_result,
|
|
chapter_id,
|
|
chapter.chapter_number
|
|
)
|
|
|
|
# 保存记忆到数据库和向量库
|
|
saved_count = 0
|
|
for mem_data in memories_data:
|
|
memory_id = str(uuid.uuid4())
|
|
|
|
# 保存到关系数据库
|
|
memory = StoryMemory(
|
|
id=memory_id,
|
|
project_id=project_id,
|
|
chapter_id=chapter_id,
|
|
memory_type=mem_data['type'],
|
|
title=mem_data.get('title', ''),
|
|
content=mem_data['content'],
|
|
story_timeline=chapter.chapter_number,
|
|
vector_id=memory_id,
|
|
**mem_data['metadata']
|
|
)
|
|
db.add(memory)
|
|
|
|
# 保存到向量库
|
|
await memory_service.add_memory(
|
|
user_id=user_id,
|
|
project_id=project_id,
|
|
memory_id=memory_id,
|
|
content=mem_data['content'],
|
|
memory_type=mem_data['type'],
|
|
metadata=mem_data['metadata']
|
|
)
|
|
saved_count += 1
|
|
|
|
await db.commit()
|
|
|
|
logger.info(f"✅ 章节分析完成: 保存{saved_count}条记忆")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"分析完成,提取了{saved_count}条记忆",
|
|
"analysis": plot_analysis.to_dict(),
|
|
"memories_count": saved_count
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ 章节分析失败: {str(e)}")
|
|
await db.rollback()
|
|
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
|
|
|
|
|
@router.get("/projects/{project_id}/memories")
|
|
async def get_project_memories(
|
|
project_id: str,
|
|
request: Request,
|
|
memory_type: Optional[str] = None,
|
|
chapter_id: Optional[str] = None,
|
|
limit: int = 50,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""获取项目的记忆列表"""
|
|
try:
|
|
user_id = request.state.user_id
|
|
|
|
# 构建查询
|
|
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
|
|
|
if memory_type:
|
|
query = query.where(StoryMemory.memory_type == memory_type)
|
|
if chapter_id:
|
|
query = query.where(StoryMemory.chapter_id == chapter_id)
|
|
|
|
query = query.order_by(desc(StoryMemory.importance_score), desc(StoryMemory.created_at)).limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
memories = result.scalars().all()
|
|
|
|
return {
|
|
"success": True,
|
|
"memories": [mem.to_dict() for mem in memories],
|
|
"total": len(memories)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 获取记忆失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/projects/{project_id}/analysis/{chapter_id}")
|
|
async def get_chapter_analysis(
|
|
project_id: str,
|
|
chapter_id: str,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""获取章节的剧情分析"""
|
|
try:
|
|
result = await db.execute(
|
|
select(PlotAnalysis).where(
|
|
and_(
|
|
PlotAnalysis.project_id == project_id,
|
|
PlotAnalysis.chapter_id == chapter_id
|
|
)
|
|
)
|
|
)
|
|
analysis = result.scalar_one_or_none()
|
|
|
|
if not analysis:
|
|
raise HTTPException(status_code=404, detail="该章节还未进行分析")
|
|
|
|
return {
|
|
"success": True,
|
|
"analysis": analysis.to_dict()
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ 获取分析失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/projects/{project_id}/search")
|
|
async def search_memories(
|
|
project_id: str,
|
|
request: Request,
|
|
query: str,
|
|
memory_types: Optional[List[str]] = None,
|
|
limit: int = 10,
|
|
min_importance: float = 0.0
|
|
):
|
|
"""语义搜索项目记忆"""
|
|
try:
|
|
user_id = request.state.user_id
|
|
|
|
memories = await memory_service.search_memories(
|
|
user_id=user_id,
|
|
project_id=project_id,
|
|
query=query,
|
|
memory_types=memory_types,
|
|
limit=limit,
|
|
min_importance=min_importance
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"query": query,
|
|
"memories": memories,
|
|
"total": len(memories)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 搜索记忆失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/projects/{project_id}/foreshadows")
|
|
async def get_unresolved_foreshadows(
|
|
project_id: str,
|
|
request: Request,
|
|
current_chapter: int,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""获取未完结的伏笔"""
|
|
try:
|
|
user_id = request.state.user_id
|
|
|
|
# 从向量库搜索
|
|
foreshadows = await memory_service.find_unresolved_foreshadows(
|
|
user_id=user_id,
|
|
project_id=project_id,
|
|
current_chapter=current_chapter
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"foreshadows": foreshadows,
|
|
"total": len(foreshadows)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 获取伏笔失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/projects/{project_id}/stats")
|
|
async def get_memory_stats(
|
|
project_id: str,
|
|
request: Request
|
|
):
|
|
"""获取记忆统计信息"""
|
|
try:
|
|
user_id = request.state.user_id
|
|
|
|
stats = await memory_service.get_memory_stats(
|
|
user_id=user_id,
|
|
project_id=project_id
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"stats": stats
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 获取统计失败: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete("/projects/{project_id}/chapters/{chapter_id}/memories")
|
|
async def delete_chapter_memories(
|
|
project_id: str,
|
|
chapter_id: str,
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""删除章节的所有记忆"""
|
|
try:
|
|
user_id = request.state.user_id
|
|
|
|
# 从数据库删除
|
|
result = await db.execute(
|
|
select(StoryMemory).where(
|
|
and_(
|
|
StoryMemory.project_id == project_id,
|
|
StoryMemory.chapter_id == chapter_id
|
|
)
|
|
)
|
|
)
|
|
memories = result.scalars().all()
|
|
|
|
for memory in memories:
|
|
await db.delete(memory)
|
|
|
|
# 从向量库删除
|
|
await memory_service.delete_chapter_memories(
|
|
user_id=user_id,
|
|
project_id=project_id,
|
|
chapter_id=chapter_id
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"已删除{len(memories)}条记忆"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ 删除记忆失败: {str(e)}")
|
|
await db.rollback()
|
|
raise HTTPException(status_code=500, detail=str(e)) |