update:更新自定义写作风格模块

This commit is contained in:
xiamuceer
2025-10-31 17:23:25 +08:00
parent b5be954112
commit e94e81c5f4
21 changed files with 1550 additions and 326 deletions
+36 -169
View File
@@ -1,10 +1,11 @@
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
import asyncio
from typing import Optional
from app.database import get_db
from app.models.chapter import Chapter
@@ -12,11 +13,13 @@ from app.models.project import Project
from app.models.outline import Outline
from app.models.character import Character
from app.models.generation_history import GenerationHistory
from app.models.writing_style import WritingStyle
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
ChapterResponse,
ChapterListResponse
ChapterListResponse,
ChapterGenerateRequest
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
@@ -245,183 +248,24 @@ async def check_can_generate(
}
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
async def generate_chapter_content(
chapter_id: str,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容
要求:必须按顺序生成,确保前置章节都已完成
"""
# 获取章节
result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
# 检查前置条件
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
if not can_generate:
raise HTTPException(status_code=400, detail=error_msg)
try:
# 获取项目信息
project_result = await db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取对应的大纲(使用新的查询确保获取最新数据)
outline_result = await db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.where(Outline.order_index == chapter.chapter_number)
.execution_options(populate_existing=True)
)
outline = outline_result.scalar_one_or_none()
# 获取所有大纲用于上下文(使用新的查询确保获取最新数据)
all_outlines_result = await db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.order_by(Outline.order_index)
.execution_options(populate_existing=True)
)
all_outlines = all_outlines_result.scalars().all()
outlines_context = "\n".join([
f"{o.order_index}{o.title}: {o.content[:100]}..."
for o in all_outlines
])
# 获取角色信息
characters_result = await db.execute(
select(Character).where(Character.project_id == chapter.project_id)
)
characters = characters_result.scalars().all()
characters_info = "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
])
# 构建前置章节内容上下文(如果有前置章节)
previous_content = ""
if previous_chapters:
# Token控制:保留最近3章的完整内容,早期章节使用摘要
recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters
early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else []
# 早期章节摘要
if early_chapters:
early_summary = "【前期剧情概要】\n" + "\n".join([
f"{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..."
for ch in early_chapters
])
previous_content += early_summary + "\n\n"
# 最近章节完整内容
if recent_chapters:
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}"
for ch in recent_chapters
])
previous_content += recent_content
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
# 根据是否有前置内容选择不同的提示词
if previous_content:
# 使用带上下文的提示词
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
previous_content=previous_content,
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
)
else:
# 第一章,使用原有提示词
prompt = prompt_service.get_chapter_generation_prompt(
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
narrative_perspective=project.narrative_perspective or '第三人称',
time_period=project.world_time_period or '未设定',
location=project.world_location or '未设定',
atmosphere=project.world_atmosphere or '未设定',
rules=project.world_rules or '未设定',
characters_info=characters_info or '暂无角色信息',
outlines_context=outlines_context,
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
)
logger.info(f"开始AI创作章节 {chapter_id}")
# 调用AI生成
ai_content = await user_ai_service.generate_text(
prompt=prompt
)
# 更新章节内容
old_word_count = chapter.word_count or 0
chapter.content = ai_content
new_word_count = len(ai_content)
chapter.word_count = new_word_count
chapter.status = "completed"
# 更新项目字数
project.current_words = project.current_words - old_word_count + new_word_count
# 记录生成历史
history = GenerationHistory(
project_id=chapter.project_id,
chapter_id=chapter.id,
prompt=f"创作章节: 第{chapter.chapter_number}{chapter.title}",
generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content,
model="default"
)
db.add(history)
await db.commit()
await db.refresh(chapter)
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count}")
return {"content": ai_content}
except Exception as e:
logger.error(f"创作章节失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}")
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
async def generate_chapter_content_stream(
chapter_id: str,
request: Request,
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
要求:必须按顺序生成,确保前置章节都已完成
请求体参数:
- style_id: 可选,指定使用的写作风格ID。不提供则不使用任何风格
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
以避免流式响应期间的连接泄漏问题
"""
style_id = generate_request.style_id
# 预先验证章节存在性(使用临时会话)
async for temp_db in get_db(request):
try:
@@ -508,6 +352,27 @@ async def generate_chapter_content_stream(
for c in characters
])
# 获取写作风格
style_content = ""
if style_id:
# 使用指定的风格
style_result = await db_session.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = style_result.scalar_one_or_none()
if style:
# 验证风格是否可用:全局预设风格(project_id为NULL)或者当前项目的自定义风格
if style.project_id is None or style.project_id == current_chapter.project_id:
style_content = style.prompt_content or ""
style_type = "全局预设" if style.project_id is None else "项目自定义"
logger.info(f"使用指定风格: {style.name} ({style_type})")
else:
logger.warning(f"风格 {style_id} 不属于当前项目,无法使用")
else:
logger.warning(f"未找到风格 {style_id}")
else:
logger.info("未指定写作风格,使用原始提示词")
# 构建前置章节内容上下文(使用之前保存的数据)
previous_content = ""
if previous_chapters_data:
@@ -533,7 +398,7 @@ async def generate_chapter_content_stream(
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 根据是否有前置内容选择不同的提示词
# 根据是否有前置内容选择不同的提示词,并应用写作风格
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
title=project.title,
@@ -549,7 +414,8 @@ async def generate_chapter_content_stream(
previous_content=previous_content,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content
)
else:
prompt = prompt_service.get_chapter_generation_prompt(
@@ -565,7 +431,8 @@ async def generate_chapter_content_stream(
outlines_context=outlines_context,
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
style_content=style_content
)
logger.info(f"开始AI流式创作章节 {chapter_id}")
+1
View File
@@ -40,6 +40,7 @@ async def create_project(
await db.commit()
await db.refresh(db_project)
logger.info(f"项目创建成功: {db_project.id}")
return db_project
except Exception as e:
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
+27 -1
View File
@@ -12,6 +12,8 @@ from app.models.character import Character
from app.models.outline import Outline
from app.models.chapter import Chapter
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
from app.models.writing_style import WritingStyle
from app.models.project_default_style import ProjectDefaultStyle
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.logger import get_logger
@@ -132,9 +134,33 @@ async def world_building_generator(
)
db.add(project)
await db.commit()
db_committed = True
await db.refresh(project)
# 自动设置默认写作风格为第一个全局预设风格
try:
result = await db.execute(
select(WritingStyle).where(
WritingStyle.project_id.is_(None),
WritingStyle.order_index == 1
).limit(1)
)
first_style = result.scalar_one_or_none()
if first_style:
default_style = ProjectDefaultStyle(
project_id=project.id,
style_id=first_style.id
)
db.add(default_style)
await db.commit()
logger.info(f"为项目 {project.id} 自动设置默认风格: {first_style.name}")
else:
logger.warning(f"未找到order_index=1的全局预设风格,项目 {project.id} 未设置默认风格")
except Exception as e:
logger.warning(f"设置默认写作风格失败: {e},不影响项目创建")
db_committed = True
# 发送最终结果
yield await SSEResponse.send_result({
"project_id": project.id,
+399
View File
@@ -0,0 +1,399 @@
"""写作风格管理 API"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
from ..database import get_db
from ..models.writing_style import WritingStyle
from ..models.project import Project
from ..models.project_default_style import ProjectDefaultStyle
from ..schemas.writing_style import (
WritingStyleCreate,
WritingStyleUpdate,
WritingStyleResponse,
WritingStyleListResponse,
SetDefaultStyleRequest
)
from ..services.prompt_service import WritingStyleManager
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
@router.get("/presets/list", response_model=List[dict])
async def get_preset_styles():
"""
获取所有预设风格列表
返回格式:数组形式的预设风格列表
[
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": "classical", "name": "古典优雅", ...}
]
"""
presets = WritingStyleManager.get_all_presets()
# 将字典转换为数组,添加 id 字段
return [
{"id": preset_id, **preset_data}
for preset_id, preset_data in presets.items()
]
@router.post("", response_model=WritingStyleResponse, status_code=201)
async def create_writing_style(
style_data: WritingStyleCreate,
db: AsyncSession = Depends(get_db)
):
"""
创建新的写作风格
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == style_data.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 如果基于预设创建,获取预设内容
if style_data.preset_id:
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
if not preset:
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
# 使用预设内容填充(如果用户未提供)
if not style_data.name:
style_data.name = preset["name"]
if not style_data.description:
style_data.description = preset["description"]
if not style_data.prompt_content:
style_data.prompt_content = preset["prompt_content"]
# 验证必填字段
if not style_data.name or not style_data.prompt_content:
raise HTTPException(
status_code=400,
detail="name 和 prompt_content 是必填字段"
)
# 获取当前最大 order_index
count_result = await db.execute(
select(func.count(WritingStyle.id))
.where(WritingStyle.project_id == style_data.project_id)
)
max_order = count_result.scalar_one()
# 创建风格记录
new_style = WritingStyle(
project_id=style_data.project_id,
name=style_data.name,
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
preset_id=style_data.preset_id,
description=style_data.description,
prompt_content=style_data.prompt_content,
order_index=max_order + 1
)
db.add(new_style)
await db.commit()
await db.refresh(new_style)
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
return {
"id": new_style.id,
"project_id": new_style.project_id,
"name": new_style.name,
"style_type": new_style.style_type,
"preset_id": new_style.preset_id,
"description": new_style.description,
"prompt_content": new_style.prompt_content,
"order_index": new_style.order_index,
"created_at": new_style.created_at,
"updated_at": new_style.updated_at,
"is_default": False
}
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
async def get_project_styles(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
获取项目的所有可用写作风格
返回:全局预设风格 + 该项目的自定义风格
按 order_index 排序,并标记哪个是当前项目的默认风格
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 获取该项目的默认风格ID
result = await db.execute(
select(ProjectDefaultStyle.style_id)
.where(ProjectDefaultStyle.project_id == project_id)
)
default_style_id = result.scalar_one_or_none()
# 获取全局预设风格(project_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = list(result.scalars().all())
# 获取项目自定义风格
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.project_id == project_id)
.order_by(WritingStyle.order_index)
)
custom_styles = list(result.scalars().all())
# 合并:预设风格 + 自定义风格
all_styles = preset_styles + custom_styles
# 为每个风格添加 is_default 标记(用于前端显示)
styles_with_default = []
for style in all_styles:
style_dict = {
"id": style.id,
"project_id": style.project_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": style.id == default_style_id
}
styles_with_default.append(style_dict)
return {"styles": styles_with_default, "total": len(styles_with_default)}
@router.get("/{style_id}", response_model=WritingStyleResponse)
async def get_writing_style(
style_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单个写作风格详情"""
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
is_default = result.scalar_one_or_none() is not None
# 返回包含 is_default 字段的字典
return {
"id": style.id,
"project_id": style.project_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": is_default
}
@router.put("/{style_id}", response_model=WritingStyleResponse)
async def update_writing_style(
style_id: int,
style_data: WritingStyleUpdate,
db: AsyncSession = Depends(get_db)
):
"""
更新写作风格
- 只能修改自定义风格
- 不能修改全局预设风格
"""
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否为全局预设风格(不允许修改)
if style.project_id is None:
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
# 更新字段
update_data = style_data.model_dump(exclude_unset=True)
# 如果修改了内容,将 style_type 改为 custom
if any(key in update_data for key in ["name", "description", "prompt_content"]):
update_data["style_type"] = "custom"
for key, value in update_data.items():
setattr(style, key, value)
await db.commit()
await db.refresh(style)
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
is_default = result.scalar_one_or_none() is not None
# 返回包含 is_default 字段的字典
return {
"id": style.id,
"project_id": style.project_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": is_default
}
@router.delete("/{style_id}", status_code=204)
async def delete_writing_style(
style_id: int,
db: AsyncSession = Depends(get_db)
):
"""
删除写作风格
注意:
- 只能删除自定义风格,不能删除全局预设风格
- 不能删除默认风格(必须先设置其他风格为默认)
- 删除后无法恢复
"""
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否为全局预设风格(不允许删除)
if style.project_id is None:
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
default_relation = result.scalar_one_or_none()
if default_relation:
raise HTTPException(
status_code=400,
detail="不能删除默认风格,请先设置其他风格为默认"
)
await db.delete(style)
await db.commit()
return None
@router.post("/{style_id}/set-default", response_model=dict)
async def set_default_style(
style_id: int,
request_data: SetDefaultStyleRequest,
db: AsyncSession = Depends(get_db)
):
"""
将指定风格设置为项目的默认风格
使用 project_default_styles 表记录项目的默认风格选择
每个项目只能有一个默认风格(通过 UniqueConstraint 保证)
参数:
- style_id: 要设置为默认的风格ID(路径参数)
- project_id: 项目ID(请求体),用于确定在哪个项目上下文中设置默认
"""
project_id = request_data.project_id
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 验证风格是否存在
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 验证风格是否属于该项目(自定义风格)或是全局预设风格
if style.project_id is not None and style.project_id != project_id:
raise HTTPException(status_code=403, detail="无权操作其他项目的风格")
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
await db.execute(
delete(ProjectDefaultStyle).where(ProjectDefaultStyle.project_id == project_id)
)
# 插入新的默认风格记录
new_default = ProjectDefaultStyle(
project_id=project_id,
style_id=style_id
)
db.add(new_default)
await db.commit()
return {
"message": "默认风格设置成功",
"project_id": project_id,
"style_id": style_id,
"style_name": style.name
}
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
async def initialize_default_styles(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
【已废弃】为项目初始化默认风格
新架构下,预设风格是全局的,不需要为每个项目单独初始化
该接口保留用于兼容性,直接返回项目可用的所有风格
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
return await get_project_styles(project_id, db)