509 lines
17 KiB
Python
509 lines
17 KiB
Python
"""写作风格管理 API"""
|
||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, func, delete
|
||
from typing import List
|
||
|
||
from ..database import get_db
|
||
from ..models.writing_style import WritingStyle
|
||
from ..models.project import Project
|
||
from ..models.project_default_style import ProjectDefaultStyle
|
||
from ..schemas.writing_style import (
|
||
WritingStyleCreate,
|
||
WritingStyleUpdate,
|
||
WritingStyleResponse,
|
||
WritingStyleListResponse,
|
||
SetDefaultStyleRequest
|
||
)
|
||
from ..logger import get_logger
|
||
|
||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def get_current_user_id(request: Request) -> str:
|
||
"""获取当前登录用户ID"""
|
||
user_id = getattr(request.state, 'user_id', None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
return user_id
|
||
|
||
|
||
@router.get("/presets/list", response_model=List[dict])
|
||
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
|
||
"""
|
||
获取所有预设风格列表(从数据库读取)
|
||
|
||
返回格式:数组形式的预设风格列表
|
||
[
|
||
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
|
||
]
|
||
"""
|
||
# 从数据库获取全局预设风格(user_id 为 NULL)
|
||
result = await db.execute(
|
||
select(WritingStyle)
|
||
.where(WritingStyle.user_id.is_(None))
|
||
.order_by(WritingStyle.order_index)
|
||
)
|
||
preset_styles = result.scalars().all()
|
||
|
||
# 转换为响应格式
|
||
return [
|
||
{
|
||
"id": style.id,
|
||
"preset_id": style.preset_id,
|
||
"name": style.name,
|
||
"description": style.description,
|
||
"prompt_content": style.prompt_content,
|
||
"style_type": style.style_type,
|
||
"order_index": style.order_index
|
||
}
|
||
for style in preset_styles
|
||
]
|
||
|
||
|
||
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
||
async def create_writing_style(
|
||
style_data: WritingStyleCreate,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
创建新的写作风格(用户级别)
|
||
|
||
- **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充
|
||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||
"""
|
||
# 获取当前用户ID
|
||
user_id = get_current_user_id(request)
|
||
|
||
# 如果基于预设创建,从数据库获取预设内容
|
||
if style_data.preset_id:
|
||
result = await db.execute(
|
||
select(WritingStyle)
|
||
.where(
|
||
WritingStyle.user_id.is_(None),
|
||
WritingStyle.preset_id == style_data.preset_id
|
||
)
|
||
)
|
||
preset = result.scalar_one_or_none()
|
||
|
||
if not preset:
|
||
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
||
|
||
# 使用预设内容填充(如果用户未提供)
|
||
if not style_data.name:
|
||
style_data.name = preset.name
|
||
if not style_data.description:
|
||
style_data.description = preset.description
|
||
if not style_data.prompt_content:
|
||
style_data.prompt_content = preset.prompt_content
|
||
|
||
# 验证必填字段
|
||
if not style_data.name or not style_data.prompt_content:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="name 和 prompt_content 是必填字段"
|
||
)
|
||
|
||
# 获取当前用户的最大 order_index
|
||
count_result = await db.execute(
|
||
select(func.count(WritingStyle.id))
|
||
.where(WritingStyle.user_id == user_id)
|
||
)
|
||
max_order = count_result.scalar_one()
|
||
|
||
# 创建风格记录
|
||
new_style = WritingStyle(
|
||
user_id=user_id,
|
||
name=style_data.name,
|
||
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
|
||
preset_id=style_data.preset_id,
|
||
description=style_data.description,
|
||
prompt_content=style_data.prompt_content,
|
||
order_index=max_order + 1
|
||
)
|
||
|
||
db.add(new_style)
|
||
await db.commit()
|
||
await db.refresh(new_style)
|
||
|
||
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
|
||
return {
|
||
"id": new_style.id,
|
||
"user_id": new_style.user_id,
|
||
"name": new_style.name,
|
||
"style_type": new_style.style_type,
|
||
"preset_id": new_style.preset_id,
|
||
"description": new_style.description,
|
||
"prompt_content": new_style.prompt_content,
|
||
"order_index": new_style.order_index,
|
||
"created_at": new_style.created_at,
|
||
"updated_at": new_style.updated_at,
|
||
"is_default": False
|
||
}
|
||
|
||
|
||
@router.get("/user", response_model=WritingStyleListResponse)
|
||
async def get_user_styles(
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
获取用户的所有可用写作风格
|
||
|
||
返回:全局预设风格 + 该用户的自定义风格
|
||
按 order_index 排序
|
||
"""
|
||
# 获取当前用户ID
|
||
user_id = get_current_user_id(request)
|
||
|
||
# 获取全局预设风格(user_id 为 NULL)
|
||
result = await db.execute(
|
||
select(WritingStyle)
|
||
.where(WritingStyle.user_id.is_(None))
|
||
.order_by(WritingStyle.order_index)
|
||
)
|
||
preset_styles = list(result.scalars().all())
|
||
|
||
# 获取用户自定义风格
|
||
result = await db.execute(
|
||
select(WritingStyle)
|
||
.where(WritingStyle.user_id == user_id)
|
||
.order_by(WritingStyle.order_index)
|
||
)
|
||
custom_styles = list(result.scalars().all())
|
||
|
||
# 合并:预设风格 + 自定义风格
|
||
all_styles = preset_styles + custom_styles
|
||
|
||
# 转换为响应格式
|
||
styles_with_default = []
|
||
for style in all_styles:
|
||
style_dict = {
|
||
"id": style.id,
|
||
"user_id": style.user_id,
|
||
"name": style.name,
|
||
"style_type": style.style_type,
|
||
"preset_id": style.preset_id,
|
||
"description": style.description,
|
||
"prompt_content": style.prompt_content,
|
||
"order_index": style.order_index,
|
||
"created_at": style.created_at,
|
||
"updated_at": style.updated_at,
|
||
"is_default": False # 用户级别不再需要默认风格标记
|
||
}
|
||
styles_with_default.append(style_dict)
|
||
|
||
return {"styles": styles_with_default, "total": len(styles_with_default)}
|
||
|
||
|
||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||
async def get_project_styles(
|
||
project_id: str,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
获取项目可用的所有写作风格(保留用于向后兼容)
|
||
|
||
返回:全局预设风格 + 该用户的自定义风格
|
||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||
"""
|
||
# 获取当前用户ID
|
||
user_id = get_current_user_id(request)
|
||
|
||
# 验证项目访问权限
|
||
result = await db.execute(
|
||
select(Project).where(
|
||
Project.id == project_id,
|
||
Project.user_id == user_id
|
||
)
|
||
)
|
||
project = result.scalar_one_or_none()
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||
|
||
# 获取该项目的默认风格ID
|
||
result = await db.execute(
|
||
select(ProjectDefaultStyle.style_id)
|
||
.where(ProjectDefaultStyle.project_id == project_id)
|
||
)
|
||
default_style_id = result.scalar_one_or_none()
|
||
|
||
# 获取全局预设风格(user_id 为 NULL)
|
||
result = await db.execute(
|
||
select(WritingStyle)
|
||
.where(WritingStyle.user_id.is_(None))
|
||
.order_by(WritingStyle.order_index)
|
||
)
|
||
preset_styles = list(result.scalars().all())
|
||
|
||
# 获取用户自定义风格
|
||
result = await db.execute(
|
||
select(WritingStyle)
|
||
.where(WritingStyle.user_id == user_id)
|
||
.order_by(WritingStyle.order_index)
|
||
)
|
||
custom_styles = list(result.scalars().all())
|
||
|
||
# 合并:预设风格 + 自定义风格
|
||
all_styles = preset_styles + custom_styles
|
||
|
||
# 为每个风格添加 is_default 标记(用于前端显示)
|
||
styles_with_default = []
|
||
for style in all_styles:
|
||
style_dict = {
|
||
"id": style.id,
|
||
"user_id": style.user_id,
|
||
"name": style.name,
|
||
"style_type": style.style_type,
|
||
"preset_id": style.preset_id,
|
||
"description": style.description,
|
||
"prompt_content": style.prompt_content,
|
||
"order_index": style.order_index,
|
||
"created_at": style.created_at,
|
||
"updated_at": style.updated_at,
|
||
"is_default": style.id == default_style_id
|
||
}
|
||
styles_with_default.append(style_dict)
|
||
|
||
return {"styles": styles_with_default, "total": len(styles_with_default)}
|
||
|
||
|
||
@router.get("/{style_id}", response_model=WritingStyleResponse)
|
||
async def get_writing_style(
|
||
style_id: int,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""获取单个写作风格详情"""
|
||
user_id = get_current_user_id(request)
|
||
result = await db.execute(
|
||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||
)
|
||
style = result.scalar_one_or_none()
|
||
if not style:
|
||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||
if style.user_id is not None and style.user_id != user_id:
|
||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||
|
||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||
result = await db.execute(
|
||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||
)
|
||
is_default = result.scalars().first() is not None
|
||
|
||
# 返回包含 is_default 字段的字典
|
||
return {
|
||
"id": style.id,
|
||
"user_id": style.user_id,
|
||
"name": style.name,
|
||
"style_type": style.style_type,
|
||
"preset_id": style.preset_id,
|
||
"description": style.description,
|
||
"prompt_content": style.prompt_content,
|
||
"order_index": style.order_index,
|
||
"created_at": style.created_at,
|
||
"updated_at": style.updated_at,
|
||
"is_default": is_default
|
||
}
|
||
|
||
|
||
@router.put("/{style_id}", response_model=WritingStyleResponse)
|
||
async def update_writing_style(
|
||
style_id: int,
|
||
style_data: WritingStyleUpdate,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
更新写作风格
|
||
|
||
- 只能修改自定义风格
|
||
- 不能修改全局预设风格
|
||
"""
|
||
# 获取当前用户ID
|
||
user_id = get_current_user_id(request)
|
||
|
||
result = await db.execute(
|
||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||
)
|
||
style = result.scalar_one_or_none()
|
||
if not style:
|
||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||
|
||
# 检查是否为全局预设风格(不允许修改)
|
||
if style.user_id is None:
|
||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||
|
||
# 验证用户权限(只能修改自己的风格)
|
||
if style.user_id != user_id:
|
||
raise HTTPException(status_code=403, detail="无权修改其他用户的风格")
|
||
|
||
# 更新字段
|
||
update_data = style_data.model_dump(exclude_unset=True)
|
||
|
||
# 如果修改了内容,将 style_type 改为 custom
|
||
if any(key in update_data for key in ["name", "description", "prompt_content"]):
|
||
update_data["style_type"] = "custom"
|
||
|
||
for key, value in update_data.items():
|
||
setattr(style, key, value)
|
||
|
||
await db.commit()
|
||
await db.refresh(style)
|
||
|
||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||
result = await db.execute(
|
||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||
)
|
||
is_default = result.scalars().first() is not None
|
||
|
||
# 返回包含 is_default 字段的字典
|
||
return {
|
||
"id": style.id,
|
||
"user_id": style.user_id,
|
||
"name": style.name,
|
||
"style_type": style.style_type,
|
||
"preset_id": style.preset_id,
|
||
"description": style.description,
|
||
"prompt_content": style.prompt_content,
|
||
"order_index": style.order_index,
|
||
"created_at": style.created_at,
|
||
"updated_at": style.updated_at,
|
||
"is_default": is_default
|
||
}
|
||
|
||
|
||
@router.delete("/{style_id}", status_code=204)
|
||
async def delete_writing_style(
|
||
style_id: int,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
删除写作风格
|
||
|
||
注意:
|
||
- 只能删除自定义风格,不能删除全局预设风格
|
||
- 不能删除默认风格(必须先设置其他风格为默认)
|
||
- 删除后无法恢复
|
||
"""
|
||
# 获取当前用户ID
|
||
user_id = get_current_user_id(request)
|
||
|
||
result = await db.execute(
|
||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||
)
|
||
style = result.scalar_one_or_none()
|
||
if not style:
|
||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||
|
||
# 检查是否为全局预设风格(不允许删除)
|
||
if style.user_id is None:
|
||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||
|
||
# 验证用户权限(只能删除自己的风格)
|
||
if style.user_id != user_id:
|
||
raise HTTPException(status_code=403, detail="无权删除其他用户的风格")
|
||
|
||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||
result = await db.execute(
|
||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||
)
|
||
default_relation = result.scalars().first()
|
||
if default_relation:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="不能删除默认风格,请先设置其他风格为默认"
|
||
)
|
||
|
||
await db.delete(style)
|
||
await db.commit()
|
||
|
||
return None
|
||
|
||
|
||
@router.post("/{style_id}/set-default", response_model=dict)
|
||
async def set_default_style(
|
||
style_id: int,
|
||
request_data: SetDefaultStyleRequest,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
将指定风格设置为项目的默认风格
|
||
|
||
使用 project_default_styles 表记录项目的默认风格选择
|
||
每个项目只能有一个默认风格(通过 UniqueConstraint 保证)
|
||
|
||
参数:
|
||
- style_id: 要设置为默认的风格ID(路径参数)
|
||
- project_id: 项目ID(请求体),用于确定在哪个项目上下文中设置默认
|
||
"""
|
||
project_id = request_data.project_id
|
||
|
||
# 获取当前用户ID
|
||
user_id = get_current_user_id(request)
|
||
|
||
# 验证项目访问权限
|
||
result = await db.execute(
|
||
select(Project).where(
|
||
Project.id == project_id,
|
||
Project.user_id == user_id
|
||
)
|
||
)
|
||
project = result.scalar_one_or_none()
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||
|
||
# 验证风格是否存在
|
||
result = await db.execute(
|
||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||
)
|
||
style = result.scalar_one_or_none()
|
||
if not style:
|
||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||
|
||
# 验证风格是否属于该用户(自定义风格)或是全局预设风格
|
||
if style.user_id is not None and style.user_id != user_id:
|
||
raise HTTPException(status_code=403, detail="无权操作其他用户的风格")
|
||
|
||
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
|
||
await db.execute(
|
||
delete(ProjectDefaultStyle).where(ProjectDefaultStyle.project_id == project_id)
|
||
)
|
||
|
||
# 插入新的默认风格记录
|
||
new_default = ProjectDefaultStyle(
|
||
project_id=project_id,
|
||
style_id=style_id
|
||
)
|
||
db.add(new_default)
|
||
await db.commit()
|
||
|
||
return {
|
||
"message": "默认风格设置成功",
|
||
"project_id": project_id,
|
||
"style_id": style_id,
|
||
"style_name": style.name
|
||
}
|
||
|
||
|
||
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
||
async def initialize_default_styles(
|
||
project_id: str,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
"""
|
||
【已废弃】为项目初始化默认风格
|
||
|
||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||
"""
|
||
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
||
return await get_project_styles(project_id, request, db)
|