update:1.切换数据库PostgreSQL

This commit is contained in:
xiamuceer
2025-11-10 21:16:55 +08:00
parent dfea51cfa4
commit 20d9319a16
31 changed files with 2526 additions and 256 deletions
+50 -30
View File
@@ -1,5 +1,5 @@
"""写作风格管理 API"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List
@@ -16,8 +16,30 @@ from ..schemas.writing_style import (
SetDefaultStyleRequest
)
from ..services.prompt_service import WritingStyleManager
from ..logger import get_logger
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("/presets/list", response_model=List[dict])
@@ -42,6 +64,7 @@ async def get_preset_styles():
@router.post("", response_model=WritingStyleResponse, status_code=201)
async def create_writing_style(
style_data: WritingStyleCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -50,13 +73,9 @@ async def create_writing_style(
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == style_data.project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(style_data.project_id, user_id, db)
# 如果基于预设创建,获取预设内容
if style_data.preset_id:
@@ -120,6 +139,7 @@ async def create_writing_style(
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
async def get_project_styles(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -128,13 +148,9 @@ async def get_project_styles(
返回:全局预设风格 + 该项目的自定义风格
按 order_index 排序,并标记哪个是当前项目的默认风格
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 获取该项目的默认风格ID
result = await db.execute(
@@ -222,6 +238,7 @@ async def get_writing_style(
async def update_writing_style(
style_id: int,
style_data: WritingStyleUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -241,6 +258,10 @@ async def update_writing_style(
if style.project_id is None:
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(style.project_id, user_id, db)
# 更新字段
update_data = style_data.model_dump(exclude_unset=True)
@@ -279,6 +300,7 @@ async def update_writing_style(
@router.delete("/{style_id}", status_code=204)
async def delete_writing_style(
style_id: int,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -300,6 +322,10 @@ async def delete_writing_style(
if style.project_id is None:
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(style.project_id, user_id, db)
# 检查是否有项目将其设置为默认风格
result = await db.execute(
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
@@ -321,6 +347,7 @@ async def delete_writing_style(
async def set_default_style(
style_id: int,
request_data: SetDefaultStyleRequest,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -335,13 +362,9 @@ async def set_default_style(
"""
project_id = request_data.project_id
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 验证风格是否存在
result = await db.execute(
@@ -379,6 +402,7 @@ async def set_default_style(
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
async def initialize_default_styles(
project_id: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
@@ -387,13 +411,9 @@ async def initialize_default_styles(
新架构下,预设风格是全局的,不需要为每个项目单独初始化
该接口保留用于兼容性,直接返回项目可用的所有风格
"""
# 验证项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 验证用户权限
user_id = getattr(request.state, 'user_id', None)
await verify_project_access(project_id, user_id, db)
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
return await get_project_styles(project_id, db)
return await get_project_styles(project_id, request, db)