Files
MuMuAINovel/backend/app/api/writing_styles.py
T

504 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""写作风格管理 API"""
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
from ..database import get_db
from ..models.writing_style import WritingStyle
from ..models.project import Project
from ..models.project_default_style import ProjectDefaultStyle
from ..schemas.writing_style import (
WritingStyleCreate,
WritingStyleUpdate,
WritingStyleResponse,
WritingStyleListResponse,
SetDefaultStyleRequest
)
from ..logger import get_logger
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
logger = get_logger(__name__)
def get_current_user_id(request: Request) -> str:
"""获取当前登录用户ID"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
return user_id
@router.get("/presets/list", response_model=List[dict])
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
"""
获取所有预设风格列表(从数据库读取)
返回格式:数组形式的预设风格列表
[
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
]
"""
# 从数据库获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = result.scalars().all()
# 转换为响应格式
return [
{
"id": style.id,
"preset_id": style.preset_id,
"name": style.name,
"description": style.description,
"prompt_content": style.prompt_content,
"style_type": style.style_type,
"order_index": style.order_index
}
for style in preset_styles
]
@router.post("", response_model=WritingStyleResponse, status_code=201)
async def create_writing_style(
style_data: WritingStyleCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
创建新的写作风格(用户级别)
- **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
# 如果基于预设创建,从数据库获取预设内容
if style_data.preset_id:
result = await db.execute(
select(WritingStyle)
.where(
WritingStyle.user_id.is_(None),
WritingStyle.preset_id == style_data.preset_id
)
)
preset = result.scalar_one_or_none()
if not preset:
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
# 使用预设内容填充(如果用户未提供)
if not style_data.name:
style_data.name = preset.name
if not style_data.description:
style_data.description = preset.description
if not style_data.prompt_content:
style_data.prompt_content = preset.prompt_content
# 验证必填字段
if not style_data.name or not style_data.prompt_content:
raise HTTPException(
status_code=400,
detail="name 和 prompt_content 是必填字段"
)
# 获取当前用户的最大 order_index
count_result = await db.execute(
select(func.count(WritingStyle.id))
.where(WritingStyle.user_id == user_id)
)
max_order = count_result.scalar_one()
# 创建风格记录
new_style = WritingStyle(
user_id=user_id,
name=style_data.name,
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
preset_id=style_data.preset_id,
description=style_data.description,
prompt_content=style_data.prompt_content,
order_index=max_order + 1
)
db.add(new_style)
await db.commit()
await db.refresh(new_style)
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
return {
"id": new_style.id,
"user_id": new_style.user_id,
"name": new_style.name,
"style_type": new_style.style_type,
"preset_id": new_style.preset_id,
"description": new_style.description,
"prompt_content": new_style.prompt_content,
"order_index": new_style.order_index,
"created_at": new_style.created_at,
"updated_at": new_style.updated_at,
"is_default": False
}
@router.get("/user", response_model=WritingStyleListResponse)
async def get_user_styles(
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
获取用户的所有可用写作风格
返回:全局预设风格 + 该用户的自定义风格
按 order_index 排序
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
# 获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = list(result.scalars().all())
# 获取用户自定义风格
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id == user_id)
.order_by(WritingStyle.order_index)
)
custom_styles = list(result.scalars().all())
# 合并:预设风格 + 自定义风格
all_styles = preset_styles + custom_styles
# 转换为响应格式
styles_with_default = []
for style in all_styles:
style_dict = {
"id": style.id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": False # 用户级别不再需要默认风格标记
}
styles_with_default.append(style_dict)
return {"styles": styles_with_default, "total": len(styles_with_default)}
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
async def get_project_styles(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
获取项目可用的所有写作风格(保留用于向后兼容)
返回:全局预设风格 + 该用户的自定义风格
按 order_index 排序,并标记哪个是当前项目的默认风格
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
# 验证项目访问权限
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
# 获取该项目的默认风格ID
result = await db.execute(
select(ProjectDefaultStyle.style_id)
.where(ProjectDefaultStyle.project_id == project_id)
)
default_style_id = result.scalar_one_or_none()
# 获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = list(result.scalars().all())
# 获取用户自定义风格
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id == user_id)
.order_by(WritingStyle.order_index)
)
custom_styles = list(result.scalars().all())
# 合并:预设风格 + 自定义风格
all_styles = preset_styles + custom_styles
# 为每个风格添加 is_default 标记(用于前端显示)
styles_with_default = []
for style in all_styles:
style_dict = {
"id": style.id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": style.id == default_style_id
}
styles_with_default.append(style_dict)
return {"styles": styles_with_default, "total": len(styles_with_default)}
@router.get("/{style_id}", response_model=WritingStyleResponse)
async def get_writing_style(
style_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单个写作风格详情"""
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
is_default = result.scalar_one_or_none() is not None
# 返回包含 is_default 字段的字典
return {
"id": style.id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": is_default
}
@router.put("/{style_id}", response_model=WritingStyleResponse)
async def update_writing_style(
style_id: int,
style_data: WritingStyleUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
更新写作风格
- 只能修改自定义风格
- 不能修改全局预设风格
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否为全局预设风格(不允许修改)
if style.user_id is None:
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
# 验证用户权限(只能修改自己的风格)
if style.user_id != user_id:
raise HTTPException(status_code=403, detail="无权修改其他用户的风格")
# 更新字段
update_data = style_data.model_dump(exclude_unset=True)
# 如果修改了内容,将 style_type 改为 custom
if any(key in update_data for key in ["name", "description", "prompt_content"]):
update_data["style_type"] = "custom"
for key, value in update_data.items():
setattr(style, key, value)
await db.commit()
await db.refresh(style)
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
is_default = result.scalar_one_or_none() is not None
# 返回包含 is_default 字段的字典
return {
"id": style.id,
"user_id": style.user_id,
"name": style.name,
"style_type": style.style_type,
"preset_id": style.preset_id,
"description": style.description,
"prompt_content": style.prompt_content,
"order_index": style.order_index,
"created_at": style.created_at,
"updated_at": style.updated_at,
"is_default": is_default
}
@router.delete("/{style_id}", status_code=204)
async def delete_writing_style(
style_id: int,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
删除写作风格
注意:
- 只能删除自定义风格,不能删除全局预设风格
- 不能删除默认风格(必须先设置其他风格为默认)
- 删除后无法恢复
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否为全局预设风格(不允许删除)
if style.user_id is None:
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
# 验证用户权限(只能删除自己的风格)
if style.user_id != user_id:
raise HTTPException(status_code=403, detail="无权删除其他用户的风格")
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
)
default_relation = result.scalar_one_or_none()
if default_relation:
raise HTTPException(
status_code=400,
detail="不能删除默认风格,请先设置其他风格为默认"
)
await db.delete(style)
await db.commit()
return None
@router.post("/{style_id}/set-default", response_model=dict)
async def set_default_style(
style_id: int,
request_data: SetDefaultStyleRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
将指定风格设置为项目的默认风格
使用 project_default_styles 表记录项目的默认风格选择
每个项目只能有一个默认风格(通过 UniqueConstraint 保证)
参数:
- style_id: 要设置为默认的风格ID(路径参数)
- project_id: 项目ID(请求体),用于确定在哪个项目上下文中设置默认
"""
project_id = request_data.project_id
# 获取当前用户ID
user_id = get_current_user_id(request)
# 验证项目访问权限
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
# 验证风格是否存在
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 验证风格是否属于该用户(自定义风格)或是全局预设风格
if style.user_id is not None and style.user_id != user_id:
raise HTTPException(status_code=403, detail="无权操作其他用户的风格")
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
await db.execute(
delete(ProjectDefaultStyle).where(ProjectDefaultStyle.project_id == project_id)
)
# 插入新的默认风格记录
new_default = ProjectDefaultStyle(
project_id=project_id,
style_id=style_id
)
db.add(new_default)
await db.commit()
return {
"message": "默认风格设置成功",
"project_id": project_id,
"style_id": style_id,
"style_name": style.name
}
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
async def initialize_default_styles(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
【已废弃】为项目初始化默认风格
新架构下,预设风格是全局的,不需要为每个项目单独初始化
该接口保留用于兼容性,直接返回项目可用的所有风格
"""
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
return await get_project_styles(project_id, request, db)