update:更新自定义写作风格模块
This commit is contained in:
+36
-169
@@ -1,10 +1,11 @@
|
||||
"""章节管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.chapter import Chapter
|
||||
@@ -12,11 +13,13 @@ from app.models.project import Project
|
||||
from app.models.outline import Outline
|
||||
from app.models.character import Character
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
ChapterResponse,
|
||||
ChapterListResponse
|
||||
ChapterListResponse,
|
||||
ChapterGenerateRequest
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
@@ -245,183 +248,24 @@ async def check_can_generate(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{chapter_id}/generate", summary="AI创作章节内容")
|
||||
async def generate_chapter_content(
|
||||
chapter_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容
|
||||
要求:必须按顺序生成,确保前置章节都已完成
|
||||
"""
|
||||
# 获取章节
|
||||
result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
if not can_generate:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
try:
|
||||
# 获取项目信息
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 获取对应的大纲(使用新的查询确保获取最新数据)
|
||||
outline_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.where(Outline.order_index == chapter.chapter_number)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取所有大纲用于上下文(使用新的查询确保获取最新数据)
|
||||
all_outlines_result = await db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.order_by(Outline.order_index)
|
||||
.execution_options(populate_existing=True)
|
||||
)
|
||||
all_outlines = all_outlines_result.scalars().all()
|
||||
outlines_context = "\n".join([
|
||||
f"第{o.order_index}章 {o.title}: {o.content[:100]}..."
|
||||
for o in all_outlines
|
||||
])
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await db.execute(
|
||||
select(Character).where(Character.project_id == chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
characters_info = "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
])
|
||||
|
||||
# 构建前置章节内容上下文(如果有前置章节)
|
||||
previous_content = ""
|
||||
if previous_chapters:
|
||||
# Token控制:保留最近3章的完整内容,早期章节使用摘要
|
||||
recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters
|
||||
early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else []
|
||||
|
||||
# 早期章节摘要
|
||||
if early_chapters:
|
||||
early_summary = "【前期剧情概要】\n" + "\n".join([
|
||||
f"第{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..."
|
||||
for ch in early_chapters
|
||||
])
|
||||
previous_content += early_summary + "\n\n"
|
||||
|
||||
# 最近章节完整内容
|
||||
if recent_chapters:
|
||||
recent_content = "【最近章节完整内容】\n" + "\n\n".join([
|
||||
f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}"
|
||||
for ch in recent_chapters
|
||||
])
|
||||
previous_content += recent_content
|
||||
|
||||
logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容")
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词
|
||||
if previous_content:
|
||||
# 使用带上下文的提示词
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
previous_content=previous_content,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
|
||||
)
|
||||
else:
|
||||
# 第一章,使用原有提示词
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
narrative_perspective=project.narrative_perspective or '第三人称',
|
||||
time_period=project.world_time_period or '未设定',
|
||||
location=project.world_location or '未设定',
|
||||
atmosphere=project.world_atmosphere or '未设定',
|
||||
rules=project.world_rules or '未设定',
|
||||
characters_info=characters_info or '暂无角色信息',
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=outline.content if outline else chapter.summary or '暂无大纲'
|
||||
)
|
||||
|
||||
logger.info(f"开始AI创作章节 {chapter_id}")
|
||||
|
||||
# 调用AI生成
|
||||
ai_content = await user_ai_service.generate_text(
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# 更新章节内容
|
||||
old_word_count = chapter.word_count or 0
|
||||
chapter.content = ai_content
|
||||
new_word_count = len(ai_content)
|
||||
chapter.word_count = new_word_count
|
||||
chapter.status = "completed"
|
||||
|
||||
# 更新项目字数
|
||||
project.current_words = project.current_words - old_word_count + new_word_count
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=chapter.project_id,
|
||||
chapter_id=chapter.id,
|
||||
prompt=f"创作章节: 第{chapter.chapter_number}章 {chapter.title}",
|
||||
generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content,
|
||||
model="default"
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(chapter)
|
||||
|
||||
logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字")
|
||||
|
||||
return {"content": ai_content}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创作章节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}")
|
||||
|
||||
@router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)")
|
||||
async def generate_chapter_content_stream(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
generate_request: ChapterGenerateRequest = ChapterGenerateRequest(),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回)
|
||||
要求:必须按顺序生成,确保前置章节都已完成
|
||||
|
||||
请求体参数:
|
||||
- style_id: 可选,指定使用的写作风格ID。不提供则不使用任何风格
|
||||
|
||||
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
|
||||
以避免流式响应期间的连接泄漏问题
|
||||
"""
|
||||
style_id = generate_request.style_id
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
@@ -508,6 +352,27 @@ async def generate_chapter_content_stream(
|
||||
for c in characters
|
||||
])
|
||||
|
||||
# 获取写作风格
|
||||
style_content = ""
|
||||
if style_id:
|
||||
# 使用指定的风格
|
||||
style_result = await db_session.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = style_result.scalar_one_or_none()
|
||||
if style:
|
||||
# 验证风格是否可用:全局预设风格(project_id为NULL)或者当前项目的自定义风格
|
||||
if style.project_id is None or style.project_id == current_chapter.project_id:
|
||||
style_content = style.prompt_content or ""
|
||||
style_type = "全局预设" if style.project_id is None else "项目自定义"
|
||||
logger.info(f"使用指定风格: {style.name} ({style_type})")
|
||||
else:
|
||||
logger.warning(f"风格 {style_id} 不属于当前项目,无法使用")
|
||||
else:
|
||||
logger.warning(f"未找到风格 {style_id}")
|
||||
else:
|
||||
logger.info("未指定写作风格,使用原始提示词")
|
||||
|
||||
# 构建前置章节内容上下文(使用之前保存的数据)
|
||||
previous_content = ""
|
||||
if previous_chapters_data:
|
||||
@@ -533,7 +398,7 @@ async def generate_chapter_content_stream(
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
title=project.title,
|
||||
@@ -549,7 +414,8 @@ async def generate_chapter_content_stream(
|
||||
previous_content=previous_content,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content
|
||||
)
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
@@ -565,7 +431,8 @@ async def generate_chapter_content_stream(
|
||||
outlines_context=outlines_context,
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content
|
||||
)
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
@@ -40,6 +40,7 @@ async def create_project(
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
logger.info(f"项目创建成功: {db_project.id}")
|
||||
|
||||
return db_project
|
||||
except Exception as e:
|
||||
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.models.character import Character
|
||||
from app.models.outline import Outline
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.logger import get_logger
|
||||
@@ -132,9 +134,33 @@ async def world_building_generator(
|
||||
)
|
||||
db.add(project)
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 自动设置默认写作风格为第一个全局预设风格
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(
|
||||
WritingStyle.project_id.is_(None),
|
||||
WritingStyle.order_index == 1
|
||||
).limit(1)
|
||||
)
|
||||
first_style = result.scalar_one_or_none()
|
||||
|
||||
if first_style:
|
||||
default_style = ProjectDefaultStyle(
|
||||
project_id=project.id,
|
||||
style_id=first_style.id
|
||||
)
|
||||
db.add(default_style)
|
||||
await db.commit()
|
||||
logger.info(f"为项目 {project.id} 自动设置默认风格: {first_style.name}")
|
||||
else:
|
||||
logger.warning(f"未找到order_index=1的全局预设风格,项目 {project.id} 未设置默认风格")
|
||||
except Exception as e:
|
||||
logger.warning(f"设置默认写作风格失败: {e},不影响项目创建")
|
||||
|
||||
db_committed = True
|
||||
|
||||
# 发送最终结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
|
||||
@@ -0,0 +1,399 @@
|
||||
"""写作风格管理 API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
|
||||
from ..database import get_db
|
||||
from ..models.writing_style import WritingStyle
|
||||
from ..models.project import Project
|
||||
from ..models.project_default_style import ProjectDefaultStyle
|
||||
from ..schemas.writing_style import (
|
||||
WritingStyleCreate,
|
||||
WritingStyleUpdate,
|
||||
WritingStyleResponse,
|
||||
WritingStyleListResponse,
|
||||
SetDefaultStyleRequest
|
||||
)
|
||||
from ..services.prompt_service import WritingStyleManager
|
||||
|
||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
async def get_preset_styles():
|
||||
"""
|
||||
获取所有预设风格列表
|
||||
|
||||
返回格式:数组形式的预设风格列表
|
||||
[
|
||||
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": "classical", "name": "古典优雅", ...}
|
||||
]
|
||||
"""
|
||||
presets = WritingStyleManager.get_all_presets()
|
||||
# 将字典转换为数组,添加 id 字段
|
||||
return [
|
||||
{"id": preset_id, **preset_data}
|
||||
for preset_id, preset_data in presets.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
||||
async def create_writing_style(
|
||||
style_data: WritingStyleCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新的写作风格
|
||||
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == style_data.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
if style_data.preset_id:
|
||||
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
|
||||
if not preset:
|
||||
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
||||
|
||||
# 使用预设内容填充(如果用户未提供)
|
||||
if not style_data.name:
|
||||
style_data.name = preset["name"]
|
||||
if not style_data.description:
|
||||
style_data.description = preset["description"]
|
||||
if not style_data.prompt_content:
|
||||
style_data.prompt_content = preset["prompt_content"]
|
||||
|
||||
# 验证必填字段
|
||||
if not style_data.name or not style_data.prompt_content:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="name 和 prompt_content 是必填字段"
|
||||
)
|
||||
|
||||
# 获取当前最大 order_index
|
||||
count_result = await db.execute(
|
||||
select(func.count(WritingStyle.id))
|
||||
.where(WritingStyle.project_id == style_data.project_id)
|
||||
)
|
||||
max_order = count_result.scalar_one()
|
||||
|
||||
# 创建风格记录
|
||||
new_style = WritingStyle(
|
||||
project_id=style_data.project_id,
|
||||
name=style_data.name,
|
||||
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
|
||||
preset_id=style_data.preset_id,
|
||||
description=style_data.description,
|
||||
prompt_content=style_data.prompt_content,
|
||||
order_index=max_order + 1
|
||||
)
|
||||
|
||||
db.add(new_style)
|
||||
await db.commit()
|
||||
await db.refresh(new_style)
|
||||
|
||||
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
|
||||
return {
|
||||
"id": new_style.id,
|
||||
"project_id": new_style.project_id,
|
||||
"name": new_style.name,
|
||||
"style_type": new_style.style_type,
|
||||
"preset_id": new_style.preset_id,
|
||||
"description": new_style.description,
|
||||
"prompt_content": new_style.prompt_content,
|
||||
"order_index": new_style.order_index,
|
||||
"created_at": new_style.created_at,
|
||||
"updated_at": new_style.updated_at,
|
||||
"is_default": False
|
||||
}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||
async def get_project_styles(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目的所有可用写作风格
|
||||
|
||||
返回:全局预设风格 + 该项目的自定义风格
|
||||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 获取该项目的默认风格ID
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle.style_id)
|
||||
.where(ProjectDefaultStyle.project_id == project_id)
|
||||
)
|
||||
default_style_id = result.scalar_one_or_none()
|
||||
|
||||
# 获取全局预设风格(project_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = list(result.scalars().all())
|
||||
|
||||
# 获取项目自定义风格
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id == project_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
custom_styles = list(result.scalars().all())
|
||||
|
||||
# 合并:预设风格 + 自定义风格
|
||||
all_styles = preset_styles + custom_styles
|
||||
|
||||
# 为每个风格添加 is_default 标记(用于前端显示)
|
||||
styles_with_default = []
|
||||
for style in all_styles:
|
||||
style_dict = {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"order_index": style.order_index,
|
||||
"created_at": style.created_at,
|
||||
"updated_at": style.updated_at,
|
||||
"is_default": style.id == default_style_id
|
||||
}
|
||||
styles_with_default.append(style_dict)
|
||||
|
||||
return {"styles": styles_with_default, "total": len(styles_with_default)}
|
||||
|
||||
|
||||
@router.get("/{style_id}", response_model=WritingStyleResponse)
|
||||
async def get_writing_style(
|
||||
style_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取单个写作风格详情"""
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
)
|
||||
is_default = result.scalar_one_or_none() is not None
|
||||
|
||||
# 返回包含 is_default 字段的字典
|
||||
return {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"order_index": style.order_index,
|
||||
"created_at": style.created_at,
|
||||
"updated_at": style.updated_at,
|
||||
"is_default": is_default
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{style_id}", response_model=WritingStyleResponse)
|
||||
async def update_writing_style(
|
||||
style_id: int,
|
||||
style_data: WritingStyleUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新写作风格
|
||||
|
||||
- 只能修改自定义风格
|
||||
- 不能修改全局预设风格
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否为全局预设风格(不允许修改)
|
||||
if style.project_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||||
|
||||
# 更新字段
|
||||
update_data = style_data.model_dump(exclude_unset=True)
|
||||
|
||||
# 如果修改了内容,将 style_type 改为 custom
|
||||
if any(key in update_data for key in ["name", "description", "prompt_content"]):
|
||||
update_data["style_type"] = "custom"
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(style, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(style)
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
)
|
||||
is_default = result.scalar_one_or_none() is not None
|
||||
|
||||
# 返回包含 is_default 字段的字典
|
||||
return {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"order_index": style.order_index,
|
||||
"created_at": style.created_at,
|
||||
"updated_at": style.updated_at,
|
||||
"is_default": is_default
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{style_id}", status_code=204)
|
||||
async def delete_writing_style(
|
||||
style_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除写作风格
|
||||
|
||||
注意:
|
||||
- 只能删除自定义风格,不能删除全局预设风格
|
||||
- 不能删除默认风格(必须先设置其他风格为默认)
|
||||
- 删除后无法恢复
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否为全局预设风格(不允许删除)
|
||||
if style.project_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
)
|
||||
default_relation = result.scalar_one_or_none()
|
||||
if default_relation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="不能删除默认风格,请先设置其他风格为默认"
|
||||
)
|
||||
|
||||
await db.delete(style)
|
||||
await db.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{style_id}/set-default", response_model=dict)
|
||||
async def set_default_style(
|
||||
style_id: int,
|
||||
request_data: SetDefaultStyleRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
将指定风格设置为项目的默认风格
|
||||
|
||||
使用 project_default_styles 表记录项目的默认风格选择
|
||||
每个项目只能有一个默认风格(通过 UniqueConstraint 保证)
|
||||
|
||||
参数:
|
||||
- style_id: 要设置为默认的风格ID(路径参数)
|
||||
- project_id: 项目ID(请求体),用于确定在哪个项目上下文中设置默认
|
||||
"""
|
||||
project_id = request_data.project_id
|
||||
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 验证风格是否存在
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 验证风格是否属于该项目(自定义风格)或是全局预设风格
|
||||
if style.project_id is not None and style.project_id != project_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作其他项目的风格")
|
||||
|
||||
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
|
||||
await db.execute(
|
||||
delete(ProjectDefaultStyle).where(ProjectDefaultStyle.project_id == project_id)
|
||||
)
|
||||
|
||||
# 插入新的默认风格记录
|
||||
new_default = ProjectDefaultStyle(
|
||||
project_id=project_id,
|
||||
style_id=style_id
|
||||
)
|
||||
db.add(new_default)
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"message": "默认风格设置成功",
|
||||
"project_id": project_id,
|
||||
"style_id": style_id,
|
||||
"style_name": style.name
|
||||
}
|
||||
|
||||
|
||||
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
||||
async def initialize_default_styles(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
【已废弃】为项目初始化默认风格
|
||||
|
||||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
||||
return await get_project_styles(project_id, db)
|
||||
Reference in New Issue
Block a user