update:1.开放系统内置提示词,支持用户自定义模板

This commit is contained in:
xiamuceer
2025-11-29 22:01:02 +08:00
parent e772676621
commit d102328b75
23 changed files with 2325 additions and 746 deletions
+58 -21
View File
@@ -37,7 +37,7 @@ from app.schemas.regeneration import (
RegenerationTaskStatus
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService, WritingStyleManager
from app.services.plot_analyzer import PlotAnalyzer
from app.services.memory_service import memory_service
from app.services.chapter_regenerator import ChapterRegenerator
@@ -1236,9 +1236,11 @@ async def generate_chapter_content_stream(
chapter_outline_content = outline.content if outline else current_chapter.summary or '暂无大纲'
logger.warning(f"⚠️ 一对多模式但无expansion_plan,使用大纲内容")
# 根据是否有前置内容选择不同的提示词,并应用写作风格、记忆增强和MCP参考资料
# 根据是否有前置内容选择不同的提示词,并应用写作风格、记忆增强和MCP参考资料(支持自定义)
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
template = await PromptService.get_template("CHAPTER_GENERATION_WITH_CONTEXT", current_user_id, db_session)
base_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
@@ -1253,14 +1255,25 @@ async def generate_chapter_content_stream(
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=chapter_outline_content,
style_content=style_content,
target_word_count=target_word_count,
memory_context=memory_context,
mcp_references=mcp_reference_materials,
outline_mode=outline_mode
max_word_count=target_word_count + 1000,
memory_context=memory_context.get('recent_context', '') + "\n" + memory_context.get('relevant_memories', '') + "\n" + memory_context.get('foreshadows', '') + "\n" + memory_context.get('character_states', '') + "\n" + memory_context.get('plot_points', '') if memory_context else "暂无相关记忆"
)
# 插入模式说明和MCP参考
mode_instruction = "\n\n【创作模式说明】\n本章采用细纲模式:本章是大纲节点的细化展开之一。请严格遵循上述详细规划(expansion_plan)中的剧情点、角色焦点、情感基调和叙事目标,确保与整体规划保持一致,同时自然衔接前文内容。\n" if outline_mode == 'one-to-many' else "\n\n【创作模式说明】\n本章采用一对一模式:一个大纲节点对应一个章节。请在承接前文的基础上,充分展开大纲中的情节,保持叙事的完整性。\n"
mcp_text = ""
if mcp_reference_materials:
mcp_text = "\n【📚 MCP工具搜索 - 参考资料】\n以下是通过MCP工具搜索到的相关参考资料,可用于丰富情节和细节:\n\n" + mcp_reference_materials + "\n"
base_prompt = base_prompt.replace("本章信息:", mcp_text + mode_instruction + "\n本章信息:")
# 应用写作风格
if style_content:
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
else:
prompt = base_prompt
else:
prompt = prompt_service.get_chapter_generation_prompt(
template = await PromptService.get_template("CHAPTER_GENERATION", current_user_id, db_session)
base_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
@@ -1274,12 +1287,23 @@ async def generate_chapter_content_stream(
chapter_number=current_chapter.chapter_number,
chapter_title=current_chapter.title,
chapter_outline=chapter_outline_content,
style_content=style_content,
target_word_count=target_word_count,
memory_context=memory_context,
mcp_references=mcp_reference_materials,
outline_mode=outline_mode
max_word_count=target_word_count + 1000
)
# 插入模式说明和记忆、MCP参考
mode_instruction = "\n\n【创作模式说明】\n本章采用细纲模式:本章是大纲节点的细化展开之一。请严格遵循上述详细规划中的剧情点、角色焦点和情感基调,确保与整体规划保持一致。\n" if outline_mode == 'one-to-many' else "\n\n【创作模式说明】\n本章采用一对一模式:一个大纲节点对应一个章节。请充分展开大纲中的情节,注重叙事的完整性和丰满度。\n"
memory_text = ""
if memory_context:
memory_text = "\n【🧠 智能记忆系统 - 重要参考】\n" + memory_context.get('recent_context', '') + "\n" + memory_context.get('relevant_memories', '') + "\n" + memory_context.get('foreshadows', '') + "\n" + memory_context.get('character_states', '') + "\n" + memory_context.get('plot_points', '')
mcp_text = ""
if mcp_reference_materials:
mcp_text = "\n【📚 MCP工具搜索 - 参考资料】\n以下是通过MCP工具搜索到的相关参考资料,可用于丰富情节和细节:\n\n" + mcp_reference_materials + "\n"
base_prompt = base_prompt.replace("本章信息:", memory_text + mcp_text + mode_instruction + "\n\n本章信息:")
# 应用写作风格
if style_content:
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
else:
prompt = base_prompt
if mcp_reference_materials:
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
@@ -2412,9 +2436,12 @@ async def generate_single_chapter_for_batch(
chapter_outline_content = outline.content if outline else chapter.summary or '暂无大纲'
logger.warning(f"⚠️ 批量生成 - 一对多模式但无expansion_plan,使用大纲内容")
# 生成提示词
# 生成提示词(支持自定义)
if previous_content:
prompt = prompt_service.get_chapter_generation_with_context_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("CHAPTER_GENERATION_WITH_CONTEXT", user_id, db_session)
base_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
@@ -2429,13 +2456,20 @@ async def generate_single_chapter_for_batch(
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=chapter_outline_content,
style_content=style_content,
target_word_count=target_word_count,
memory_context=memory_context,
outline_mode=outline_mode
max_word_count=target_word_count + 1000,
memory_context=memory_context.get('recent_context', '') + "\n" + memory_context.get('relevant_memories', '') + "\n" + memory_context.get('foreshadows', '') + "\n" + memory_context.get('character_states', '') + "\n" + memory_context.get('plot_points', '') if memory_context else "暂无相关记忆"
)
# 应用写作风格
if style_content:
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
else:
prompt = base_prompt
else:
prompt = prompt_service.get_chapter_generation_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("CHAPTER_GENERATION", user_id, db_session)
base_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or '',
genre=project.genre or '',
@@ -2449,11 +2483,14 @@ async def generate_single_chapter_for_batch(
chapter_number=chapter.chapter_number,
chapter_title=chapter.title,
chapter_outline=chapter_outline_content,
style_content=style_content,
target_word_count=target_word_count,
memory_context=memory_context,
outline_mode=outline_mode
max_word_count=target_word_count + 1000
)
# 应用写作风格
if style_content:
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
else:
prompt = base_prompt
# 非流式生成内容
full_content = ""
+11 -4
View File
@@ -19,7 +19,7 @@ from app.schemas.character import (
CharacterGenerateRequest
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -419,8 +419,11 @@ async def generate_character(
- 其他要求:{request.requirements or ''}
"""
# 使用统一的提示词服务
prompt = prompt_service.get_single_character_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_CHARACTER", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_context=project_context,
user_input=user_input
)
@@ -825,7 +828,11 @@ async def generate_character_stream(
yield await SSEResponse.send_progress("构建AI提示词...", 20)
prompt = prompt_service.get_single_character_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_CHARACTER", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_context=project_context,
user_input=user_input
)
+52 -120
View File
@@ -1,5 +1,5 @@
"""灵感模式API - 通过对话引导创建项目"""
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any
import json
@@ -7,97 +7,13 @@ import json
from app.database import get_db
from app.services.ai_service import AIService
from app.api.settings import get_user_ai_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
router = APIRouter(prefix="/inspiration", tags=["灵感模式"])
logger = get_logger(__name__)
# 灵感模式提示词模板
INSPIRATION_PROMPTS = {
"title": {
"system": """你是一位专业的小说创作顾问。
用户的原始想法:{initial_idea}
请根据用户的想法,生成6个吸引人的书名建议,要求:
1. 紧扣用户的原始想法和核心故事构思
2. 富有创意和吸引力
3. 涵盖不同的风格倾向
返回JSON格式:
{{
"prompt": "根据你的想法,我为你准备了几个书名建议:",
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
}}
只返回纯JSON,不要有其他文字。""",
"user": "用户的想法:{initial_idea}\n请生成6个书名建议"
},
"description": {
"system": """你是一位专业的小说创作顾问。
用户的原始想法:{initial_idea}
已确定的书名:{title}
请生成6个精彩的小说简介,要求:
1. 必须紧扣用户的原始想法,确保简介是原始想法的具体展开
2. 符合已确定的书名风格
3. 简洁有力,每个50-100字
4. 包含核心冲突
5. 涵盖不同的故事走向,但都基于用户的原始构思
返回JSON格式:
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
只返回纯JSON,不要有其他文字,不要换行。""",
"user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
},
"theme": {
"system": """你是一位专业的小说创作顾问。
用户的原始想法:{initial_idea}
小说信息:
- 书名:{title}
- 简介:{description}
请生成6个深刻的主题选项,要求:
1. 必须与用户的原始想法保持高度一致
2. 符合书名和简介的风格
3. 有深度和思想性
4. 每个50-150字
5. 涵盖不同角度(如:成长、复仇、救赎、探索等),但都围绕用户的核心构思
返回JSON格式:
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
只返回纯JSON,不要有其他文字,不要换行。""",
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
},
"genre": {
"system": """你是一位专业的小说创作顾问。
用户的原始想法:{initial_idea}
小说信息:
- 书名:{title}
- 简介:{description}
- 主题:{theme}
请生成6个合适的类型标签(每个2-4字),要求:
1. 必须符合用户原始想法中暗示的类型倾向
2. 符合小说整体风格
3. 可以多选组合
常见类型:玄幻、都市、科幻、武侠、仙侠、历史、言情、悬疑、奇幻、修仙等
返回JSON格式:
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
只返回紧凑的纯JSON,不要换行,不要有其他文字。""",
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
}
}
# 不同阶段的temperature设置(递减以保持一致性)
TEMPERATURE_SETTINGS = {
"title": 0.8, # 书名阶段可以更有创意
@@ -153,6 +69,8 @@ def validate_options_response(result: Dict[str, Any], step: str, max_retries: in
@router.post("/generate-options")
async def generate_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
@@ -183,28 +101,49 @@ async def generate_options(
logger.info(f"灵感模式:生成{step}阶段的选项(第{attempt + 1}次尝试)")
# 获取对应的提示词模板
if step not in INSPIRATION_PROMPTS:
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板(根据step确定模板key)
template_key_map = {
"title": "INSPIRATION_TITLE",
"description": "INSPIRATION_DESCRIPTION",
"theme": "INSPIRATION_THEME",
"genre": "INSPIRATION_GENRE"
}
template_key = template_key_map.get(step)
if not template_key:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
prompt_template = INSPIRATION_PROMPTS[step]
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
# 准备格式化参数(提供默认值避免KeyError
# 关键改进:保持initial_idea在所有阶段传递,确保内容关联性
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")), # 优先使用initial_idea,兼容旧数据
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化系统提示词
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
# 尝试解析为JSON格式的字典
try:
prompt_template = json.loads(prompt_template_str)
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
except (json.JSONDecodeError, KeyError):
# 如果不是JSON格式,降级使用原有方法
prompt_template = prompt_service.get_inspiration_prompt(step)
if not prompt_template:
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
# 如果是重试,在提示词中强调格式要求
if attempt > 0:
@@ -302,6 +241,8 @@ async def generate_options(
@router.post("/quick-generate")
async def quick_generate(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
@@ -326,6 +267,9 @@ async def quick_generate(
try:
logger.info("灵感模式:智能补全")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 构建补全提示词
existing_info = []
if data.get("title"):
@@ -339,35 +283,23 @@ async def quick_generate(
existing_text = "\n".join(existing_info) if existing_info else "暂无信息"
system_prompt = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
用户已提供的信息:
{existing}
请生成完整的小说方案,包含:
1. title: 书名(3-6字,如果用户已提供则保持原样)
2. description: 简介(50-100字,必须基于用户提供的信息,不要偏离原意)
3. theme: 核心主题(30-50字,必须与用户提供的信息保持一致)
4. genre: 类型标签数组(2-3个)
重要:所有补全的内容都必须与用户提供的信息保持高度关联,确保前后一致性。
返回JSON格式:
{{
"title": "书名",
"description": "简介内容...",
"theme": "主题内容...",
"genre": ["类型1", "类型2"]
}}
只返回纯JSON,不要有其他文字。"""
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template("INSPIRATION_QUICK_COMPLETE", user_id, db)
user_prompt = "请补全小说信息"
# 格式化提示词
try:
prompts = json.loads(prompt_template_str)
# 格式化参数
prompts["system"] = prompts["system"].replace("{existing}", existing_text)
prompts["user"] = prompts["user"].replace("{existing}", existing_text)
except (json.JSONDecodeError, KeyError):
# 降级使用原有方法
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
# 调用AI
response = await ai_service.generate_text(
prompt=user_prompt,
system_prompt=system_prompt.format(existing=existing_text),
prompt=prompts["user"],
system_prompt=prompts["system"],
temperature=0.7
)
+11 -4
View File
@@ -24,7 +24,7 @@ from app.schemas.relationship import (
)
from app.schemas.character import CharacterResponse
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -496,8 +496,11 @@ async def generate_organization(
- 其他要求:{gen_request.requirements or ''}
"""
# 使用统一的提示词服务
prompt = prompt_service.get_single_organization_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_ORGANIZATION", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_context=project_context,
user_input=user_input
)
@@ -689,7 +692,11 @@ async def generate_organization_stream(
yield await SSEResponse.send_progress("构建AI提示词...", 20)
prompt = prompt_service.get_single_organization_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_ORGANIZATION", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
project_context=project_context,
user_input=user_input
)
+19 -9
View File
@@ -25,7 +25,7 @@ from app.schemas.outline import (
CreateChaptersFromPlansResponse
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService
from app.services.memory_service import memory_service
from app.services.plot_expansion_service import PlotExpansionService
from app.logger import get_logger
@@ -477,8 +477,10 @@ async def _generate_new_outline(
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
mcp_reference_materials = ""
# 使用完整提示词(插入MCP参考资料)
prompt = prompt_service.get_complete_outline_prompt(
# 使用完整提示词(插入MCP参考资料,支持自定义
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id, db)
prompt = PromptService.format_prompt(
template,
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
@@ -797,8 +799,10 @@ async def _continue_outline(
logger.warning(f"⚠️ 第{batch_num + 1}批MCP工具调用失败,降级为基础模式: {str(e)}")
mcp_reference_materials = ""
# 使用标准续写提示词模板(支持记忆+MCP增强)
prompt = prompt_service.get_outline_continue_prompt(
# 使用标准续写提示词模板(支持记忆+MCP增强+自定义
template = await PromptService.get_template("OUTLINE_CONTINUE_GENERATION", user_id, db)
prompt = PromptService.format_prompt(
template,
title=project.title,
theme=request.theme or project.theme or "未设定",
genre=request.genre or project.genre or "通用",
@@ -814,6 +818,7 @@ async def _continue_outline(
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
end_chapter=current_start_chapter + current_batch_size - 1,
story_direction=request.story_direction or "自然延续",
requirements=request.requirements or "",
memory_context=memory_context,
@@ -1084,9 +1089,11 @@ async def new_outline_generator(
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
mcp_reference_materials = ""
# 使用完整提示词(插入MCP参考资料)
# 使用完整提示词(插入MCP参考资料,支持自定义
yield await SSEResponse.send_progress("准备AI提示词...", 20)
prompt = prompt_service.get_complete_outline_prompt(
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id_for_mcp, db)
prompt = PromptService.format_prompt(
template,
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
genre=data.get("genre") or project.genre or "通用",
@@ -1412,8 +1419,10 @@ async def continue_outline_generator(
batch_progress + 5
)
# 使用标准续写提示词模板(支持记忆+MCP增强)
prompt = prompt_service.get_outline_continue_prompt(
# 使用标准续写提示词模板(支持记忆+MCP增强+自定义
template = await PromptService.get_template("OUTLINE_CONTINUE_GENERATION", user_id, db)
prompt = PromptService.format_prompt(
template,
title=project.title,
theme=data.get("theme") or project.theme or "未设定",
genre=data.get("genre") or project.genre or "通用",
@@ -1429,6 +1438,7 @@ async def continue_outline_generator(
recent_plot=recent_plot,
plot_stage_instruction=stage_instruction,
start_chapter=current_start_chapter,
end_chapter=current_start_chapter + current_batch_size - 1,
story_direction=data.get("story_direction", "自然延续"),
requirements=data.get("requirements", ""),
memory_context=memory_context,
+19 -5
View File
@@ -1,12 +1,12 @@
"""AI去味API - 核心特色功能"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.generation_history import GenerationHistory
from app.schemas.polish import PolishRequest, PolishResponse
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
@@ -17,6 +17,7 @@ logger = get_logger(__name__)
@router.post("", response_model=PolishResponse, summary="AI去味")
async def polish_text(
request: PolishRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -32,8 +33,14 @@ async def polish_text(
这是本项目的核心特色功能!
"""
try:
# 构建AI去味提示词
prompt = prompt_service.get_denoising_prompt(
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取自定义提示词模板
template = await PromptService.get_template("DENOISING", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(
template,
original_text=request.original_text
)
@@ -85,6 +92,7 @@ async def polish_batch(
project_id: int = None,
provider: str = None,
model: str = None,
http_request: Request = None,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
@@ -94,12 +102,18 @@ async def polish_batch(
适用于一次性处理多个章节或段落
"""
try:
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None) if http_request else None
results = []
for idx, text in enumerate(texts):
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
prompt = prompt_service.get_denoising_prompt(original_text=text)
# 获取自定义提示词模板
template = await PromptService.get_template("DENOISING", user_id, db)
# 格式化提示词
prompt = PromptService.format_prompt(template, original_text=text)
polished_text = await user_ai_service.generate_text(
prompt=prompt,
+478
View File
@@ -0,0 +1,478 @@
"""提示词模板管理 API"""
from fastapi import APIRouter, HTTPException, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, delete
from typing import List, Optional
from datetime import datetime
import json
from app.database import get_db
from app.models.prompt_template import PromptTemplate
from app.schemas.prompt_template import (
PromptTemplateCreate,
PromptTemplateUpdate,
PromptTemplateResponse,
PromptTemplateListResponse,
PromptTemplateCategoryResponse,
PromptTemplateExport,
PromptTemplatePreviewRequest
)
from app.services.prompt_service import PromptService
from app.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
@router.get("", response_model=PromptTemplateListResponse)
async def get_all_templates(
request: Request,
category: Optional[str] = Query(None, description="按分类筛选"),
is_active: Optional[bool] = Query(None, description="按启用状态筛选"),
db: AsyncSession = Depends(get_db)
):
"""
获取用户所有提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
query = select(PromptTemplate).where(PromptTemplate.user_id == user_id)
if category:
query = query.where(PromptTemplate.category == category)
if is_active is not None:
query = query.where(PromptTemplate.is_active == is_active)
query = query.order_by(PromptTemplate.category, PromptTemplate.template_key)
result = await db.execute(query)
templates = result.scalars().all()
# 获取所有分类
categories_result = await db.execute(
select(PromptTemplate.category)
.where(PromptTemplate.user_id == user_id)
.distinct()
)
categories = [c for c in categories_result.scalars().all() if c]
return PromptTemplateListResponse(
templates=templates,
total=len(templates),
categories=sorted(categories)
)
@router.get("/categories", response_model=List[PromptTemplateCategoryResponse])
async def get_templates_by_category(
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
按分类获取提示词模板(合并用户自定义和系统默认)
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 1. 查询用户自定义模板
result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.user_id == user_id)
.order_by(PromptTemplate.category, PromptTemplate.template_key)
)
user_templates = result.scalars().all()
# 2. 获取所有系统默认模板
system_templates = PromptService.get_all_system_templates()
# 3. 构建用户自定义模板的键集合
user_template_keys = {t.template_key for t in user_templates}
# 4. 合并模板:用户自定义的 + 未自定义的系统默认
all_templates = []
current_time = datetime.now()
# 添加用户自定义的模板
for user_template in user_templates:
user_template.is_system_default = False # 标记为已自定义
all_templates.append(user_template)
# 添加未自定义的系统默认模板
for sys_template in system_templates:
if sys_template['template_key'] not in user_template_keys:
# 这个系统模板用户还没有自定义,创建临时对象
template_obj = PromptTemplate(
id=sys_template['template_key'], # 使用template_key作为临时ID
user_id=user_id,
template_key=sys_template['template_key'],
template_name=sys_template['template_name'],
template_content=sys_template['content'],
description=sys_template['description'],
category=sys_template['category'],
parameters=json.dumps(sys_template['parameters']),
is_active=True,
is_system_default=True,
created_at=current_time,
updated_at=current_time
)
all_templates.append(template_obj)
# 5. 按分类分组
category_dict = {}
for template in all_templates:
cat = template.category or "未分类"
if cat not in category_dict:
category_dict[cat] = []
category_dict[cat].append(template)
# 6. 构建响应
response = []
for category, temps in sorted(category_dict.items()):
# 按template_key排序,确保顺序一致
temps.sort(key=lambda t: t.template_key)
response.append(PromptTemplateCategoryResponse(
category=category,
count=len(temps),
templates=temps
))
return response
@router.get("/system-defaults")
async def get_system_defaults(
request: Request
):
"""
获取所有系统默认提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 从PromptService获取所有系统默认模板
system_templates = PromptService.get_all_system_templates()
return {
"templates": system_templates,
"total": len(system_templates)
}
@router.get("/{template_key}", response_model=PromptTemplateResponse)
async def get_template(
template_key: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
获取指定的提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
return template
@router.post("", response_model=PromptTemplateResponse)
async def create_or_update_template(
data: PromptTemplateCreate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
创建或更新提示词模板(Upsert)
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 查找现有模板
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == data.template_key
)
)
template = result.scalar_one_or_none()
if template:
# 更新现有模板
for key, value in data.model_dump(exclude_unset=True).items():
setattr(template, key, value)
logger.info(f"用户 {user_id} 更新模板 {data.template_key}")
else:
# 创建新模板
template = PromptTemplate(
user_id=user_id,
**data.model_dump()
)
db.add(template)
logger.info(f"用户 {user_id} 创建模板 {data.template_key}")
await db.commit()
await db.refresh(template)
return template
@router.put("/{template_key}", response_model=PromptTemplateResponse)
async def update_template(
template_key: str,
data: PromptTemplateUpdate,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
更新提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
# 更新模板
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(template, key, value)
await db.commit()
await db.refresh(template)
logger.info(f"用户 {user_id} 更新模板 {template_key}")
return template
@router.delete("/{template_key}")
async def delete_template(
template_key: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
删除自定义提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if not template:
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
await db.delete(template)
await db.commit()
logger.info(f"用户 {user_id} 删除模板 {template_key}")
return {"message": "模板已删除", "template_key": template_key}
@router.post("/{template_key}/reset")
async def reset_to_default(
template_key: str,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
重置为系统默认模板(删除用户自定义版本)
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证系统默认模板是否存在
system_template = PromptService.get_system_template_info(template_key)
if not system_template:
raise HTTPException(status_code=404, detail=f"系统默认模板 {template_key} 不存在")
# 查找并删除用户的自定义模板
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_key
)
)
template = result.scalar_one_or_none()
if template:
await db.delete(template)
await db.commit()
logger.info(f"用户 {user_id} 删除自定义模板 {template_key},恢复为系统默认")
return {"message": "已重置为系统默认", "template_key": template_key}
else:
# 用户本来就没有自定义,已经是系统默认状态
logger.info(f"用户 {user_id} 的模板 {template_key} 本来就是系统默认")
return {"message": "已是系统默认状态", "template_key": template_key}
@router.post("/export", response_model=PromptTemplateExport)
async def export_templates(
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
导出用户所有自定义模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
)
templates = result.scalars().all()
# 转换为导出格式
export_data = [
{
"template_key": t.template_key,
"template_name": t.template_name,
"template_content": t.template_content,
"description": t.description,
"category": t.category,
"parameters": t.parameters,
"is_active": t.is_active
}
for t in templates
]
logger.info(f"用户 {user_id} 导出了 {len(export_data)} 个模板")
return PromptTemplateExport(
templates=export_data,
export_time=datetime.now()
)
@router.post("/import")
async def import_templates(
data: PromptTemplateExport,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""
导入提示词模板
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
imported_count = 0
updated_count = 0
for template_data in data.templates:
# 查找是否已存在
result = await db.execute(
select(PromptTemplate).where(
PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_data.template_key
)
)
existing = result.scalar_one_or_none()
if existing:
# 更新现有模板
for key, value in template_data.model_dump().items():
setattr(existing, key, value)
updated_count += 1
else:
# 创建新模板
new_template = PromptTemplate(
user_id=user_id,
**template_data.model_dump()
)
db.add(new_template)
imported_count += 1
await db.commit()
logger.info(f"用户 {user_id} 导入了 {imported_count} 个新模板,更新了 {updated_count} 个模板")
return {
"message": "导入成功",
"imported": imported_count,
"updated": updated_count,
"total": imported_count + updated_count
}
@router.post("/{template_key}/preview")
async def preview_template(
template_key: str,
data: PromptTemplatePreviewRequest,
request: Request
):
"""
预览提示词模板(渲染变量)
"""
# 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
try:
# 使用PromptService的format_prompt方法
rendered = PromptService.format_prompt(
data.template_content,
**data.parameters
)
return {
"success": True,
"rendered_content": rendered,
"parameters_used": list(data.parameters.keys())
}
except KeyError as e:
return {
"success": False,
"error": f"缺少必需的参数: {str(e)}",
"rendered_content": None
}
except Exception as e:
return {
"success": False,
"error": f"渲染失败: {str(e)}",
"rendered_content": None
}
+19 -8
View File
@@ -16,7 +16,7 @@ from app.models.writing_style import WritingStyle
from app.models.project_default_style import ProjectDefaultStyle
from app.services.ai_service import AIService
from app.services.mcp_tool_service import MCPToolService
from app.services.prompt_service import prompt_service
from app.services.prompt_service import prompt_service, PromptService
from app.services.plot_expansion_service import PlotExpansionService
from app.logger import get_logger
from app.utils.sse_response import SSEResponse, create_sse_response
@@ -57,12 +57,14 @@ async def world_building_generator(
yield await SSEResponse.send_error("title、description、theme 和 genre 是必需的参数", 400)
return
# 获取基础提示词
# 获取基础提示词(支持自定义)
yield await SSEResponse.send_progress("准备AI提示词...", 15)
base_prompt = prompt_service.get_world_building_prompt(
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
base_prompt = PromptService.format_prompt(
template,
title=title,
theme=theme,
genre=genre
genre=genre or "通用类型"
)
# MCP工具增强:收集参考资料
@@ -455,8 +457,11 @@ async def characters_generator(
else:
batch_requirements += "\n主要是配角(supporting)和反派(antagonist)"
# 获取自定义提示词模板
template = await PromptService.get_template("CHARACTERS_BATCH", user_id, db)
# 构建基础提示词
base_prompt = prompt_service.get_characters_batch_prompt(
base_prompt = PromptService.format_prompt(
template,
count=current_batch_size, # 传递精确数量
time_period=world_context.get("time_period", ""),
location=world_context.get("location", ""),
@@ -954,7 +959,10 @@ async def outline_generator(
outline_requirements += "4. 不要试图完结故事,这只是开始部分\n"
outline_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
outline_prompt = prompt_service.get_complete_outline_prompt(
# 获取自定义提示词模板
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id, db)
outline_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or "未设定",
genre=project.genre or "通用",
@@ -966,6 +974,7 @@ async def outline_generator(
atmosphere=project.world_atmosphere or "未设定",
rules=project.world_rules or "未设定",
characters_info=characters_info or "暂无角色信息",
mcp_references="",
requirements=outline_requirements
)
@@ -1150,9 +1159,11 @@ async def world_building_regenerate_generator(
enable_mcp = data.get("enable_mcp", True)
user_id = data.get("user_id")
# 获取基础提示词
# 获取基础提示词(支持自定义)
yield await SSEResponse.send_progress("准备AI提示词...", 15)
base_prompt = prompt_service.get_world_building_prompt(
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
base_prompt = PromptService.format_prompt(
template,
title=project.title,
theme=project.theme or "未设定",
genre=project.genre or "通用"