Files
MuMuAINovel/backend/app/api/inspiration.py
T

340 lines
13 KiB
Python

"""灵感模式API - 通过对话引导创建项目"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any
import json
from app.database import get_db
from app.services.ai_service import AIService
from app.api.settings import get_user_ai_service
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
router = APIRouter(prefix="/inspiration", tags=["灵感模式"])
logger = get_logger(__name__)
# 不同阶段的temperature设置(递减以保持一致性)
TEMPERATURE_SETTINGS = {
"title": 0.8, # 书名阶段可以更有创意
"description": 0.65, # 简介需要贴合书名和原始想法
"theme": 0.55, # 主题需要更加贴合
"genre": 0.45 # 类型应该很明确
}
def validate_options_response(result: Dict[str, Any], step: str, max_retries: int = 3) -> tuple[bool, str]:
"""
校验AI返回的选项格式是否正确
Returns:
(is_valid, error_message)
"""
# 检查必需字段
if "options" not in result:
return False, "缺少options字段"
options = result.get("options", [])
# 检查options是否为数组
if not isinstance(options, list):
return False, "options必须是数组"
# 检查数组长度
if len(options) < 3:
return False, f"选项数量不足,至少需要3个,当前只有{len(options)}"
if len(options) > 10:
return False, f"选项数量过多,最多10个,当前有{len(options)}"
# 检查每个选项是否为字符串且不为空
for i, option in enumerate(options):
if not isinstance(option, str):
return False, f"{i+1}个选项不是字符串类型"
if not option.strip():
return False, f"{i+1}个选项为空"
if len(option) > 500:
return False, f"{i+1}个选项过长(超过500字符)"
# 根据不同步骤进行特定校验
if step == "genre":
# 类型标签应该比较短
for i, option in enumerate(options):
if len(option) > 10:
return False, f"类型标签【{option}】过长,应该在2-10字之间"
return True, ""
@router.post("/generate-options")
async def generate_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
根据当前收集的信息生成下一步的选项建议(带自动重试)
Request:
{
"step": "title", // title/description/theme/genre
"context": {
"title": "...",
"description": "...",
"theme": "..."
}
}
Response:
{
"prompt": "引导语",
"options": ["选项1", "选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
logger.info(f"灵感模式:生成{step}阶段的选项(第{attempt + 1}次尝试)")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板(根据step确定模板key)
template_key_map = {
"title": "INSPIRATION_TITLE",
"description": "INSPIRATION_DESCRIPTION",
"theme": "INSPIRATION_THEME",
"genre": "INSPIRATION_GENRE"
}
template_key = template_key_map.get(step)
if not template_key:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
# 尝试解析为JSON格式的字典
try:
prompt_template = json.loads(prompt_template_str)
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
except (json.JSONDecodeError, KeyError):
# 如果不是JSON格式,降级使用原有方法
prompt_template = prompt_service.get_inspiration_prompt(step)
if not prompt_template:
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
# 如果是重试,在提示词中强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回,确保options数组包含6个有效选项!"
# 调用AI生成选项
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
response = await ai_service.generate_text(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
)
content = response.get("content", "")
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON
try:
# 清理可能的markdown标记
cleaned_content = content.strip()
if cleaned_content.startswith('```json'):
cleaned_content = cleaned_content[7:].lstrip('\n\r')
elif cleaned_content.startswith('```'):
cleaned_content = cleaned_content[3:].lstrip('\n\r')
if cleaned_content.endswith('```'):
cleaned_content = cleaned_content[:-3].rstrip('\n\r')
cleaned_content = cleaned_content.strip()
# 检查JSON是否完整
if not cleaned_content.endswith('}'):
logger.warning(f"⚠️ JSON可能被截断,尝试补全...")
if '"options"' in cleaned_content:
if cleaned_content.count('[') > cleaned_content.count(']'):
cleaned_content += '"]}'
elif cleaned_content.count('{') > cleaned_content.count('}'):
cleaned_content += '}'
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue # 重试
else:
# 最后一次尝试也失败了
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次,请手动重试或自己输入"
}
logger.info(f"✅ 第{attempt + 1}次成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue # 重试
else:
# 最后一次尝试也失败了
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}次,请手动重试或自己输入"
}
except Exception as e:
logger.error(f"{attempt + 1}次生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
# 理论上不会到这里
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/quick-generate")
async def quick_generate(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
智能补全:根据用户已提供的部分信息,AI自动补全缺失字段
Request:
{
"title": "书名(可选)",
"description": "简介(可选)",
"theme": "主题(可选)",
"genre": ["类型1", "类型2"](可选)
}
Response:
{
"title": "补全的书名",
"description": "补全的简介",
"theme": "补全的主题",
"genre": ["补全的类型"]
}
"""
try:
logger.info("灵感模式:智能补全")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 构建补全提示词
existing_info = []
if data.get("title"):
existing_info.append(f"- 书名:{data['title']}")
if data.get("description"):
existing_info.append(f"- 简介:{data['description']}")
if data.get("theme"):
existing_info.append(f"- 主题:{data['theme']}")
if data.get("genre"):
existing_info.append(f"- 类型:{', '.join(data['genre'])}")
existing_text = "\n".join(existing_info) if existing_info else "暂无信息"
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template("INSPIRATION_QUICK_COMPLETE", user_id, db)
# 格式化提示词
try:
prompts = json.loads(prompt_template_str)
# 格式化参数
prompts["system"] = prompts["system"].replace("{existing}", existing_text)
prompts["user"] = prompts["user"].replace("{existing}", existing_text)
except (json.JSONDecodeError, KeyError):
# 降级使用原有方法
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
# 调用AI
response = await ai_service.generate_text(
prompt=prompts["user"],
system_prompt=prompts["system"],
temperature=0.7
)
content = response.get("content", "")
# 解析JSON
try:
cleaned_content = content.strip()
if cleaned_content.startswith('```json'):
cleaned_content = cleaned_content[7:].lstrip('\n\r')
elif cleaned_content.startswith('```'):
cleaned_content = cleaned_content[3:].lstrip('\n\r')
if cleaned_content.endswith('```'):
cleaned_content = cleaned_content[:-3].rstrip('\n\r')
cleaned_content = cleaned_content.strip()
result = json.loads(cleaned_content)
# 合并用户已提供的信息(用户输入优先)
final_result = {
"title": data.get("title") or result.get("title", ""),
"description": data.get("description") or result.get("description", ""),
"theme": data.get("theme") or result.get("theme", ""),
"genre": data.get("genre") or result.get("genre", [])
}
logger.info(f"✅ 智能补全成功")
return final_result
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败: {e}")
raise Exception("AI返回格式错误,请重试")
except Exception as e:
logger.error(f"智能补全失败: {e}", exc_info=True)
return {
"error": str(e)
}