update:1.切换数据库PostgreSQL
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""写作风格管理 API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
@@ -16,8 +16,30 @@ from ..schemas.writing_style import (
|
||||
SetDefaultStyleRequest
|
||||
)
|
||||
from ..services.prompt_service import WritingStyleManager
|
||||
from ..logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
@@ -42,6 +64,7 @@ async def get_preset_styles():
|
||||
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
||||
async def create_writing_style(
|
||||
style_data: WritingStyleCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -50,13 +73,9 @@ async def create_writing_style(
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == style_data.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style_data.project_id, user_id, db)
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
if style_data.preset_id:
|
||||
@@ -120,6 +139,7 @@ async def create_writing_style(
|
||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||
async def get_project_styles(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -128,13 +148,9 @@ async def get_project_styles(
|
||||
返回:全局预设风格 + 该项目的自定义风格
|
||||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取该项目的默认风格ID
|
||||
result = await db.execute(
|
||||
@@ -222,6 +238,7 @@ async def get_writing_style(
|
||||
async def update_writing_style(
|
||||
style_id: int,
|
||||
style_data: WritingStyleUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -241,6 +258,10 @@ async def update_writing_style(
|
||||
if style.project_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = style_data.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -279,6 +300,7 @@ async def update_writing_style(
|
||||
@router.delete("/{style_id}", status_code=204)
|
||||
async def delete_writing_style(
|
||||
style_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -300,6 +322,10 @@ async def delete_writing_style(
|
||||
if style.project_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style.project_id, user_id, db)
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
@@ -321,6 +347,7 @@ async def delete_writing_style(
|
||||
async def set_default_style(
|
||||
style_id: int,
|
||||
request_data: SetDefaultStyleRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -335,13 +362,9 @@ async def set_default_style(
|
||||
"""
|
||||
project_id = request_data.project_id
|
||||
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 验证风格是否存在
|
||||
result = await db.execute(
|
||||
@@ -379,6 +402,7 @@ async def set_default_style(
|
||||
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
||||
async def initialize_default_styles(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -387,13 +411,9 @@ async def initialize_default_styles(
|
||||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
||||
return await get_project_styles(project_id, db)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
Reference in New Issue
Block a user