diff --git a/.gitignore b/.gitignore index d36db68..aa7f9ff 100644 --- a/.gitignore +++ b/.gitignore @@ -102,4 +102,7 @@ dmypy.json # Jupyter Notebook .ipynb_checkpoints -data/ \ No newline at end of file +data/ +docs/ +data_old/ +backend/migrate_all_databases.py \ No newline at end of file diff --git a/backend/app/api/chapters.py b/backend/app/api/chapters.py index c625ed3..fda9bf1 100644 --- a/backend/app/api/chapters.py +++ b/backend/app/api/chapters.py @@ -1,10 +1,11 @@ """章节管理API""" -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Query from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func import json import asyncio +from typing import Optional from app.database import get_db from app.models.chapter import Chapter @@ -12,11 +13,13 @@ from app.models.project import Project from app.models.outline import Outline from app.models.character import Character from app.models.generation_history import GenerationHistory +from app.models.writing_style import WritingStyle from app.schemas.chapter import ( ChapterCreate, ChapterUpdate, ChapterResponse, - ChapterListResponse + ChapterListResponse, + ChapterGenerateRequest ) from app.services.ai_service import AIService from app.services.prompt_service import prompt_service @@ -245,183 +248,24 @@ async def check_can_generate( } -@router.post("/{chapter_id}/generate", summary="AI创作章节内容") -async def generate_chapter_content( - chapter_id: str, - db: AsyncSession = Depends(get_db), - user_ai_service: AIService = Depends(get_user_ai_service) -): - """ - 根据大纲、前置章节内容和项目信息AI创作章节完整内容 - 要求:必须按顺序生成,确保前置章节都已完成 - """ - # 获取章节 - result = await db.execute( - select(Chapter).where(Chapter.id == chapter_id) - ) - chapter = result.scalar_one_or_none() - if not chapter: - raise HTTPException(status_code=404, detail="章节不存在") - - # 检查前置条件 - can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter) - if not can_generate: - raise HTTPException(status_code=400, detail=error_msg) - - try: - # 获取项目信息 - project_result = await db.execute( - select(Project).where(Project.id == chapter.project_id) - ) - project = project_result.scalar_one_or_none() - if not project: - raise HTTPException(status_code=404, detail="项目不存在") - - # 获取对应的大纲(使用新的查询确保获取最新数据) - outline_result = await db.execute( - select(Outline) - .where(Outline.project_id == chapter.project_id) - .where(Outline.order_index == chapter.chapter_number) - .execution_options(populate_existing=True) - ) - outline = outline_result.scalar_one_or_none() - - # 获取所有大纲用于上下文(使用新的查询确保获取最新数据) - all_outlines_result = await db.execute( - select(Outline) - .where(Outline.project_id == chapter.project_id) - .order_by(Outline.order_index) - .execution_options(populate_existing=True) - ) - all_outlines = all_outlines_result.scalars().all() - outlines_context = "\n".join([ - f"第{o.order_index}章 {o.title}: {o.content[:100]}..." - for o in all_outlines - ]) - - # 获取角色信息 - characters_result = await db.execute( - select(Character).where(Character.project_id == chapter.project_id) - ) - characters = characters_result.scalars().all() - characters_info = "\n".join([ - f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}" - for c in characters - ]) - - # 构建前置章节内容上下文(如果有前置章节) - previous_content = "" - if previous_chapters: - # Token控制:保留最近3章的完整内容,早期章节使用摘要 - recent_chapters = previous_chapters[-3:] if len(previous_chapters) > 3 else previous_chapters - early_chapters = previous_chapters[:-3] if len(previous_chapters) > 3 else [] - - # 早期章节摘要 - if early_chapters: - early_summary = "【前期剧情概要】\n" + "\n".join([ - f"第{ch.chapter_number}章《{ch.title}》:{ch.content[:200] if ch.content else ''}..." - for ch in early_chapters - ]) - previous_content += early_summary + "\n\n" - - # 最近章节完整内容 - if recent_chapters: - recent_content = "【最近章节完整内容】\n" + "\n\n".join([ - f"=== 第{ch.chapter_number}章:{ch.title} ===\n{ch.content}" - for ch in recent_chapters - ]) - previous_content += recent_content - - logger.info(f"构建前置上下文:{len(early_chapters)}章摘要 + {len(recent_chapters)}章完整内容") - - # 根据是否有前置内容选择不同的提示词 - if previous_content: - # 使用带上下文的提示词 - prompt = prompt_service.get_chapter_generation_with_context_prompt( - title=project.title, - theme=project.theme or '', - genre=project.genre or '', - narrative_perspective=project.narrative_perspective or '第三人称', - time_period=project.world_time_period or '未设定', - location=project.world_location or '未设定', - atmosphere=project.world_atmosphere or '未设定', - rules=project.world_rules or '未设定', - characters_info=characters_info or '暂无角色信息', - outlines_context=outlines_context, - previous_content=previous_content, - chapter_number=chapter.chapter_number, - chapter_title=chapter.title, - chapter_outline=outline.content if outline else chapter.summary or '暂无大纲' - ) - else: - # 第一章,使用原有提示词 - prompt = prompt_service.get_chapter_generation_prompt( - title=project.title, - theme=project.theme or '', - genre=project.genre or '', - narrative_perspective=project.narrative_perspective or '第三人称', - time_period=project.world_time_period or '未设定', - location=project.world_location or '未设定', - atmosphere=project.world_atmosphere or '未设定', - rules=project.world_rules or '未设定', - characters_info=characters_info or '暂无角色信息', - outlines_context=outlines_context, - chapter_number=chapter.chapter_number, - chapter_title=chapter.title, - chapter_outline=outline.content if outline else chapter.summary or '暂无大纲' - ) - - logger.info(f"开始AI创作章节 {chapter_id}") - - # 调用AI生成 - ai_content = await user_ai_service.generate_text( - prompt=prompt - ) - - # 更新章节内容 - old_word_count = chapter.word_count or 0 - chapter.content = ai_content - new_word_count = len(ai_content) - chapter.word_count = new_word_count - chapter.status = "completed" - - # 更新项目字数 - project.current_words = project.current_words - old_word_count + new_word_count - - # 记录生成历史 - history = GenerationHistory( - project_id=chapter.project_id, - chapter_id=chapter.id, - prompt=f"创作章节: 第{chapter.chapter_number}章 {chapter.title}", - generated_content=ai_content[:500] if len(ai_content) > 500 else ai_content, - model="default" - ) - db.add(history) - - await db.commit() - await db.refresh(chapter) - - logger.info(f"成功创作章节 {chapter_id},共 {new_word_count} 字") - - return {"content": ai_content} - - except Exception as e: - logger.error(f"创作章节失败: {str(e)}") - raise HTTPException(status_code=500, detail=f"创作章节失败: {str(e)}") - @router.post("/{chapter_id}/generate-stream", summary="AI创作章节内容(流式)") async def generate_chapter_content_stream( chapter_id: str, request: Request, + generate_request: ChapterGenerateRequest = ChapterGenerateRequest(), user_ai_service: AIService = Depends(get_user_ai_service) ): """ 根据大纲、前置章节内容和项目信息AI创作章节完整内容(流式返回) 要求:必须按顺序生成,确保前置章节都已完成 + 请求体参数: + - style_id: 可选,指定使用的写作风格ID。不提供则不使用任何风格 + 注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话 以避免流式响应期间的连接泄漏问题 """ + style_id = generate_request.style_id # 预先验证章节存在性(使用临时会话) async for temp_db in get_db(request): try: @@ -508,6 +352,27 @@ async def generate_chapter_content_stream( for c in characters ]) + # 获取写作风格 + style_content = "" + if style_id: + # 使用指定的风格 + style_result = await db_session.execute( + select(WritingStyle).where(WritingStyle.id == style_id) + ) + style = style_result.scalar_one_or_none() + if style: + # 验证风格是否可用:全局预设风格(project_id为NULL)或者当前项目的自定义风格 + if style.project_id is None or style.project_id == current_chapter.project_id: + style_content = style.prompt_content or "" + style_type = "全局预设" if style.project_id is None else "项目自定义" + logger.info(f"使用指定风格: {style.name} ({style_type})") + else: + logger.warning(f"风格 {style_id} 不属于当前项目,无法使用") + else: + logger.warning(f"未找到风格 {style_id}") + else: + logger.info("未指定写作风格,使用原始提示词") + # 构建前置章节内容上下文(使用之前保存的数据) previous_content = "" if previous_chapters_data: @@ -533,7 +398,7 @@ async def generate_chapter_content_stream( # 发送开始事件 yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n" - # 根据是否有前置内容选择不同的提示词 + # 根据是否有前置内容选择不同的提示词,并应用写作风格 if previous_content: prompt = prompt_service.get_chapter_generation_with_context_prompt( title=project.title, @@ -549,7 +414,8 @@ async def generate_chapter_content_stream( previous_content=previous_content, chapter_number=current_chapter.chapter_number, chapter_title=current_chapter.title, - chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲' + chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲', + style_content=style_content ) else: prompt = prompt_service.get_chapter_generation_prompt( @@ -565,7 +431,8 @@ async def generate_chapter_content_stream( outlines_context=outlines_context, chapter_number=current_chapter.chapter_number, chapter_title=current_chapter.title, - chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲' + chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲', + style_content=style_content ) logger.info(f"开始AI流式创作章节 {chapter_id}") diff --git a/backend/app/api/projects.py b/backend/app/api/projects.py index 4b747a8..f824c19 100644 --- a/backend/app/api/projects.py +++ b/backend/app/api/projects.py @@ -40,6 +40,7 @@ async def create_project( await db.commit() await db.refresh(db_project) logger.info(f"项目创建成功: {db_project.id}") + return db_project except Exception as e: logger.error(f"创建项目失败: {str(e)}", exc_info=True) diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 0d304aa..c7400db 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -12,6 +12,8 @@ from app.models.character import Character from app.models.outline import Outline from app.models.chapter import Chapter from app.models.relationship import CharacterRelationship, Organization, OrganizationMember, RelationshipType +from app.models.writing_style import WritingStyle +from app.models.project_default_style import ProjectDefaultStyle from app.services.ai_service import AIService from app.services.prompt_service import prompt_service from app.logger import get_logger @@ -132,9 +134,33 @@ async def world_building_generator( ) db.add(project) await db.commit() - db_committed = True await db.refresh(project) + # 自动设置默认写作风格为第一个全局预设风格 + try: + result = await db.execute( + select(WritingStyle).where( + WritingStyle.project_id.is_(None), + WritingStyle.order_index == 1 + ).limit(1) + ) + first_style = result.scalar_one_or_none() + + if first_style: + default_style = ProjectDefaultStyle( + project_id=project.id, + style_id=first_style.id + ) + db.add(default_style) + await db.commit() + logger.info(f"为项目 {project.id} 自动设置默认风格: {first_style.name}") + else: + logger.warning(f"未找到order_index=1的全局预设风格,项目 {project.id} 未设置默认风格") + except Exception as e: + logger.warning(f"设置默认写作风格失败: {e},不影响项目创建") + + db_committed = True + # 发送最终结果 yield await SSEResponse.send_result({ "project_id": project.id, diff --git a/backend/app/api/writing_styles.py b/backend/app/api/writing_styles.py new file mode 100644 index 0000000..af503bd --- /dev/null +++ b/backend/app/api/writing_styles.py @@ -0,0 +1,399 @@ +"""写作风格管理 API""" +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, func, delete +from typing import List + +from ..database import get_db +from ..models.writing_style import WritingStyle +from ..models.project import Project +from ..models.project_default_style import ProjectDefaultStyle +from ..schemas.writing_style import ( + WritingStyleCreate, + WritingStyleUpdate, + WritingStyleResponse, + WritingStyleListResponse, + SetDefaultStyleRequest +) +from ..services.prompt_service import WritingStyleManager + +router = APIRouter(prefix="/writing-styles", tags=["writing-styles"]) + + +@router.get("/presets/list", response_model=List[dict]) +async def get_preset_styles(): + """ + 获取所有预设风格列表 + + 返回格式:数组形式的预设风格列表 + [ + {"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."}, + {"id": "classical", "name": "古典优雅", ...} + ] + """ + presets = WritingStyleManager.get_all_presets() + # 将字典转换为数组,添加 id 字段 + return [ + {"id": preset_id, **preset_data} + for preset_id, preset_data in presets.items() + ] + + +@router.post("", response_model=WritingStyleResponse, status_code=201) +async def create_writing_style( + style_data: WritingStyleCreate, + db: AsyncSession = Depends(get_db) +): + """ + 创建新的写作风格 + + - **基于预设创建**:提供 preset_id,系统会自动填充预设内容 + - **完全自定义**:不提供 preset_id,需要手动填写所有字段 + """ + # 验证项目是否存在 + result = await db.execute( + select(Project).where(Project.id == style_data.project_id) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + # 如果基于预设创建,获取预设内容 + if style_data.preset_id: + preset = WritingStyleManager.get_preset_style(style_data.preset_id) + if not preset: + raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在") + + # 使用预设内容填充(如果用户未提供) + if not style_data.name: + style_data.name = preset["name"] + if not style_data.description: + style_data.description = preset["description"] + if not style_data.prompt_content: + style_data.prompt_content = preset["prompt_content"] + + # 验证必填字段 + if not style_data.name or not style_data.prompt_content: + raise HTTPException( + status_code=400, + detail="name 和 prompt_content 是必填字段" + ) + + # 获取当前最大 order_index + count_result = await db.execute( + select(func.count(WritingStyle.id)) + .where(WritingStyle.project_id == style_data.project_id) + ) + max_order = count_result.scalar_one() + + # 创建风格记录 + new_style = WritingStyle( + project_id=style_data.project_id, + name=style_data.name, + style_type=style_data.style_type or ("preset" if style_data.preset_id else "custom"), + preset_id=style_data.preset_id, + description=style_data.description, + prompt_content=style_data.prompt_content, + order_index=max_order + 1 + ) + + db.add(new_style) + await db.commit() + await db.refresh(new_style) + + # 返回包含 is_default 字段的字典(新创建的风格默认不是默认风格) + return { + "id": new_style.id, + "project_id": new_style.project_id, + "name": new_style.name, + "style_type": new_style.style_type, + "preset_id": new_style.preset_id, + "description": new_style.description, + "prompt_content": new_style.prompt_content, + "order_index": new_style.order_index, + "created_at": new_style.created_at, + "updated_at": new_style.updated_at, + "is_default": False + } + + +@router.get("/project/{project_id}", response_model=WritingStyleListResponse) +async def get_project_styles( + project_id: str, + db: AsyncSession = Depends(get_db) +): + """ + 获取项目的所有可用写作风格 + + 返回:全局预设风格 + 该项目的自定义风格 + 按 order_index 排序,并标记哪个是当前项目的默认风格 + """ + # 验证项目是否存在 + result = await db.execute( + select(Project).where(Project.id == project_id) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + # 获取该项目的默认风格ID + result = await db.execute( + select(ProjectDefaultStyle.style_id) + .where(ProjectDefaultStyle.project_id == project_id) + ) + default_style_id = result.scalar_one_or_none() + + # 获取全局预设风格(project_id 为 NULL) + result = await db.execute( + select(WritingStyle) + .where(WritingStyle.project_id.is_(None)) + .order_by(WritingStyle.order_index) + ) + preset_styles = list(result.scalars().all()) + + # 获取项目自定义风格 + result = await db.execute( + select(WritingStyle) + .where(WritingStyle.project_id == project_id) + .order_by(WritingStyle.order_index) + ) + custom_styles = list(result.scalars().all()) + + # 合并:预设风格 + 自定义风格 + all_styles = preset_styles + custom_styles + + # 为每个风格添加 is_default 标记(用于前端显示) + styles_with_default = [] + for style in all_styles: + style_dict = { + "id": style.id, + "project_id": style.project_id, + "name": style.name, + "style_type": style.style_type, + "preset_id": style.preset_id, + "description": style.description, + "prompt_content": style.prompt_content, + "order_index": style.order_index, + "created_at": style.created_at, + "updated_at": style.updated_at, + "is_default": style.id == default_style_id + } + styles_with_default.append(style_dict) + + return {"styles": styles_with_default, "total": len(styles_with_default)} + + +@router.get("/{style_id}", response_model=WritingStyleResponse) +async def get_writing_style( + style_id: int, + db: AsyncSession = Depends(get_db) +): + """获取单个写作风格详情""" + result = await db.execute( + select(WritingStyle).where(WritingStyle.id == style_id) + ) + style = result.scalar_one_or_none() + if not style: + raise HTTPException(status_code=404, detail="写作风格不存在") + + # 检查是否有项目将其设置为默认风格 + result = await db.execute( + select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id) + ) + is_default = result.scalar_one_or_none() is not None + + # 返回包含 is_default 字段的字典 + return { + "id": style.id, + "project_id": style.project_id, + "name": style.name, + "style_type": style.style_type, + "preset_id": style.preset_id, + "description": style.description, + "prompt_content": style.prompt_content, + "order_index": style.order_index, + "created_at": style.created_at, + "updated_at": style.updated_at, + "is_default": is_default + } + + +@router.put("/{style_id}", response_model=WritingStyleResponse) +async def update_writing_style( + style_id: int, + style_data: WritingStyleUpdate, + db: AsyncSession = Depends(get_db) +): + """ + 更新写作风格 + + - 只能修改自定义风格 + - 不能修改全局预设风格 + """ + result = await db.execute( + select(WritingStyle).where(WritingStyle.id == style_id) + ) + style = result.scalar_one_or_none() + if not style: + raise HTTPException(status_code=404, detail="写作风格不存在") + + # 检查是否为全局预设风格(不允许修改) + if style.project_id is None: + raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格") + + # 更新字段 + update_data = style_data.model_dump(exclude_unset=True) + + # 如果修改了内容,将 style_type 改为 custom + if any(key in update_data for key in ["name", "description", "prompt_content"]): + update_data["style_type"] = "custom" + + for key, value in update_data.items(): + setattr(style, key, value) + + await db.commit() + await db.refresh(style) + + # 检查是否有项目将其设置为默认风格 + result = await db.execute( + select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id) + ) + is_default = result.scalar_one_or_none() is not None + + # 返回包含 is_default 字段的字典 + return { + "id": style.id, + "project_id": style.project_id, + "name": style.name, + "style_type": style.style_type, + "preset_id": style.preset_id, + "description": style.description, + "prompt_content": style.prompt_content, + "order_index": style.order_index, + "created_at": style.created_at, + "updated_at": style.updated_at, + "is_default": is_default + } + + +@router.delete("/{style_id}", status_code=204) +async def delete_writing_style( + style_id: int, + db: AsyncSession = Depends(get_db) +): + """ + 删除写作风格 + + 注意: + - 只能删除自定义风格,不能删除全局预设风格 + - 不能删除默认风格(必须先设置其他风格为默认) + - 删除后无法恢复 + """ + result = await db.execute( + select(WritingStyle).where(WritingStyle.id == style_id) + ) + style = result.scalar_one_or_none() + if not style: + raise HTTPException(status_code=404, detail="写作风格不存在") + + # 检查是否为全局预设风格(不允许删除) + if style.project_id is None: + raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格") + + # 检查是否有项目将其设置为默认风格 + result = await db.execute( + select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id) + ) + default_relation = result.scalar_one_or_none() + if default_relation: + raise HTTPException( + status_code=400, + detail="不能删除默认风格,请先设置其他风格为默认" + ) + + await db.delete(style) + await db.commit() + + return None + + +@router.post("/{style_id}/set-default", response_model=dict) +async def set_default_style( + style_id: int, + request_data: SetDefaultStyleRequest, + db: AsyncSession = Depends(get_db) +): + """ + 将指定风格设置为项目的默认风格 + + 使用 project_default_styles 表记录项目的默认风格选择 + 每个项目只能有一个默认风格(通过 UniqueConstraint 保证) + + 参数: + - style_id: 要设置为默认的风格ID(路径参数) + - project_id: 项目ID(请求体),用于确定在哪个项目上下文中设置默认 + """ + project_id = request_data.project_id + + # 验证项目是否存在 + result = await db.execute( + select(Project).where(Project.id == project_id) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + # 验证风格是否存在 + result = await db.execute( + select(WritingStyle).where(WritingStyle.id == style_id) + ) + style = result.scalar_one_or_none() + if not style: + raise HTTPException(status_code=404, detail="写作风格不存在") + + # 验证风格是否属于该项目(自定义风格)或是全局预设风格 + if style.project_id is not None and style.project_id != project_id: + raise HTTPException(status_code=403, detail="无权操作其他项目的风格") + + # 使用 UPSERT 逻辑:先删除该项目的旧默认风格记录,再插入新的 + await db.execute( + delete(ProjectDefaultStyle).where(ProjectDefaultStyle.project_id == project_id) + ) + + # 插入新的默认风格记录 + new_default = ProjectDefaultStyle( + project_id=project_id, + style_id=style_id + ) + db.add(new_default) + await db.commit() + + return { + "message": "默认风格设置成功", + "project_id": project_id, + "style_id": style_id, + "style_name": style.name + } + + +@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse) +async def initialize_default_styles( + project_id: str, + db: AsyncSession = Depends(get_db) +): + """ + 【已废弃】为项目初始化默认风格 + + 新架构下,预设风格是全局的,不需要为每个项目单独初始化 + 该接口保留用于兼容性,直接返回项目可用的所有风格 + """ + # 验证项目是否存在 + result = await db.execute( + select(Project).where(Project.id == project_id) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + # 直接返回项目可用的所有风格(全局预设 + 项目自定义) + return await get_project_styles(project_id, db) \ No newline at end of file diff --git a/backend/app/database.py b/backend/app/database.py index 03bd341..be098fa 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -226,6 +226,62 @@ async def _init_relationship_types(user_id: str): +async def _init_global_writing_styles(user_id: str): + """为指定用户初始化全局预设写作风格 + + 全局预设风格的 project_id 为 NULL,所有用户共享 + 只在第一次创建数据库时插入一次 + + Args: + user_id: 用户ID + """ + from app.models.writing_style import WritingStyle + from app.services.prompt_service import WritingStyleManager + + try: + engine = await get_engine(user_id) + AsyncSessionLocal = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False + ) + + async with AsyncSessionLocal() as session: + # 检查是否已存在全局预设风格 + result = await session.execute( + select(WritingStyle).where(WritingStyle.project_id.is_(None)) + ) + existing = result.scalars().first() + + if existing: + logger.info(f"用户 {user_id} 的全局预设风格已存在,跳过初始化") + return + + logger.info(f"开始为用户 {user_id} 插入全局预设写作风格...") + + # 获取所有预设风格配置 + presets = WritingStyleManager.get_all_presets() + + for index, (preset_id, preset_data) in enumerate(presets.items(), start=1): + style = WritingStyle( + project_id=None, # NULL 表示全局预设 + name=preset_data["name"], + style_type="preset", + preset_id=preset_id, + description=preset_data["description"], + prompt_content=preset_data["prompt_content"], + order_index=index + ) + session.add(style) + + await session.commit() + logger.info(f"成功为用户 {user_id} 插入 {len(presets)} 个全局预设写作风格") + + except Exception as e: + logger.error(f"用户 {user_id} 初始化全局预设写作风格失败: {str(e)}", exc_info=True) + raise + + async def init_db(user_id: str): """初始化指定用户的数据库,创建所有表并插入预置数据 @@ -240,6 +296,7 @@ async def init_db(user_id: str): await conn.run_sync(Base.metadata.create_all) await _init_relationship_types(user_id) + await _init_global_writing_styles(user_id) logger.info(f"用户 {user_id} 的数据库初始化成功") except Exception as e: diff --git a/backend/app/main.py b/backend/app/main.py index fe65bad..756a23c 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,18 +7,18 @@ from fastapi.exceptions import RequestValidationError from contextlib import asynccontextmanager from pathlib import Path -from app.config import settings +from app.config import settings as config_settings from app.database import close_db, _session_stats from app.logger import setup_logging, get_logger from app.middleware import RequestIDMiddleware from app.middleware.auth_middleware import AuthMiddleware setup_logging( - level=settings.log_level, - log_to_file=settings.log_to_file, - log_file_path=settings.log_file_path, - max_bytes=settings.log_max_bytes, - backup_count=settings.log_backup_count + level=config_settings.log_level, + log_to_file=config_settings.log_to_file, + log_file_path=config_settings.log_file_path, + max_bytes=config_settings.log_max_bytes, + backup_count=config_settings.log_backup_count ) logger = get_logger(__name__) @@ -34,8 +34,8 @@ async def lifespan(app: FastAPI): app = FastAPI( - title=settings.app_name, - version=settings.app_version, + title=config_settings.app_name, + version=config_settings.app_version, description="AI写小说工具 - 智能小说创作助手", lifespan=lifespan ) @@ -60,14 +60,14 @@ async def global_exception_handler(request: Request, exc: Exception): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={ "detail": "服务器内部错误", - "message": str(exc) if settings.debug else "请稍后重试" + "message": str(exc) if config_settings.debug else "请稍后重试" } ) app.add_middleware(RequestIDMiddleware) app.add_middleware(AuthMiddleware) -if settings.debug: +if config_settings.debug: app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -78,7 +78,7 @@ if settings.debug: else: app.add_middleware( CORSMiddleware, - allow_origins=settings.cors_origins, + allow_origins=config_settings.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -114,7 +114,7 @@ async def db_session_stats(): from app.api import ( projects, outlines, characters, chapters, wizard_stream, relationships, organizations, - auth, users, settings + auth, users, settings, writing_styles ) app.include_router(auth.router, prefix="/api") @@ -128,6 +128,7 @@ app.include_router(characters.router, prefix="/api") app.include_router(chapters.router, prefix="/api") app.include_router(relationships.router, prefix="/api") app.include_router(organizations.router, prefix="/api") +app.include_router(writing_styles.router, prefix="/api") static_dir = Path(__file__).parent.parent / "static" if static_dir.exists(): @@ -161,7 +162,7 @@ else: async def root(): return { "message": "欢迎使用AI Story Creator", - "version": settings.app_version, + "version": config_settings.app_version, "docs": "/docs", "notice": "请先构建前端: cd frontend && npm run build" } @@ -171,7 +172,7 @@ if __name__ == "__main__": import uvicorn uvicorn.run( "app.main:app", - host=settings.app_host, - port=settings.app_port, - reload=settings.debug + host=config_settings.app_host, + port=config_settings.app_port, + reload=config_settings.debug ) \ No newline at end of file diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 545c7a3..38ca856 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -5,6 +5,8 @@ from app.models.character import Character from app.models.chapter import Chapter from app.models.generation_history import GenerationHistory from app.models.settings import Settings +from app.models.writing_style import WritingStyle +from app.models.project_default_style import ProjectDefaultStyle from app.models.relationship import ( RelationshipType, CharacterRelationship, @@ -19,6 +21,8 @@ __all__ = [ "Chapter", "GenerationHistory", "Settings", + "WritingStyle", + "ProjectDefaultStyle", "RelationshipType", "CharacterRelationship", "Organization", diff --git a/backend/app/models/project_default_style.py b/backend/app/models/project_default_style.py new file mode 100644 index 0000000..21f17f6 --- /dev/null +++ b/backend/app/models/project_default_style.py @@ -0,0 +1,23 @@ +"""项目默认风格关联表""" +from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, UniqueConstraint +from sqlalchemy.sql import func +from app.database import Base + + +class ProjectDefaultStyle(Base): + """项目默认风格关联表 - 记录每个项目选择的默认风格""" + __tablename__ = "project_default_styles" + + id = Column(Integer, primary_key=True, autoincrement=True) + project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, comment="项目ID") + style_id = Column(Integer, ForeignKey("writing_styles.id", ondelete="CASCADE"), nullable=False, comment="风格ID") + created_at = Column(DateTime, server_default=func.now(), comment="创建时间") + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") + + # 确保每个项目只有一个默认风格 + __table_args__ = ( + UniqueConstraint('project_id', name='uix_project_default_style'), + ) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/backend/app/models/writing_style.py b/backend/app/models/writing_style.py new file mode 100644 index 0000000..8aa109f --- /dev/null +++ b/backend/app/models/writing_style.py @@ -0,0 +1,23 @@ +"""写作风格数据模型""" +from sqlalchemy import Column, String, Text, Boolean, DateTime, ForeignKey, Integer +from sqlalchemy.sql import func +from app.database import Base + + +class WritingStyle(Base): + """写作风格表""" + __tablename__ = "writing_styles" + + id = Column(Integer, primary_key=True, autoincrement=True) + project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True, comment="所属项目ID(NULL表示全局预设风格)") + name = Column(String(100), nullable=False, comment="风格名称") + style_type = Column(String(50), nullable=False, comment="风格类型:preset/custom") + preset_id = Column(String(50), comment="预设风格ID:natural/classical/modern等") + description = Column(Text, comment="风格描述") + prompt_content = Column(Text, nullable=False, comment="风格提示词内容") + order_index = Column(Integer, default=0, comment="排序序号") + created_at = Column(DateTime, server_default=func.now(), comment="创建时间") + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/backend/app/schemas/chapter.py b/backend/app/schemas/chapter.py index 064ac1d..46f88e0 100644 --- a/backend/app/schemas/chapter.py +++ b/backend/app/schemas/chapter.py @@ -54,4 +54,9 @@ class ChapterResponse(BaseModel): class ChapterListResponse(BaseModel): """章节列表响应模型""" total: int - items: list[ChapterResponse] \ No newline at end of file + items: list[ChapterResponse] + + +class ChapterGenerateRequest(BaseModel): + """AI生成章节内容的请求模型""" + style_id: Optional[int] = Field(None, description="写作风格ID,不提供则不使用任何风格") \ No newline at end of file diff --git a/backend/app/schemas/writing_style.py b/backend/app/schemas/writing_style.py new file mode 100644 index 0000000..81ace45 --- /dev/null +++ b/backend/app/schemas/writing_style.py @@ -0,0 +1,54 @@ +"""写作风格 Schema""" +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + + +class WritingStyleBase(BaseModel): + """写作风格基础模型""" + name: str = Field(..., description="风格名称") + style_type: str = Field(..., description="风格类型:preset/custom") + preset_id: Optional[str] = Field(None, description="预设风格ID") + description: Optional[str] = Field(None, description="风格描述") + prompt_content: str = Field(..., description="风格提示词内容") + + +class WritingStyleCreate(WritingStyleBase): + """创建写作风格(仅用于创建项目自定义风格)""" + project_id: str = Field(..., description="所属项目ID") + + +class WritingStyleUpdate(BaseModel): + """更新写作风格""" + name: Optional[str] = None + description: Optional[str] = None + prompt_content: Optional[str] = None + + +class SetDefaultStyleRequest(BaseModel): + """设置默认风格请求""" + project_id: str = Field(..., description="项目ID") + + +class WritingStyleResponse(BaseModel): + """写作风格响应""" + id: int + project_id: Optional[str] = None # NULL 表示全局预设风格 + name: str + style_type: str + preset_id: Optional[str] = None + description: Optional[str] = None + prompt_content: str + is_default: bool + order_index: int + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class WritingStyleListResponse(BaseModel): + """写作风格列表响应""" + total: int + styles: list[WritingStyleResponse] \ No newline at end of file diff --git a/backend/app/services/prompt_service.py b/backend/app/services/prompt_service.py index 3d44af3..7bd391c 100644 --- a/backend/app/services/prompt_service.py +++ b/backend/app/services/prompt_service.py @@ -1,8 +1,113 @@ """提示词管理服务""" -from typing import Dict, Any +from typing import Dict, Any, Optional import json +class WritingStyleManager: + """写作风格管理器""" + + # 预设风格配置 + PRESET_STYLES = { + "natural": { + "name": "自然流畅", + "description": "像普通人讲故事一样自然,不刻意修饰,有生活气息", + "prompt_content": """ +**自然流畅风格要求:** +- 用简单朴实的语言叙述,避免华丽辞藻 +- 像在和朋友聊天一样讲故事 +- 保持轻松自然的节奏,不要刻意营造氛围 +- 多用短句,少用长句和排比 +- 让读者感觉舒服,不要让人觉得在"看文学作品" +""" + }, + "classical": { + "name": "古典优雅", + "description": "典雅精致的文学风格,注重意境和韵味", + "prompt_content": """ +**古典优雅风格要求:** +- 使用优美典雅的语言,注重文字的韵律感 +- 善用比喻、拟人等修辞手法 +- 注重意境营造,追求诗意美感 +- 可适当引用古诗词或典故(需符合世界观) +- 保持端庄雅致的叙述节奏 +""" + }, + "modern": { + "name": "现代简约", + "description": "简洁明快的现代风格,注重效率和直接表达", + "prompt_content": """ +**现代简约风格要求:** +- 语言简洁有力,直达重点 +- 多用短句和短段落,节奏明快 +- 避免冗长描写,注重信息密度 +- 使用现代口语化表达 +- 情节推进快速,少做环境渲染 +""" + }, + "poetic": { + "name": "诗意抒情", + "description": "富有诗意和情感张力的抒情风格", + "prompt_content": """ +**诗意抒情风格要求:** +- 注重情感表达和内心描写 +- 善用景物描写烘托情绪 +- 语言富有韵律和美感 +- 细腻刻画人物心理活动 +- 营造情感氛围,引发共鸣 +""" + }, + "concise": { + "name": "精炼利落", + "description": "惜字如金的简练风格,每个字都有意义", + "prompt_content": """ +**精炼利落风格要求:** +- 删除所有冗余描写,每句话都要有作用 +- 多用动词,少用形容词和副词 +- 对话干脆利落,不拖泥带水 +- 环境描写点到为止 +- 用最少的字数传达最多的信息 +""" + }, + "vivid": { + "name": "生动形象", + "description": "画面感强烈,让读者如临其境", + "prompt_content": """ +**生动形象风格要求:** +- 注重细节描写,让场景具体可感 +- 调动五感(视觉、听觉、触觉、嗅觉、味觉) +- 使用鲜明的比喻和形象化语言 +- 让读者能"看到"场景和动作 +- 人物表情、动作要具体生动 +""" + } + } + + @classmethod + def get_preset_style(cls, preset_id: str) -> Optional[Dict[str, str]]: + """获取预设风格配置""" + return cls.PRESET_STYLES.get(preset_id) + + @classmethod + def get_all_presets(cls) -> Dict[str, Dict[str, str]]: + """获取所有预设风格""" + return cls.PRESET_STYLES + + @staticmethod + def apply_style_to_prompt(base_prompt: str, style_content: str) -> str: + """ + 将写作风格应用到基础提示词中 + + Args: + base_prompt: 基础提示词 + style_content: 风格要求内容 + + Returns: + 组合后的提示词 + """ + # 在基础提示词末尾添加风格要求 + return f"{base_prompt}\n\n{style_content}\n\n请直接输出章节正文内容,不要包含章节标题和其他说明文字。" + + class PromptService: """提示词模板管理""" @@ -362,15 +467,6 @@ class PromptService: 6. 字数不得低于3000字 7. 语言自然流畅,避免AI痕迹 -**写作风格要求(重要):** -- 让故事自然流淌,写到哪算哪 -- 结尾处直接结束情节,不要加总结性段落 -- 不要在章节末尾写"这一天/这一夜就这样过去了"之类的总结句 -- 不要用"他/她陷入了沉思"作为结尾 -- 避免刻意的情感升华或哲理感悟收尾 -- 章节结尾可以戛然而止,可以是对话,可以是动作,可以是悬念 -- 就像在讲一个故事,讲完了就停,不需要画龙点睛 - 请直接输出章节正文内容,不要包含章节标题和其他说明文字。""" # 章节完整创作提示词(带前置章节上下文) @@ -429,15 +525,6 @@ class PromptService: - 开头自然衔接上一章结尾 - 结尾为下一章做好铺垫 -**写作风格要求(重要):** -- 让故事自然流淌,写到哪算哪 -- 结尾处直接结束情节,不要加总结性段落 -- 不要在章节末尾写"这一天/这一夜就这样过去了"之类的总结句 -- 不要用"他/她陷入了沉思"作为结尾 -- 避免刻意的情感升华或哲理感悟收尾 -- 章节结尾可以戛然而止,可以是对话,可以是动作,可以是悬念 -- 就像在讲一个故事,讲完了就停,不需要画龙点睛 - 请直接输出章节正文内容,不要包含章节标题和其他说明文字。""" # 大纲生成提示词 @@ -662,9 +749,14 @@ class PromptService: location: str, atmosphere: str, rules: str, characters_info: str, outlines_context: str, chapter_number: int, chapter_title: str, - chapter_outline: str) -> str: - """获取章节完整创作提示词""" - return cls.format_prompt( + chapter_outline: str, style_content: str = "") -> str: + """ + 获取章节完整创作提示词 + + Args: + style_content: 写作风格要求内容,如果提供则会追加到提示词中 + """ + base_prompt = cls.format_prompt( cls.CHAPTER_GENERATION, title=title, theme=theme, @@ -680,6 +772,12 @@ class PromptService: chapter_title=chapter_title, chapter_outline=chapter_outline ) + + # 如果有风格要求,应用到提示词中 + if style_content: + return WritingStyleManager.apply_style_to_prompt(base_prompt, style_content) + + return base_prompt @classmethod def get_chapter_generation_with_context_prompt(cls, title: str, theme: str, genre: str, @@ -687,9 +785,15 @@ class PromptService: location: str, atmosphere: str, rules: str, characters_info: str, outlines_context: str, previous_content: str, chapter_number: int, - chapter_title: str, chapter_outline: str) -> str: - """获取章节完整创作提示词(带前置章节上下文)""" - return cls.format_prompt( + chapter_title: str, chapter_outline: str, + style_content: str = "") -> str: + """ + 获取章节完整创作提示词(带前置章节上下文) + + Args: + style_content: 写作风格要求内容,如果提供则会追加到提示词中 + """ + base_prompt = cls.format_prompt( cls.CHAPTER_GENERATION_WITH_CONTEXT, title=title, theme=theme, @@ -706,6 +810,12 @@ class PromptService: chapter_title=chapter_title, chapter_outline=chapter_outline ) + + # 如果有风格要求,应用到提示词中 + if style_content: + return WritingStyleManager.apply_style_to_prompt(base_prompt, style_content) + + return base_prompt @classmethod def get_outline_prompt(cls, genre: str, theme: str, target_words: int, diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index cf26153..e5ada3d 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -10,6 +10,7 @@ import Characters from './pages/Characters'; import Relationships from './pages/Relationships'; import Organizations from './pages/Organizations'; import Chapters from './pages/Chapters'; +import WritingStyles from './pages/WritingStyles'; import Settings from './pages/Settings'; // import Polish from './pages/Polish'; import Login from './pages/Login'; @@ -41,6 +42,7 @@ function App() { } /> } /> } /> + } /> {/* } /> */} diff --git a/frontend/src/pages/Chapters.tsx b/frontend/src/pages/Chapters.tsx index cf31e8e..c8418be 100644 --- a/frontend/src/pages/Chapters.tsx +++ b/frontend/src/pages/Chapters.tsx @@ -3,8 +3,8 @@ import { List, Button, Modal, Form, Input, Select, message, Empty, Space, Badge, import { EditOutlined, FileTextOutlined, ThunderboltOutlined, LockOutlined, DownloadOutlined, SettingOutlined } from '@ant-design/icons'; import { useStore } from '../store'; import { useChapterSync } from '../store/hooks'; -import { projectApi } from '../services/api'; -import type { Chapter, ChapterUpdate, ApiError } from '../types'; +import { projectApi, writingStyleApi } from '../services/api'; +import type { Chapter, ChapterUpdate, ApiError, WritingStyle } from '../types'; import { cardStyles } from '../components/CardStyles'; const { TextArea } = Input; @@ -20,6 +20,8 @@ export default function Chapters() { const [editorForm] = Form.useForm(); const [isMobile, setIsMobile] = useState(window.innerWidth <= 768); const contentTextAreaRef = useRef(null); + const [writingStyles, setWritingStyles] = useState([]); + const [selectedStyleId, setSelectedStyleId] = useState(); useEffect(() => { const handleResize = () => { @@ -39,10 +41,29 @@ export default function Chapters() { useEffect(() => { if (currentProject?.id) { refreshChapters(); + loadWritingStyles(); } // eslint-disable-next-line react-hooks/exhaustive-deps }, [currentProject?.id]); + const loadWritingStyles = async () => { + if (!currentProject?.id) return; + + try { + const response = await writingStyleApi.getProjectStyles(currentProject.id); + setWritingStyles(response.styles); + + // 设置默认风格为初始选中 + const defaultStyle = response.styles.find(s => s.is_default); + if (defaultStyle) { + setSelectedStyleId(defaultStyle.id); + } + } catch (error) { + console.error('加载写作风格失败:', error); + message.error('加载写作风格失败'); + } + }; + if (!currentProject) return null; const canGenerateChapter = (chapter: Chapter): boolean => { @@ -146,7 +167,7 @@ export default function Chapters() { textArea.scrollTop = textArea.scrollHeight; } } - }); + }, selectedStyleId); message.success('AI创作成功'); } catch (error) { @@ -163,6 +184,8 @@ export default function Chapters() { c => c.chapter_number < chapter.chapter_number ).sort((a, b) => a.chapter_number - b.chapter_number); + const selectedStyle = writingStyles.find(s => s.id === selectedStyleId); + const modal = Modal.confirm({ title: 'AI创作章节内容', width: 700, @@ -175,6 +198,9 @@ export default function Chapters() {
  • 项目的世界观设定
  • 相关角色信息
  • 前面已完成章节的内容(确保剧情连贯)
  • + {selectedStyle && ( +
  • 写作风格:{selectedStyle.name}
  • + )} {previousChapters.length > 0 && ( @@ -219,6 +245,17 @@ export default function Chapters() { }); try { + if (!selectedStyleId) { + message.error('请先选择写作风格'); + modal.update({ + okButtonProps: { danger: true, loading: false }, + cancelButtonProps: { disabled: false }, + closable: true, + maskClosable: true, + keyboard: true, + }); + return; + } await handleGenerate(); modal.destroy(); } catch (error) { @@ -526,6 +563,35 @@ export default function Chapters() { + + + {!selectedStyleId && ( +
    + 请选择写作风格 +
    + )} +
    +