update:1.开放系统内置提示词,支持用户自定义模板
This commit is contained in:
+58
-21
@@ -37,7 +37,7 @@ from app.schemas.regeneration import (
|
||||
RegenerationTaskStatus
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService, WritingStyleManager
|
||||
from app.services.plot_analyzer import PlotAnalyzer
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.chapter_regenerator import ChapterRegenerator
|
||||
@@ -1236,9 +1236,11 @@ async def generate_chapter_content_stream(
|
||||
chapter_outline_content = outline.content if outline else current_chapter.summary or '暂无大纲'
|
||||
logger.warning(f"⚠️ 一对多模式但无expansion_plan,使用大纲内容")
|
||||
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格、记忆增强和MCP参考资料
|
||||
# 根据是否有前置内容选择不同的提示词,并应用写作风格、记忆增强和MCP参考资料(支持自定义)
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION_WITH_CONTEXT", current_user_id, db_session)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
@@ -1253,14 +1255,25 @@ async def generate_chapter_content_stream(
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=chapter_outline_content,
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context,
|
||||
mcp_references=mcp_reference_materials,
|
||||
outline_mode=outline_mode
|
||||
max_word_count=target_word_count + 1000,
|
||||
memory_context=memory_context.get('recent_context', '') + "\n" + memory_context.get('relevant_memories', '') + "\n" + memory_context.get('foreshadows', '') + "\n" + memory_context.get('character_states', '') + "\n" + memory_context.get('plot_points', '') if memory_context else "暂无相关记忆"
|
||||
)
|
||||
# 插入模式说明和MCP参考
|
||||
mode_instruction = "\n\n【创作模式说明】\n本章采用细纲模式:本章是大纲节点的细化展开之一。请严格遵循上述详细规划(expansion_plan)中的剧情点、角色焦点、情感基调和叙事目标,确保与整体规划保持一致,同时自然衔接前文内容。\n" if outline_mode == 'one-to-many' else "\n\n【创作模式说明】\n本章采用一对一模式:一个大纲节点对应一个章节。请在承接前文的基础上,充分展开大纲中的情节,保持叙事的完整性。\n"
|
||||
mcp_text = ""
|
||||
if mcp_reference_materials:
|
||||
mcp_text = "\n【📚 MCP工具搜索 - 参考资料】\n以下是通过MCP工具搜索到的相关参考资料,可用于丰富情节和细节:\n\n" + mcp_reference_materials + "\n"
|
||||
base_prompt = base_prompt.replace("本章信息:", mcp_text + mode_instruction + "\n本章信息:")
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION", current_user_id, db_session)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
@@ -1274,12 +1287,23 @@ async def generate_chapter_content_stream(
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=chapter_outline_content,
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context,
|
||||
mcp_references=mcp_reference_materials,
|
||||
outline_mode=outline_mode
|
||||
max_word_count=target_word_count + 1000
|
||||
)
|
||||
# 插入模式说明和记忆、MCP参考
|
||||
mode_instruction = "\n\n【创作模式说明】\n本章采用细纲模式:本章是大纲节点的细化展开之一。请严格遵循上述详细规划中的剧情点、角色焦点和情感基调,确保与整体规划保持一致。\n" if outline_mode == 'one-to-many' else "\n\n【创作模式说明】\n本章采用一对一模式:一个大纲节点对应一个章节。请充分展开大纲中的情节,注重叙事的完整性和丰满度。\n"
|
||||
memory_text = ""
|
||||
if memory_context:
|
||||
memory_text = "\n【🧠 智能记忆系统 - 重要参考】\n" + memory_context.get('recent_context', '') + "\n" + memory_context.get('relevant_memories', '') + "\n" + memory_context.get('foreshadows', '') + "\n" + memory_context.get('character_states', '') + "\n" + memory_context.get('plot_points', '')
|
||||
mcp_text = ""
|
||||
if mcp_reference_materials:
|
||||
mcp_text = "\n【📚 MCP工具搜索 - 参考资料】\n以下是通过MCP工具搜索到的相关参考资料,可用于丰富情节和细节:\n\n" + mcp_reference_materials + "\n"
|
||||
base_prompt = base_prompt.replace("本章信息:", memory_text + mcp_text + mode_instruction + "\n\n本章信息:")
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
if mcp_reference_materials:
|
||||
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
|
||||
@@ -2412,9 +2436,12 @@ async def generate_single_chapter_for_batch(
|
||||
chapter_outline_content = outline.content if outline else chapter.summary or '暂无大纲'
|
||||
logger.warning(f"⚠️ 批量生成 - 一对多模式但无expansion_plan,使用大纲内容")
|
||||
|
||||
# 生成提示词
|
||||
# 生成提示词(支持自定义)
|
||||
if previous_content:
|
||||
prompt = prompt_service.get_chapter_generation_with_context_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION_WITH_CONTEXT", user_id, db_session)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
@@ -2429,13 +2456,20 @@ async def generate_single_chapter_for_batch(
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=chapter_outline_content,
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context,
|
||||
outline_mode=outline_mode
|
||||
max_word_count=target_word_count + 1000,
|
||||
memory_context=memory_context.get('recent_context', '') + "\n" + memory_context.get('relevant_memories', '') + "\n" + memory_context.get('foreshadows', '') + "\n" + memory_context.get('character_states', '') + "\n" + memory_context.get('plot_points', '') if memory_context else "暂无相关记忆"
|
||||
)
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("CHAPTER_GENERATION", user_id, db_session)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or '',
|
||||
genre=project.genre or '',
|
||||
@@ -2449,11 +2483,14 @@ async def generate_single_chapter_for_batch(
|
||||
chapter_number=chapter.chapter_number,
|
||||
chapter_title=chapter.title,
|
||||
chapter_outline=chapter_outline_content,
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count,
|
||||
memory_context=memory_context,
|
||||
outline_mode=outline_mode
|
||||
max_word_count=target_word_count + 1000
|
||||
)
|
||||
# 应用写作风格
|
||||
if style_content:
|
||||
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# 非流式生成内容
|
||||
full_content = ""
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.schemas.character import (
|
||||
CharacterGenerateRequest
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
@@ -419,8 +419,11 @@ async def generate_character(
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
"""
|
||||
|
||||
# 使用统一的提示词服务
|
||||
prompt = prompt_service.get_single_character_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_CHARACTER", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
@@ -825,7 +828,11 @@ async def generate_character_stream(
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
||||
|
||||
prompt = prompt_service.get_single_character_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_CHARACTER", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
+52
-120
@@ -1,5 +1,5 @@
|
||||
"""灵感模式API - 通过对话引导创建项目"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
@@ -7,97 +7,13 @@ import json
|
||||
from app.database import get_db
|
||||
from app.services.ai_service import AIService
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/inspiration", tags=["灵感模式"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# 灵感模式提示词模板
|
||||
INSPIRATION_PROMPTS = {
|
||||
"title": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
|
||||
请根据用户的想法,生成6个吸引人的书名建议,要求:
|
||||
1. 紧扣用户的原始想法和核心故事构思
|
||||
2. 富有创意和吸引力
|
||||
3. 涵盖不同的风格倾向
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
"prompt": "根据你的想法,我为你准备了几个书名建议:",
|
||||
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
|
||||
}}
|
||||
|
||||
只返回纯JSON,不要有其他文字。""",
|
||||
"user": "用户的想法:{initial_idea}\n请生成6个书名建议"
|
||||
},
|
||||
|
||||
"description": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
已确定的书名:{title}
|
||||
|
||||
请生成6个精彩的小说简介,要求:
|
||||
1. 必须紧扣用户的原始想法,确保简介是原始想法的具体展开
|
||||
2. 符合已确定的书名风格
|
||||
3. 简洁有力,每个50-100字
|
||||
4. 包含核心冲突
|
||||
5. 涵盖不同的故事走向,但都基于用户的原始构思
|
||||
|
||||
返回JSON格式:
|
||||
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
|
||||
|
||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
||||
},
|
||||
|
||||
"theme": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
小说信息:
|
||||
- 书名:{title}
|
||||
- 简介:{description}
|
||||
|
||||
请生成6个深刻的主题选项,要求:
|
||||
1. 必须与用户的原始想法保持高度一致
|
||||
2. 符合书名和简介的风格
|
||||
3. 有深度和思想性
|
||||
4. 每个50-150字
|
||||
5. 涵盖不同角度(如:成长、复仇、救赎、探索等),但都围绕用户的核心构思
|
||||
|
||||
返回JSON格式:
|
||||
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
|
||||
|
||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
||||
},
|
||||
|
||||
"genre": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
小说信息:
|
||||
- 书名:{title}
|
||||
- 简介:{description}
|
||||
- 主题:{theme}
|
||||
|
||||
请生成6个合适的类型标签(每个2-4字),要求:
|
||||
1. 必须符合用户原始想法中暗示的类型倾向
|
||||
2. 符合小说整体风格
|
||||
3. 可以多选组合
|
||||
|
||||
常见类型:玄幻、都市、科幻、武侠、仙侠、历史、言情、悬疑、奇幻、修仙等
|
||||
|
||||
返回JSON格式:
|
||||
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
|
||||
|
||||
只返回紧凑的纯JSON,不要换行,不要有其他文字。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 不同阶段的temperature设置(递减以保持一致性)
|
||||
TEMPERATURE_SETTINGS = {
|
||||
"title": 0.8, # 书名阶段可以更有创意
|
||||
@@ -153,6 +69,8 @@ def validate_options_response(result: Dict[str, Any], step: str, max_retries: in
|
||||
@router.post("/generate-options")
|
||||
async def generate_options(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -183,28 +101,49 @@ async def generate_options(
|
||||
|
||||
logger.info(f"灵感模式:生成{step}阶段的选项(第{attempt + 1}次尝试)")
|
||||
|
||||
# 获取对应的提示词模板
|
||||
if step not in INSPIRATION_PROMPTS:
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板(根据step确定模板key)
|
||||
template_key_map = {
|
||||
"title": "INSPIRATION_TITLE",
|
||||
"description": "INSPIRATION_DESCRIPTION",
|
||||
"theme": "INSPIRATION_THEME",
|
||||
"genre": "INSPIRATION_GENRE"
|
||||
}
|
||||
template_key = template_key_map.get(step)
|
||||
|
||||
if not template_key:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
prompt_template = INSPIRATION_PROMPTS[step]
|
||||
# 获取自定义提示词模板
|
||||
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
|
||||
|
||||
# 准备格式化参数(提供默认值避免KeyError)
|
||||
# 关键改进:保持initial_idea在所有阶段传递,确保内容关联性
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
"initial_idea": context.get("initial_idea", context.get("description", "")), # 优先使用initial_idea,兼容旧数据
|
||||
"initial_idea": context.get("initial_idea", context.get("description", "")),
|
||||
"title": context.get("title", ""),
|
||||
"description": context.get("description", ""),
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化系统提示词
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
|
||||
# 尝试解析为JSON格式的字典
|
||||
try:
|
||||
prompt_template = json.loads(prompt_template_str)
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# 如果不是JSON格式,降级使用原有方法
|
||||
prompt_template = prompt_service.get_inspiration_prompt(step)
|
||||
if not prompt_template:
|
||||
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
|
||||
# 如果是重试,在提示词中强调格式要求
|
||||
if attempt > 0:
|
||||
@@ -302,6 +241,8 @@ async def generate_options(
|
||||
@router.post("/quick-generate")
|
||||
async def quick_generate(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -326,6 +267,9 @@ async def quick_generate(
|
||||
try:
|
||||
logger.info("灵感模式:智能补全")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 构建补全提示词
|
||||
existing_info = []
|
||||
if data.get("title"):
|
||||
@@ -339,35 +283,23 @@ async def quick_generate(
|
||||
|
||||
existing_text = "\n".join(existing_info) if existing_info else "暂无信息"
|
||||
|
||||
system_prompt = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
|
||||
|
||||
用户已提供的信息:
|
||||
{existing}
|
||||
|
||||
请生成完整的小说方案,包含:
|
||||
1. title: 书名(3-6字,如果用户已提供则保持原样)
|
||||
2. description: 简介(50-100字,必须基于用户提供的信息,不要偏离原意)
|
||||
3. theme: 核心主题(30-50字,必须与用户提供的信息保持一致)
|
||||
4. genre: 类型标签数组(2-3个)
|
||||
|
||||
重要:所有补全的内容都必须与用户提供的信息保持高度关联,确保前后一致性。
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
"title": "书名",
|
||||
"description": "简介内容...",
|
||||
"theme": "主题内容...",
|
||||
"genre": ["类型1", "类型2"]
|
||||
}}
|
||||
|
||||
只返回纯JSON,不要有其他文字。"""
|
||||
# 获取自定义提示词模板
|
||||
prompt_template_str = await PromptService.get_template("INSPIRATION_QUICK_COMPLETE", user_id, db)
|
||||
|
||||
user_prompt = "请补全小说信息"
|
||||
# 格式化提示词
|
||||
try:
|
||||
prompts = json.loads(prompt_template_str)
|
||||
# 格式化参数
|
||||
prompts["system"] = prompts["system"].replace("{existing}", existing_text)
|
||||
prompts["user"] = prompts["user"].replace("{existing}", existing_text)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# 降级使用原有方法
|
||||
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
|
||||
|
||||
# 调用AI
|
||||
response = await ai_service.generate_text(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt.format(existing=existing_text),
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from app.schemas.relationship import (
|
||||
)
|
||||
from app.schemas.character import CharacterResponse
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
@@ -496,8 +496,11 @@ async def generate_organization(
|
||||
- 其他要求:{gen_request.requirements or '无'}
|
||||
"""
|
||||
|
||||
# 使用统一的提示词服务
|
||||
prompt = prompt_service.get_single_organization_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_ORGANIZATION", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
@@ -689,7 +692,11 @@ async def generate_organization_stream(
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
||||
|
||||
prompt = prompt_service.get_single_organization_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_ORGANIZATION", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
project_context=project_context,
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ from app.schemas.outline import (
|
||||
CreateChaptersFromPlansResponse
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.plot_expansion_service import PlotExpansionService
|
||||
from app.logger import get_logger
|
||||
@@ -477,8 +477,10 @@ async def _generate_new_outline(
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
# 使用完整提示词(插入MCP参考资料)
|
||||
prompt = prompt_service.get_complete_outline_prompt(
|
||||
# 使用完整提示词(插入MCP参考资料,支持自定义)
|
||||
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
genre=request.genre or project.genre or "通用",
|
||||
@@ -797,8 +799,10 @@ async def _continue_outline(
|
||||
logger.warning(f"⚠️ 第{batch_num + 1}批MCP工具调用失败,降级为基础模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
# 使用标准续写提示词模板(支持记忆+MCP增强)
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
# 使用标准续写提示词模板(支持记忆+MCP增强+自定义)
|
||||
template = await PromptService.get_template("OUTLINE_CONTINUE_GENERATION", user_id, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=request.theme or project.theme or "未设定",
|
||||
genre=request.genre or project.genre or "通用",
|
||||
@@ -814,6 +818,7 @@ async def _continue_outline(
|
||||
recent_plot=recent_plot,
|
||||
plot_stage_instruction=stage_instruction,
|
||||
start_chapter=current_start_chapter,
|
||||
end_chapter=current_start_chapter + current_batch_size - 1,
|
||||
story_direction=request.story_direction or "自然延续",
|
||||
requirements=request.requirements or "",
|
||||
memory_context=memory_context,
|
||||
@@ -1084,9 +1089,11 @@ async def new_outline_generator(
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
|
||||
mcp_reference_materials = ""
|
||||
|
||||
# 使用完整提示词(插入MCP参考资料)
|
||||
# 使用完整提示词(插入MCP参考资料,支持自定义)
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 20)
|
||||
prompt = prompt_service.get_complete_outline_prompt(
|
||||
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id_for_mcp, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=data.get("theme") or project.theme or "未设定",
|
||||
genre=data.get("genre") or project.genre or "通用",
|
||||
@@ -1412,8 +1419,10 @@ async def continue_outline_generator(
|
||||
batch_progress + 5
|
||||
)
|
||||
|
||||
# 使用标准续写提示词模板(支持记忆+MCP增强)
|
||||
prompt = prompt_service.get_outline_continue_prompt(
|
||||
# 使用标准续写提示词模板(支持记忆+MCP增强+自定义)
|
||||
template = await PromptService.get_template("OUTLINE_CONTINUE_GENERATION", user_id, db)
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=data.get("theme") or project.theme or "未设定",
|
||||
genre=data.get("genre") or project.genre or "通用",
|
||||
@@ -1429,6 +1438,7 @@ async def continue_outline_generator(
|
||||
recent_plot=recent_plot,
|
||||
plot_stage_instruction=stage_instruction,
|
||||
start_chapter=current_start_chapter,
|
||||
end_chapter=current_start_chapter + current_batch_size - 1,
|
||||
story_direction=data.get("story_direction", "自然延续"),
|
||||
requirements=data.get("requirements", ""),
|
||||
memory_context=memory_context,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""AI去味API - 核心特色功能"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.generation_history import GenerationHistory
|
||||
from app.schemas.polish import PolishRequest, PolishResponse
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
|
||||
@@ -17,6 +17,7 @@ logger = get_logger(__name__)
|
||||
@router.post("", response_model=PolishResponse, summary="AI去味")
|
||||
async def polish_text(
|
||||
request: PolishRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -32,8 +33,14 @@ async def polish_text(
|
||||
这是本项目的核心特色功能!
|
||||
"""
|
||||
try:
|
||||
# 构建AI去味提示词
|
||||
prompt = prompt_service.get_denoising_prompt(
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("DENOISING", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(
|
||||
template,
|
||||
original_text=request.original_text
|
||||
)
|
||||
|
||||
@@ -85,6 +92,7 @@ async def polish_batch(
|
||||
project_id: int = None,
|
||||
provider: str = None,
|
||||
model: str = None,
|
||||
http_request: Request = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -94,12 +102,18 @@ async def polish_batch(
|
||||
适用于一次性处理多个章节或段落
|
||||
"""
|
||||
try:
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None) if http_request else None
|
||||
|
||||
results = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
logger.info(f"处理第 {idx+1}/{len(texts)} 个文本")
|
||||
|
||||
prompt = prompt_service.get_denoising_prompt(original_text=text)
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("DENOISING", user_id, db)
|
||||
# 格式化提示词
|
||||
prompt = PromptService.format_prompt(template, original_text=text)
|
||||
|
||||
polished_text = await user_ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
|
||||
@@ -0,0 +1,478 @@
|
||||
"""提示词模板管理 API"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.prompt_template import PromptTemplate
|
||||
from app.schemas.prompt_template import (
|
||||
PromptTemplateCreate,
|
||||
PromptTemplateUpdate,
|
||||
PromptTemplateResponse,
|
||||
PromptTemplateListResponse,
|
||||
PromptTemplateCategoryResponse,
|
||||
PromptTemplateExport,
|
||||
PromptTemplatePreviewRequest
|
||||
)
|
||||
from app.services.prompt_service import PromptService
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
|
||||
|
||||
|
||||
@router.get("", response_model=PromptTemplateListResponse)
|
||||
async def get_all_templates(
|
||||
request: Request,
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
is_active: Optional[bool] = Query(None, description="按启用状态筛选"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取用户所有提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
query = select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
||||
|
||||
if category:
|
||||
query = query.where(PromptTemplate.category == category)
|
||||
if is_active is not None:
|
||||
query = query.where(PromptTemplate.is_active == is_active)
|
||||
|
||||
query = query.order_by(PromptTemplate.category, PromptTemplate.template_key)
|
||||
|
||||
result = await db.execute(query)
|
||||
templates = result.scalars().all()
|
||||
|
||||
# 获取所有分类
|
||||
categories_result = await db.execute(
|
||||
select(PromptTemplate.category)
|
||||
.where(PromptTemplate.user_id == user_id)
|
||||
.distinct()
|
||||
)
|
||||
categories = [c for c in categories_result.scalars().all() if c]
|
||||
|
||||
return PromptTemplateListResponse(
|
||||
templates=templates,
|
||||
total=len(templates),
|
||||
categories=sorted(categories)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[PromptTemplateCategoryResponse])
|
||||
async def get_templates_by_category(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
按分类获取提示词模板(合并用户自定义和系统默认)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 1. 查询用户自定义模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.user_id == user_id)
|
||||
.order_by(PromptTemplate.category, PromptTemplate.template_key)
|
||||
)
|
||||
user_templates = result.scalars().all()
|
||||
|
||||
# 2. 获取所有系统默认模板
|
||||
system_templates = PromptService.get_all_system_templates()
|
||||
|
||||
# 3. 构建用户自定义模板的键集合
|
||||
user_template_keys = {t.template_key for t in user_templates}
|
||||
|
||||
# 4. 合并模板:用户自定义的 + 未自定义的系统默认
|
||||
all_templates = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# 添加用户自定义的模板
|
||||
for user_template in user_templates:
|
||||
user_template.is_system_default = False # 标记为已自定义
|
||||
all_templates.append(user_template)
|
||||
|
||||
# 添加未自定义的系统默认模板
|
||||
for sys_template in system_templates:
|
||||
if sys_template['template_key'] not in user_template_keys:
|
||||
# 这个系统模板用户还没有自定义,创建临时对象
|
||||
template_obj = PromptTemplate(
|
||||
id=sys_template['template_key'], # 使用template_key作为临时ID
|
||||
user_id=user_id,
|
||||
template_key=sys_template['template_key'],
|
||||
template_name=sys_template['template_name'],
|
||||
template_content=sys_template['content'],
|
||||
description=sys_template['description'],
|
||||
category=sys_template['category'],
|
||||
parameters=json.dumps(sys_template['parameters']),
|
||||
is_active=True,
|
||||
is_system_default=True,
|
||||
created_at=current_time,
|
||||
updated_at=current_time
|
||||
)
|
||||
all_templates.append(template_obj)
|
||||
|
||||
# 5. 按分类分组
|
||||
category_dict = {}
|
||||
for template in all_templates:
|
||||
cat = template.category or "未分类"
|
||||
if cat not in category_dict:
|
||||
category_dict[cat] = []
|
||||
category_dict[cat].append(template)
|
||||
|
||||
# 6. 构建响应
|
||||
response = []
|
||||
for category, temps in sorted(category_dict.items()):
|
||||
# 按template_key排序,确保顺序一致
|
||||
temps.sort(key=lambda t: t.template_key)
|
||||
response.append(PromptTemplateCategoryResponse(
|
||||
category=category,
|
||||
count=len(temps),
|
||||
templates=temps
|
||||
))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/system-defaults")
|
||||
async def get_system_defaults(
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
获取所有系统默认提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 从PromptService获取所有系统默认模板
|
||||
system_templates = PromptService.get_all_system_templates()
|
||||
|
||||
return {
|
||||
"templates": system_templates,
|
||||
"total": len(system_templates)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{template_key}", response_model=PromptTemplateResponse)
|
||||
async def get_template(
|
||||
template_key: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取指定的提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@router.post("", response_model=PromptTemplateResponse)
|
||||
async def create_or_update_template(
|
||||
data: PromptTemplateCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建或更新提示词模板(Upsert)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 查找现有模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == data.template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template:
|
||||
# 更新现有模板
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
setattr(template, key, value)
|
||||
logger.info(f"用户 {user_id} 更新模板 {data.template_key}")
|
||||
else:
|
||||
# 创建新模板
|
||||
template = PromptTemplate(
|
||||
user_id=user_id,
|
||||
**data.model_dump()
|
||||
)
|
||||
db.add(template)
|
||||
logger.info(f"用户 {user_id} 创建模板 {data.template_key}")
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@router.put("/{template_key}", response_model=PromptTemplateResponse)
|
||||
async def update_template(
|
||||
template_key: str,
|
||||
data: PromptTemplateUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
||||
|
||||
# 更新模板
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(template, key, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
logger.info(f"用户 {user_id} 更新模板 {template_key}")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
@router.delete("/{template_key}")
|
||||
async def delete_template(
|
||||
template_key: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除自定义提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
||||
|
||||
await db.delete(template)
|
||||
await db.commit()
|
||||
logger.info(f"用户 {user_id} 删除模板 {template_key}")
|
||||
|
||||
return {"message": "模板已删除", "template_key": template_key}
|
||||
|
||||
|
||||
@router.post("/{template_key}/reset")
|
||||
async def reset_to_default(
|
||||
template_key: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
重置为系统默认模板(删除用户自定义版本)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证系统默认模板是否存在
|
||||
system_template = PromptService.get_system_template_info(template_key)
|
||||
if not system_template:
|
||||
raise HTTPException(status_code=404, detail=f"系统默认模板 {template_key} 不存在")
|
||||
|
||||
# 查找并删除用户的自定义模板
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_key
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template:
|
||||
await db.delete(template)
|
||||
await db.commit()
|
||||
logger.info(f"用户 {user_id} 删除自定义模板 {template_key},恢复为系统默认")
|
||||
return {"message": "已重置为系统默认", "template_key": template_key}
|
||||
else:
|
||||
# 用户本来就没有自定义,已经是系统默认状态
|
||||
logger.info(f"用户 {user_id} 的模板 {template_key} 本来就是系统默认")
|
||||
return {"message": "已是系统默认状态", "template_key": template_key}
|
||||
|
||||
|
||||
@router.post("/export", response_model=PromptTemplateExport)
|
||||
async def export_templates(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导出用户所有自定义模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
||||
)
|
||||
templates = result.scalars().all()
|
||||
|
||||
# 转换为导出格式
|
||||
export_data = [
|
||||
{
|
||||
"template_key": t.template_key,
|
||||
"template_name": t.template_name,
|
||||
"template_content": t.template_content,
|
||||
"description": t.description,
|
||||
"category": t.category,
|
||||
"parameters": t.parameters,
|
||||
"is_active": t.is_active
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
|
||||
logger.info(f"用户 {user_id} 导出了 {len(export_data)} 个模板")
|
||||
|
||||
return PromptTemplateExport(
|
||||
templates=export_data,
|
||||
export_time=datetime.now()
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
async def import_templates(
|
||||
data: PromptTemplateExport,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
导入提示词模板
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
imported_count = 0
|
||||
updated_count = 0
|
||||
|
||||
for template_data in data.templates:
|
||||
# 查找是否已存在
|
||||
result = await db.execute(
|
||||
select(PromptTemplate).where(
|
||||
PromptTemplate.user_id == user_id,
|
||||
PromptTemplate.template_key == template_data.template_key
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有模板
|
||||
for key, value in template_data.model_dump().items():
|
||||
setattr(existing, key, value)
|
||||
updated_count += 1
|
||||
else:
|
||||
# 创建新模板
|
||||
new_template = PromptTemplate(
|
||||
user_id=user_id,
|
||||
**template_data.model_dump()
|
||||
)
|
||||
db.add(new_template)
|
||||
imported_count += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"用户 {user_id} 导入了 {imported_count} 个新模板,更新了 {updated_count} 个模板")
|
||||
|
||||
return {
|
||||
"message": "导入成功",
|
||||
"imported": imported_count,
|
||||
"updated": updated_count,
|
||||
"total": imported_count + updated_count
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{template_key}/preview")
|
||||
async def preview_template(
|
||||
template_key: str,
|
||||
data: PromptTemplatePreviewRequest,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
预览提示词模板(渲染变量)
|
||||
"""
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
try:
|
||||
# 使用PromptService的format_prompt方法
|
||||
rendered = PromptService.format_prompt(
|
||||
data.template_content,
|
||||
**data.parameters
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"rendered_content": rendered,
|
||||
"parameters_used": list(data.parameters.keys())
|
||||
}
|
||||
except KeyError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"缺少必需的参数: {str(e)}",
|
||||
"rendered_content": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"渲染失败: {str(e)}",
|
||||
"rendered_content": None
|
||||
}
|
||||
@@ -16,7 +16,7 @@ from app.models.writing_style import WritingStyle
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.mcp_tool_service import MCPToolService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.services.plot_expansion_service import PlotExpansionService
|
||||
from app.logger import get_logger
|
||||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||||
@@ -57,12 +57,14 @@ async def world_building_generator(
|
||||
yield await SSEResponse.send_error("title、description、theme 和 genre 是必需的参数", 400)
|
||||
return
|
||||
|
||||
# 获取基础提示词
|
||||
# 获取基础提示词(支持自定义)
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=title,
|
||||
theme=theme,
|
||||
genre=genre
|
||||
genre=genre or "通用类型"
|
||||
)
|
||||
|
||||
# MCP工具增强:收集参考资料
|
||||
@@ -455,8 +457,11 @@ async def characters_generator(
|
||||
else:
|
||||
batch_requirements += "\n主要是配角(supporting)和反派(antagonist)"
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("CHARACTERS_BATCH", user_id, db)
|
||||
# 构建基础提示词
|
||||
base_prompt = prompt_service.get_characters_batch_prompt(
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
count=current_batch_size, # 传递精确数量
|
||||
time_period=world_context.get("time_period", ""),
|
||||
location=world_context.get("location", ""),
|
||||
@@ -954,7 +959,10 @@ async def outline_generator(
|
||||
outline_requirements += "4. 不要试图完结故事,这只是开始部分\n"
|
||||
outline_requirements += "5. 不要在JSON字符串值中使用中文引号(""''),请使用【】或《》标记\n"
|
||||
|
||||
outline_prompt = prompt_service.get_complete_outline_prompt(
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("COMPLETE_OUTLINE_GENERATION", user_id, db)
|
||||
outline_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or "未设定",
|
||||
genre=project.genre or "通用",
|
||||
@@ -966,6 +974,7 @@ async def outline_generator(
|
||||
atmosphere=project.world_atmosphere or "未设定",
|
||||
rules=project.world_rules or "未设定",
|
||||
characters_info=characters_info or "暂无角色信息",
|
||||
mcp_references="",
|
||||
requirements=outline_requirements
|
||||
)
|
||||
|
||||
@@ -1150,9 +1159,11 @@ async def world_building_regenerate_generator(
|
||||
enable_mcp = data.get("enable_mcp", True)
|
||||
user_id = data.get("user_id")
|
||||
|
||||
# 获取基础提示词
|
||||
# 获取基础提示词(支持自定义)
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
|
||||
base_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
theme=project.theme or "未设定",
|
||||
genre=project.genre or "通用"
|
||||
|
||||
Reference in New Issue
Block a user