update:1.修复一对一模式修改大纲名称没有同步更新章节名称 2.修复一对一模式全新生成大纲,没有关联删除对应章节问题 3.优化根据分析建议重新生成章节内容时引用默认写作风格 5.将写作风格调整至用户级,在一个项目中添加全局可见(需要更新数据库)
This commit is contained in:
@@ -2543,7 +2543,7 @@ async def regenerate_chapter_stream(
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
|
||||
|
||||
# 预先获取项目上下文数据
|
||||
# 预先获取项目上下文数据和写作风格
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
# 获取项目信息
|
||||
@@ -2566,6 +2566,41 @@ async def regenerate_chapter_stream(
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 获取写作风格
|
||||
style_content = ""
|
||||
style_id = regenerate_request.style_id
|
||||
|
||||
# 如果没有指定风格,尝试使用项目的默认风格
|
||||
if not style_id:
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
default_style_result = await temp_db.execute(
|
||||
select(ProjectDefaultStyle.style_id)
|
||||
.where(ProjectDefaultStyle.project_id == chapter.project_id)
|
||||
)
|
||||
default_style_id = default_style_result.scalar_one_or_none()
|
||||
if default_style_id:
|
||||
style_id = default_style_id
|
||||
logger.info(f"📝 使用项目默认写作风格: {style_id}")
|
||||
|
||||
# 获取风格内容
|
||||
if style_id:
|
||||
style_result = await temp_db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = style_result.scalar_one_or_none()
|
||||
if style:
|
||||
# 验证风格是否可用:全局预设风格(project_id为NULL)或者当前项目的自定义风格
|
||||
if style.project_id is None or style.project_id == chapter.project_id:
|
||||
style_content = style.prompt_content or ""
|
||||
style_type = "全局预设" if style.project_id is None else "项目自定义"
|
||||
logger.info(f"✅ 使用写作风格: {style.name} ({style_type})")
|
||||
else:
|
||||
logger.warning(f"⚠️ 风格 {style_id} 不属于当前项目,跳过")
|
||||
else:
|
||||
logger.warning(f"⚠️ 未找到风格 {style_id}")
|
||||
else:
|
||||
logger.info("ℹ️ 未指定写作风格,使用默认提示词")
|
||||
|
||||
# 构建项目上下文
|
||||
project_context = {
|
||||
'project_title': project.title if project else '未知',
|
||||
@@ -2635,7 +2670,8 @@ async def regenerate_chapter_stream(
|
||||
chapter=chapter,
|
||||
analysis=analysis,
|
||||
regenerate_request=regenerate_request,
|
||||
project_context=project_context
|
||||
project_context=project_context,
|
||||
style_content=style_content
|
||||
):
|
||||
# 处理不同类型的事件
|
||||
if event['type'] == 'chunk':
|
||||
|
||||
@@ -174,7 +174,7 @@ async def update_outline(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新大纲信息并同步更新structure字段"""
|
||||
"""更新大纲信息并同步更新structure字段和关联章节"""
|
||||
result = await db.execute(
|
||||
select(Outline).where(Outline.id == outline_id)
|
||||
)
|
||||
@@ -185,7 +185,7 @@ async def update_outline(
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(outline.project_id, user_id, db)
|
||||
project = await verify_project_access(outline.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = outline_update.model_dump(exclude_unset=True)
|
||||
@@ -214,6 +214,28 @@ async def update_outline(
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"大纲 {outline_id} 的structure字段格式错误,跳过更新")
|
||||
|
||||
# 🔧 传统模式(one-to-one):同步更新关联章节的标题
|
||||
if 'title' in update_data and project.outline_mode == 'one-to-one':
|
||||
try:
|
||||
# 查找对应的章节(通过chapter_number匹配order_index)
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(
|
||||
Chapter.project_id == outline.project_id,
|
||||
Chapter.chapter_number == outline.order_index
|
||||
)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if chapter:
|
||||
# 同步更新章节标题
|
||||
chapter.title = outline.title
|
||||
logger.info(f"一对一模式:同步更新章节 {chapter.id} 的标题为 '{outline.title}'")
|
||||
else:
|
||||
logger.debug(f"一对一模式:未找到对应的章节(chapter_number={outline.order_index})")
|
||||
except Exception as e:
|
||||
logger.error(f"同步更新章节标题失败: {str(e)}")
|
||||
# 不阻断大纲更新流程,仅记录错误
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(outline)
|
||||
return outline
|
||||
@@ -485,9 +507,21 @@ async def _generate_new_outline(
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 全新生成模式:必须删除旧大纲(章节不自动删除,由用户手动管理)
|
||||
# 注意:这是"new"模式的核心逻辑,应该始终删除旧数据
|
||||
logger.info(f"删除项目 {project.id} 的旧大纲")
|
||||
# 全新生成模式:删除旧大纲和关联的所有章节
|
||||
logger.info(f"全新生成:删除项目 {project.id} 的旧大纲和章节")
|
||||
|
||||
from sqlalchemy import delete as sql_delete
|
||||
|
||||
# 无论是一对一还是一对多模式,都删除所有项目的章节
|
||||
# 一对一模式:通过 chapter_number 关联
|
||||
# 一对多模式:通过 outline_id 关联
|
||||
delete_result = await db.execute(
|
||||
sql_delete(Chapter).where(Chapter.project_id == project.id)
|
||||
)
|
||||
deleted_chapters_count = delete_result.rowcount
|
||||
logger.info(f"全新生成:删除了 {deleted_chapters_count} 个旧章节")
|
||||
|
||||
# 删除旧大纲
|
||||
await db.execute(
|
||||
delete(Outline).where(Outline.project_id == project.id)
|
||||
)
|
||||
|
||||
@@ -22,24 +22,12 @@ router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
def get_current_user_id(request: Request) -> str:
|
||||
"""获取当前登录用户ID"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
return user_id
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
@@ -68,14 +56,13 @@ async def create_writing_style(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新的写作风格
|
||||
创建新的写作风格(用户级别)
|
||||
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style_data.project_id, user_id, db)
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
if style_data.preset_id:
|
||||
@@ -98,16 +85,16 @@ async def create_writing_style(
|
||||
detail="name 和 prompt_content 是必填字段"
|
||||
)
|
||||
|
||||
# 获取当前最大 order_index
|
||||
# 获取当前用户的最大 order_index
|
||||
count_result = await db.execute(
|
||||
select(func.count(WritingStyle.id))
|
||||
.where(WritingStyle.project_id == style_data.project_id)
|
||||
.where(WritingStyle.user_id == user_id)
|
||||
)
|
||||
max_order = count_result.scalar_one()
|
||||
|
||||
# 创建风格记录
|
||||
new_style = WritingStyle(
|
||||
project_id=style_data.project_id,
|
||||
user_id=user_id,
|
||||
name=style_data.name,
|
||||
style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"),
|
||||
preset_id=style_data.preset_id,
|
||||
@@ -123,7 +110,7 @@ async def create_writing_style(
|
||||
# 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格)
|
||||
return {
|
||||
"id": new_style.id,
|
||||
"project_id": new_style.project_id,
|
||||
"user_id": new_style.user_id,
|
||||
"name": new_style.name,
|
||||
"style_type": new_style.style_type,
|
||||
"preset_id": new_style.preset_id,
|
||||
@@ -136,6 +123,60 @@ async def create_writing_style(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user", response_model=WritingStyleListResponse)
|
||||
async def get_user_styles(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用户的所有可用写作风格
|
||||
|
||||
返回:全局预设风格 + 该用户的自定义风格
|
||||
按 order_index 排序
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = list(result.scalars().all())
|
||||
|
||||
# 获取用户自定义风格
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id == user_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
custom_styles = list(result.scalars().all())
|
||||
|
||||
# 合并:预设风格 + 自定义风格
|
||||
all_styles = preset_styles + custom_styles
|
||||
|
||||
# 转换为响应格式
|
||||
styles_with_default = []
|
||||
for style in all_styles:
|
||||
style_dict = {
|
||||
"id": style.id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"order_index": style.order_index,
|
||||
"created_at": style.created_at,
|
||||
"updated_at": style.updated_at,
|
||||
"is_default": False # 用户级别不再需要默认风格标记
|
||||
}
|
||||
styles_with_default.append(style_dict)
|
||||
|
||||
return {"styles": styles_with_default, "total": len(styles_with_default)}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||
async def get_project_styles(
|
||||
project_id: str,
|
||||
@@ -143,14 +184,24 @@ async def get_project_styles(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目的所有可用写作风格
|
||||
获取项目可用的所有写作风格(保留用于向后兼容)
|
||||
|
||||
返回:全局预设风格 + 该项目的自定义风格
|
||||
返回:全局预设风格 + 该用户的自定义风格
|
||||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 验证项目访问权限
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
# 获取该项目的默认风格ID
|
||||
result = await db.execute(
|
||||
@@ -159,18 +210,18 @@ async def get_project_styles(
|
||||
)
|
||||
default_style_id = result.scalar_one_or_none()
|
||||
|
||||
# 获取全局预设风格(project_id 为 NULL)
|
||||
# 获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id.is_(None))
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = list(result.scalars().all())
|
||||
|
||||
# 获取项目自定义风格
|
||||
# 获取用户自定义风格
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.project_id == project_id)
|
||||
.where(WritingStyle.user_id == user_id)
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
custom_styles = list(result.scalars().all())
|
||||
@@ -183,7 +234,7 @@ async def get_project_styles(
|
||||
for style in all_styles:
|
||||
style_dict = {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
@@ -221,7 +272,7 @@ async def get_writing_style(
|
||||
# 返回包含 is_default 字段的字典
|
||||
return {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
@@ -247,6 +298,9 @@ async def update_writing_style(
|
||||
- 只能修改自定义风格
|
||||
- 不能修改全局预设风格
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
@@ -255,12 +309,12 @@ async def update_writing_style(
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否为全局预设风格(不允许修改)
|
||||
if style.project_id is None:
|
||||
if style.user_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style.project_id, user_id, db)
|
||||
# 验证用户权限(只能修改自己的风格)
|
||||
if style.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权修改其他用户的风格")
|
||||
|
||||
# 更新字段
|
||||
update_data = style_data.model_dump(exclude_unset=True)
|
||||
@@ -284,7 +338,7 @@ async def update_writing_style(
|
||||
# 返回包含 is_default 字段的字典
|
||||
return {
|
||||
"id": style.id,
|
||||
"project_id": style.project_id,
|
||||
"user_id": style.user_id,
|
||||
"name": style.name,
|
||||
"style_type": style.style_type,
|
||||
"preset_id": style.preset_id,
|
||||
@@ -311,6 +365,9 @@ async def delete_writing_style(
|
||||
- 不能删除默认风格(必须先设置其他风格为默认)
|
||||
- 删除后无法恢复
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
@@ -319,12 +376,12 @@ async def delete_writing_style(
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否为全局预设风格(不允许删除)
|
||||
if style.project_id is None:
|
||||
if style.user_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style.project_id, user_id, db)
|
||||
# 验证用户权限(只能删除自己的风格)
|
||||
if style.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权删除其他用户的风格")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
result = await db.execute(
|
||||
@@ -362,9 +419,19 @@ async def set_default_style(
|
||||
"""
|
||||
project_id = request_data.project_id
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 验证项目访问权限
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
# 验证风格是否存在
|
||||
result = await db.execute(
|
||||
@@ -374,9 +441,9 @@ async def set_default_style(
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 验证风格是否属于该项目(自定义风格)或是全局预设风格
|
||||
if style.project_id is not None and style.project_id != project_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作其他项目的风格")
|
||||
# 验证风格是否属于该用户(自定义风格)或是全局预设风格
|
||||
if style.user_id is not None and style.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作其他用户的风格")
|
||||
|
||||
# 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的
|
||||
await db.execute(
|
||||
@@ -411,9 +478,5 @@ async def initialize_default_styles(
|
||||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
||||
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
Reference in New Issue
Block a user