Files
MuMuAINovel/backend/app/api/inspiration.py
T

494 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""灵感模式API - 通过对话引导创建项目"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any
import json
from app.database import get_db
from app.services.ai_service import AIService
from app.api.settings import get_user_ai_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
router = APIRouter(prefix="/inspiration", tags=["灵感模式"])
logger = get_logger(__name__)
# 不同阶段的temperature设置(递减以保持一致性)
TEMPERATURE_SETTINGS = {
"title": 0.8, # 书名阶段可以更有创意
"description": 0.65, # 简介需要贴合书名和原始想法
"theme": 0.55, # 主题需要更加贴合
"genre": 0.45 # 类型应该很明确
}
def validate_options_response(result: Dict[str, Any], step: str, max_retries: int = 3) -> tuple[bool, str]:
"""
校验AI返回的选项格式是否正确
Returns:
(is_valid, error_message)
"""
# 检查必需字段
if "options" not in result:
return False, "缺少options字段"
options = result.get("options", [])
# 检查options是否为数组
if not isinstance(options, list):
return False, "options必须是数组"
# 检查数组长度
if len(options) < 3:
return False, f"选项数量不足,至少需要3个,当前只有{len(options)}"
if len(options) > 10:
return False, f"选项数量过多,最多10个,当前有{len(options)}"
# 检查每个选项是否为字符串且不为空
for i, option in enumerate(options):
if not isinstance(option, str):
return False, f"{i+1}个选项不是字符串类型"
if not option.strip():
return False, f"{i+1}个选项为空"
if len(option) > 500:
return False, f"{i+1}个选项过长(超过500字符)"
# 根据不同步骤进行特定校验
if step == "genre":
# 类型标签应该比较短
for i, option in enumerate(options):
if len(option) > 10:
return False, f"类型标签【{option}】过长,应该在2-10字之间"
return True, ""
@router.post("/generate-options")
async def generate_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
根据当前收集的信息生成下一步的选项建议(带自动重试)
Request:
{
"step": "title", // title/description/theme/genre
"context": {
"title": "...",
"description": "...",
"theme": "..."
}
}
Response:
{
"prompt": "引导语",
"options": ["选项1", "选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
logger.info(f"灵感模式:生成{step}阶段的选项(第{attempt + 1}次尝试)")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板(根据step确定模板key)
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
template_key_map = {
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_keys = template_key_map.get(step)
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
system_key, user_key = template_keys
# 获取自定义提示词模板(分别获取 system 和 user
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 如果是重试,在提示词中强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回,确保options数组包含6个有效选项!"
# 调用AI生成选项
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
):
accumulated_text += chunk
response = {"content": accumulated_text}
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON(使用统一的JSON清洗方法)
try:
# 使用统一的JSON清洗方法
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue # 重试
else:
# 最后一次尝试也失败了
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次,请手动重试或自己输入"
}
logger.info(f"✅ 第{attempt + 1}次成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue # 重试
else:
# 最后一次尝试也失败了
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}次,请手动重试或自己输入"
}
except Exception as e:
logger.error(f"{attempt + 1}次生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
# 理论上不会到这里
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/refine-options")
async def refine_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
基于用户反馈重新生成选项(支持多轮对话)
Request:
{
"step": "title", // 当前步骤
"context": {
"initial_idea": "...",
"title": "...",
"description": "...",
"theme": "..."
},
"feedback": "我想要更悲剧一些的主题", // 用户反馈
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
}
Response:
{
"prompt": "引导语",
"options": ["新选项1", "新选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
feedback = data.get("feedback", "")
previous_options = data.get("previous_options", [])
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
logger.info(f"用户反馈: {feedback}")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板
template_key_map = {
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_keys = template_key_map.get(step)
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
system_key, user_key = template_keys
# 获取自定义提示词模板
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 添加反馈信息到提示词
feedback_instruction = f"""
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
{feedback}
之前生成的选项:
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
注意:
1. 仔细理解用户的反馈意图
2. 生成的新选项要明显体现用户要求的调整方向
3. 保持与已有上下文的一致性
4. 确保返回6个有效选项
"""
system_prompt += feedback_instruction
# 如果是重试,强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
# 调用AI生成选项
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
# 反馈生成时使用稍高的temperature以获得更多样化的结果
temperature = min(temperature + 0.1, 0.9)
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
):
accumulated_text += chunk
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON
try:
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}"
}
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}"
}
except Exception as e:
logger.error(f"{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/quick-generate")
async def quick_generate(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
智能补全:根据用户已提供的部分信息,AI自动补全缺失字段
Request:
{
"title": "书名(可选)",
"description": "简介(可选)",
"theme": "主题(可选)",
"genre": ["类型1", "类型2"](可选)
}
Response:
{
"title": "补全的书名",
"description": "补全的简介",
"theme": "补全的主题",
"genre": ["补全的类型"]
}
"""
try:
logger.info("灵感模式:智能补全")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 构建补全提示词
existing_info = []
if data.get("title"):
existing_info.append(f"- 书名:{data['title']}")
if data.get("description"):
existing_info.append(f"- 简介:{data['description']}")
if data.get("theme"):
existing_info.append(f"- 主题:{data['theme']}")
if data.get("genre"):
existing_info.append(f"- 类型:{', '.join(data['genre'])}")
existing_text = "\n".join(existing_info) if existing_info else "暂无信息"
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template("INSPIRATION_QUICK_COMPLETE", user_id, db)
# 格式化提示词
try:
prompts = json.loads(prompt_template_str)
# 格式化参数
prompts["system"] = prompts["system"].replace("{existing}", existing_text)
prompts["user"] = prompts["user"].replace("{existing}", existing_text)
except (json.JSONDecodeError, KeyError):
# 降级使用原有方法
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
# 调用AI - 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=prompts["user"],
system_prompt=prompts["system"],
temperature=0.7
):
accumulated_text += chunk
response = {"content": accumulated_text}
content = accumulated_text
# 解析JSON(使用统一的JSON清洗方法)
try:
# 使用统一的JSON清洗方法
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 合并用户已提供的信息(用户输入优先)
final_result = {
"title": data.get("title") or result.get("title", ""),
"description": data.get("description") or result.get("description", ""),
"theme": data.get("theme") or result.get("theme", ""),
"genre": data.get("genre") or result.get("genre", [])
}
logger.info(f"✅ 智能补全成功")
return final_result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise Exception("AI返回格式错误,请重试")
except Exception as e:
logger.error(f"智能补全失败: {e}", exc_info=True)
return {
"error": str(e)
}