update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词

This commit is contained in:
xiamuceer
2025-12-28 19:35:23 +08:00
parent f32e51b594
commit 89848e2258
40 changed files with 2752 additions and 1824 deletions
+31 -7
View File
@@ -309,13 +309,37 @@ async def generate_career_system(
7. 只返回纯JSON,不要添加任何解释文字
"""
yield await SSEResponse.send_progress("调用AI生成新职业...", 30)
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
try:
# 调用AI生成
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
# 使用流式生成替代非流式
ai_response = ""
chunk_count = 0
last_progress = 10
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 平滑更新进度(10-90%,AI生成占60%
# 每10个chunk增加约1%的进度,最多到90%
if chunk_count % 10 == 0:
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
current_progress = min(10 + (chunk_count // 10), 90)
if current_progress > last_progress:
last_progress = current_progress
yield await SSEResponse.send_progress(
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
current_progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
@@ -326,7 +350,7 @@ async def generate_career_system(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 50)
yield await SSEResponse.send_progress("解析AI响应...", 91)
# 清洗并解析JSON
try:
@@ -339,7 +363,7 @@ async def generate_career_system(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("保存主职业...", 60)
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
# 保存主职业
main_careers_created = []
@@ -371,7 +395,7 @@ async def generate_career_system(
logger.error(f" ❌ 创建主职业失败:{str(e)}")
continue
yield await SSEResponse.send_progress("保存副职业...", 80)
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
# 保存副职业
sub_careers_created = []
+47 -12
View File
@@ -1070,8 +1070,7 @@ async def analyze_chapter_background(
if career_update_result['updated_count'] > 0:
logger.info(
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息: "
f"{', '.join(career_update_result['updated_characters'])}"
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息"
)
if career_update_result['changes']:
for change in career_update_result['changes']:
@@ -1445,7 +1444,7 @@ async def generate_chapter_content_stream(
user_id=current_user_id,
db_session=db_session,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1596,10 +1595,24 @@ async def generate_chapter_content_stream(
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 发送开始生成的进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 35, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
system_prompt_with_style = None
if style_content:
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
{style_content}
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
确保在整个章节创作过程中始终保持风格的一致性。"""
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
# 准备生成参数
generate_kwargs = {"prompt": prompt}
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
}
if custom_model:
logger.info(f" 使用自定义模型: {custom_model}")
generate_kwargs["model"] = custom_model
@@ -1618,11 +1631,14 @@ async def generate_chapter_content_stream(
# 发送内容块
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
# 每20个chunk发送一次进度更新(提高频率
if chunk_count % 20 == 0:
# 每5个chunk发送一次进度更新(10-95%,更平滑
if chunk_count % 5 == 0:
current_word_count = len(full_content)
# 根据目标字数估算进度(40%起步,最高95%,为后续保存留5%)
estimated_progress = min(95, 40 + int((current_word_count / target_word_count) * 55))
# 优化进度计算:使用更平滑的递增方式
# 基于chunk数量和字数的混合计算,避免大幅跳跃
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
estimated_progress = min(95, 10 + chunk_progress + word_progress)
# 只在进度变化时发送
if estimated_progress > last_progress:
@@ -1636,10 +1652,14 @@ async def generate_chapter_content_stream(
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
last_progress = estimated_progress
# 每20个chunk发送心跳
if chunk_count % 20 == 0:
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
await asyncio.sleep(0) # 让出控制权
# 发送保存进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 98, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
@@ -1696,7 +1716,7 @@ async def generate_chapter_content_stream(
)
# 发送最终进度100%
yield f"data: {json.dumps({'type': 'progress', 'progress': 100, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
# 发送完成事件(包含分析任务ID
completion_data = {
@@ -2880,15 +2900,30 @@ async def generate_single_chapter_for_batch(
else:
prompt = base_prompt
# 🎨 方案一:将写作风格注入到系统提示词(批量生成)
system_prompt_with_style = None
if style_content:
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
{style_content}
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
确保在整个章节创作过程中始终保持风格的一致性。"""
logger.info(f"✅ 批量生成 - 已将写作风格注入系统提示词({len(style_content)}字符)")
# 非流式生成内容
full_content = ""
# 准备生成参数
generate_kwargs = {"prompt": prompt}
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
}
# 如果传入了自定义模型,使用指定的模型
if custom_model:
generate_kwargs["model"] = custom_model
logger.info(f" 批量生成使用自定义模型: {custom_model}")
# 批量生成中的流式生成(非SSE,不需要修改进度显示)
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
full_content += chunk
+120 -20
View File
@@ -662,10 +662,10 @@ async def generate_character_stream(
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(request.project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成角色...", 0)
yield await SSEResponse.send_progress("开始生成角色...", 1)
# 获取已存在的角色列表
yield await SSEResponse.send_progress("获取项目上下文...", 10)
yield await SSEResponse.send_progress("获取项目上下文...", 2)
existing_chars_result = await db.execute(
select(Character)
@@ -757,7 +757,7 @@ async def generate_character_stream(
- 其他要求:{request.requirements or ''}
"""
yield await SSEResponse.send_progress("构建AI提示词...", 20)
yield await SSEResponse.send_progress("构建AI提示词...", 3)
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
@@ -768,11 +768,14 @@ async def generate_character_stream(
user_input=user_input
)
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
try:
# 🔧 MCP工具增强:静默检查并收集参考资料
ai_response = ""
chunk_count = 0
if user_id:
try:
from app.services.mcp_tool_service import mcp_tool_service
@@ -789,7 +792,7 @@ async def generate_character_stream(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # 减少为1轮,避免超时
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -797,22 +800,119 @@ async def generate_character_stream(
if isinstance(result, dict):
ai_response = result.get('content', '')
if result.get('tool_calls_made', 0) > 0:
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
finish_reason = result.get('finish_reason', '')
tool_calls_made = result.get('tool_calls_made', 0)
# 🔧 修复:检查工具调用是否真正成功
if tool_calls_made > 0:
if finish_reason == 'tool_error':
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
# 工具调用失败,重新用基础模式生成
ai_response = ""
elif not ai_response.strip():
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
# 工具调用成功但返回空内容,重新生成
ai_response = ""
else:
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
# MCP成功且有内容,模拟流式输出(分块发送)
chunk_size = 50
for i in range(0, len(ai_response), chunk_size):
chunk = ai_response[i:i+chunk_size]
chunk_count += 1
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 3 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
)
# 跳过后续的流式生成
ai_response = result.get('content', '')
else:
ai_response = result
# 如果MCP调用失败或返回空,继续走流式生成
if not ai_response or not ai_response.strip():
logger.info(f"🔄 开始流式生成...")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as mcp_error:
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.debug(f"未登录用户,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
@@ -823,7 +923,7 @@ async def generate_character_stream(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 60)
yield await SSEResponse.send_progress("解析AI响应...", 96)
# ✅ 使用统一的 JSON 清洗方法
try:
@@ -836,7 +936,7 @@ async def generate_character_stream(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建角色记录...", 75)
yield await SSEResponse.send_progress("创建角色记录...", 97)
# 转换traits
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
# 如果是组织,创建Organization详情
if is_organization:
yield await SSEResponse.send_progress("创建组织详情...", 85)
yield await SSEResponse.send_progress("创建组织详情...", 98)
org_check = await db.execute(
select(Organization).where(Organization.character_id == character.id)
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
yield await SSEResponse.send_progress("保存生成历史...", 95)
yield await SSEResponse.send_progress("保存生成历史...", 99)
# 记录生成历史
history = GenerationHistory(
project_id=request.project_id,
prompt=prompt,
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
generated_content=ai_response,
model=user_ai_service.default_model
)
db.add(history)
+204 -28
View File
@@ -105,23 +105,27 @@ async def generate_options(
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板(根据step确定模板key)
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
template_key_map = {
"title": "INSPIRATION_TITLE",
"description": "INSPIRATION_DESCRIPTION",
"theme": "INSPIRATION_THEME",
"genre": "INSPIRATION_GENRE"
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_key = template_key_map.get(step)
template_keys = template_key_map.get(step)
if not template_key:
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
system_key, user_key = template_keys
# 获取自定义提示词模板(分别获取 system 和 user
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
@@ -131,19 +135,9 @@ async def generate_options(
"theme": context.get("theme", "")
}
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
# 尝试解析为JSON格式的字典
try:
prompt_template = json.loads(prompt_template_str)
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
except (json.JSONDecodeError, KeyError):
# 如果不是JSON格式,降级使用原有方法
prompt_template = prompt_service.get_inspiration_prompt(step)
if not prompt_template:
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 如果是重试,在提示词中强调格式要求
if attempt > 0:
@@ -153,13 +147,18 @@ async def generate_options(
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
response = await ai_service.generate_text(
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
)
):
accumulated_text += chunk
content = response.get("content", "")
response = {"content": accumulated_text}
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON(使用统一的JSON清洗方法)
@@ -222,6 +221,180 @@ async def generate_options(
}
@router.post("/refine-options")
async def refine_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
基于用户反馈重新生成选项(支持多轮对话)
Request:
{
"step": "title", // 当前步骤
"context": {
"initial_idea": "...",
"title": "...",
"description": "...",
"theme": "..."
},
"feedback": "我想要更悲剧一些的主题", // 用户反馈
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
}
Response:
{
"prompt": "引导语",
"options": ["新选项1", "新选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
feedback = data.get("feedback", "")
previous_options = data.get("previous_options", [])
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
logger.info(f"用户反馈: {feedback}")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板
template_key_map = {
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_keys = template_key_map.get(step)
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
system_key, user_key = template_keys
# 获取自定义提示词模板
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 添加反馈信息到提示词
feedback_instruction = f"""
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
{feedback}
之前生成的选项:
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
注意:
1. 仔细理解用户的反馈意图
2. 生成的新选项要明显体现用户要求的调整方向
3. 保持与已有上下文的一致性
4. 确保返回6个有效选项
"""
system_prompt += feedback_instruction
# 如果是重试,强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
# 调用AI生成选项
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
# 反馈生成时使用稍高的temperature以获得更多样化的结果
temperature = min(temperature + 0.1, 0.9)
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
):
accumulated_text += chunk
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON
try:
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}"
}
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}"
}
except Exception as e:
logger.error(f"{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/quick-generate")
async def quick_generate(
data: Dict[str, Any],
@@ -280,14 +453,17 @@ async def quick_generate(
# 降级使用原有方法
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
# 调用AI
response = await ai_service.generate_text(
# 调用AI - 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=prompts["user"],
system_prompt=prompts["system"],
temperature=0.7
)
):
accumulated_text += chunk
content = response.get("content", "")
response = {"content": accumulated_text}
content = accumulated_text
# 解析JSON(使用统一的JSON清洗方法)
try:
+27 -6
View File
@@ -512,8 +512,29 @@ async def generate_organization_stream(
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
try:
ai_response = await user_ai_service.generate_text(prompt=prompt)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else str(ai_response)
# 使用流式生成替代非流式
ai_content = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_content += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新字数(5-95%,AI生成占90%)
if chunk_count % 5 == 0:
progress = min(5 + (chunk_count // 5), 95)
yield await SSEResponse.send_progress(
f"AI生成组织中... ({len(ai_content)}字符)",
progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
@@ -523,7 +544,7 @@ async def generate_organization_stream(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 60)
yield await SSEResponse.send_progress("解析AI响应...", 96)
# ✅ 使用统一的 JSON 清洗方法
try:
@@ -536,7 +557,7 @@ async def generate_organization_stream(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建组织记录...", 75)
yield await SSEResponse.send_progress("创建组织记录...", 97)
# 创建角色记录(组织也是角色的一种)
character = Character(
@@ -563,7 +584,7 @@ async def generate_organization_stream(
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
yield await SSEResponse.send_progress("创建组织详情...", 85)
yield await SSEResponse.send_progress("创建组织详情...", 98)
# 自动创建Organization详情记录
organization = Organization(
@@ -580,7 +601,7 @@ async def generate_organization_stream(
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
yield await SSEResponse.send_progress("保存生成历史...", 95)
yield await SSEResponse.send_progress("保存生成历史...", 99)
# 记录生成历史
history = GenerationHistory(
+89 -30
View File
@@ -470,7 +470,7 @@ async def _generate_new_outline(
project: Project,
db: AsyncSession,
user_ai_service: AIService,
user_id: str = None
user_id: str
) -> OutlineListResponse:
"""全新生成大纲(MCP增强版)"""
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
@@ -534,7 +534,7 @@ async def _generate_new_outline(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -573,15 +573,23 @@ async def _generate_new_outline(
mcp_references=mcp_reference_materials
)
# 调用AI生成大纲
ai_response = await user_ai_service.generate_text(
# 调用AI流式生成大纲(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=request.provider,
model=request.model
)
):
chunk_count += 1
accumulated_text += chunk
# 这里是非SSE接口,不需要发送chunk
# 如果未来需要转SSE,可以在这里yield
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
@@ -732,7 +740,7 @@ async def _continue_outline(
existing_outlines: List[Outline],
db: AsyncSession,
user_ai_service: AIService,
user_id: str = "system"
user_id: str
) -> OutlineListResponse:
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
@@ -1000,7 +1008,7 @@ async def _continue_outline(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1045,15 +1053,22 @@ async def _continue_outline(
)
# 调用AI生成当前批次
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
ai_response = await user_ai_service.generate_text(
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=request.provider,
model=request.model
)
):
chunk_count += 1
accumulated_text += chunk
# 这里是非SSE接口,不需要发送chunk
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
user_id=user_id_for_mcp,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
mcp_references=mcp_reference_materials
)
# 调用AI
# 调用AI流式生成
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
# 添加调试日志
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
logger.info(f"=== 大纲生成AI调用参数 ===")
logger.info(f" provider参数: {provider_param}")
logger.info(f" model参数: {model_param}")
logger.info(f" 完整data: {data}")
ai_response = await user_ai_service.generate_text(
# ✅ 流式生成(带字数统计和进度)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider_param,
model=model_param
)
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数(30-95%,AI生成占65%
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 2), 95)
yield await SSEResponse.send_progress(
f"AI生成大纲中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
# 全新生成模式:删除旧大纲和关联的所有章节
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode}")
from sqlalchemy import delete as sql_delete
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
# 保存新大纲
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
outlines = await _save_outlines(
project_id, outline_data, db, start_index=1
)
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
for outline in outlines:
await db.refresh(outline)
yield await SSEResponse.send_progress("整理结果数据...", 95)
yield await SSEResponse.send_progress("整理结果数据...", 99)
logger.info(f"全新生成完成 - {len(outlines)}")
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
logger.info(f" provider参数: {provider_param}")
logger.info(f" model参数: {model_param}")
ai_response = await user_ai_service.generate_text(
# 流式生成并累积文本
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider_param,
model=model_param
)
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度(每批占用约50%的进度空间)
if chunk_count % 5 == 0:
# 在批次范围内平滑递增(从10到85,总共75%)
batch_range = 60 // total_batches # 总进度60%分配给所有批次
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
yield await SSEResponse.send_progress(
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
progress_in_batch
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await SSEResponse.send_progress(
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
batch_progress + 10
)
# 提取内容generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
# 提取内容
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
+25 -11
View File
@@ -73,14 +73,15 @@ async def get_user_ai_service(
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
# 使用用户设置创建AI服务实例
# 使用用户设置创建AI服务实例(包括系统提示词)
return create_user_ai_service(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url or "",
model_name=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens
max_tokens=settings.max_tokens,
system_prompt=settings.system_prompt # 传递系统提示词
)
@@ -271,17 +272,30 @@ async def get_available_models(
}
elif provider == "anthropic":
# Anthropic 没有公开的模型列表API
raise HTTPException(
status_code=400,
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
)
# Anthropic models API
url = f"{api_base_url.rstrip('/')}/v1/models"
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
return {"provider": provider, "models": models, "count": len(models)}
elif provider == "gemini":
# Gemini models API
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
response = await client.get(url)
response.raise_for_status()
data = response.json()
models = []
for m in data.get("models", []):
if "generateContent" in m.get("supportedGenerationMethods", []):
mid = m.get("name", "").replace("models/", "")
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
return {"provider": provider, "models": models, "count": len(models)}
else:
raise HTTPException(
status_code=400,
detail=f"不支持的提供商: {provider}"
)
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
except httpx.HTTPStatusError as e:
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
+362 -126
View File
@@ -99,7 +99,7 @@ async def world_building_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -139,51 +139,118 @@ async def world_building_generator(
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
# ===== 流式生成世界观(带重试机制) =====
MAX_WORLD_RETRIES = 3 # 最多重试3次
world_retry_count = 0
world_generation_success = False
world_data = {}
try:
# ✅ 使用 AIService 的统一清洗方法
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
world_data = json.loads(cleaned_text)
logger.info(f"世界观JSON解析成功")
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
try:
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
except json.JSONDecodeError as e:
logger.error(f"❌ 世界构建JSON解析失败: {e}")
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 世界观生成独立进度:5-95%
if chunk_count % 5 == 0:
progress = min(5 + (chunk_count // 3), 95)
yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 检查是否返回空响应
if not accumulated_text or not accumulated_text.strip():
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ AI返回为空,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
logger.error("❌ 世界观生成多次返回空响应")
world_data = {
"time_period": "AI多次返回为空,请稍后重试",
"location": "AI多次返回为空,请稍后重试",
"atmosphere": "AI多次返回为空,请稍后重试",
"rules": "AI多次返回为空,请稍后重试"
}
world_generation_success = True # 标记为成功以继续流程
break
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析世界观数据...", 96)
try:
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
logger.info(f" 原始内容预览: {accumulated_text[:300]}...")
# ✅ 使用 AIService 的统一清洗方法
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
logger.info(f" 清洗后预览: {cleaned_text[:300]}...")
world_data = json.loads(cleaned_text)
logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_generation_success = True # 解析成功,标记完成
except json.JSONDecodeError as e:
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {e}")
logger.error(f" 原始内容长度: {len(accumulated_text)}")
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ JSON解析失败,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
world_generation_success = True # 标记为成功以继续流程
except Exception as e:
logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {type(e).__name__}: {e}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 生成异常,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 最后一次重试仍失败,抛出异常
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
raise
# 保存到数据库
yield await SSEResponse.send_progress("保存到数据库...", 90)
yield await SSEResponse.send_progress("保存世界观到数据库...", 99)
# 确保user_id存在
if not user_id:
@@ -240,41 +307,81 @@ async def world_building_generator(
project.wizard_step = 1
await db.commit()
# ===== 自动生成职业体系 =====
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 75)
# ===== 自动生成职业体系(带重试机制+流式) =====
yield await SSEResponse.send_progress("世界观完成!", 100, "success")
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5)
logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系")
try:
# 获取职业生成提示词模板(支持用户自定义)
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
career_prompt = PromptService.format_prompt(
template,
title=project.title,
genre=genre or '未设定',
theme=theme or '未设定',
time_period=world_data.get('time_period', '未设定'),
location=world_data.get('location', '未设定'),
atmosphere=world_data.get('atmosphere', '未设定'),
rules=world_data.get('rules', '未设定')
)
yield await SSEResponse.send_progress("正在生成职业体系...", 78)
# 调用AI生成职业
result = await user_ai_service.generate_text(prompt=career_prompt)
career_response = result.get('content', '') if isinstance(result, dict) else result
if not career_response or not career_response.strip():
logger.warning("⚠️ AI返回空职业体系,跳过职业生成")
yield await SSEResponse.send_progress("职业体系生成跳过(AI返回为空)", 85)
else:
yield await SSEResponse.send_progress("解析职业体系数据...", 82)
MAX_CAREER_RETRIES = 3 # 最多重试3次
career_retry_count = 0
career_generation_success = False
while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success:
try:
retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else ""
yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10)
# 获取职业生成提示词模板(支持用户自定义)
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
career_prompt = PromptService.format_prompt(
template,
title=project.title,
genre=genre or '未设定',
theme=theme or '未设定',
time_period=world_data.get('time_period', '未设定'),
location=world_data.get('location', '未设定'),
atmosphere=world_data.get('atmosphere', '未设定'),
rules=world_data.get('rules', '未设定')
)
# ✅ 使用流式生成职业体系
career_response = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=career_prompt,
provider=provider,
model=model
):
chunk_count += 1
career_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 职业体系生成独立进度:10-95%
if chunk_count % 5 == 0:
progress = min(10 + (chunk_count // 3), 95)
yield await SSEResponse.send_progress(
f"生成职业体系中... ({len(career_response)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
if not career_response or not career_response.strip():
logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ AI返回为空,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99)
break
yield await SSEResponse.send_progress("解析职业体系数据...", 96)
# 清洗并解析JSON
try:
cleaned_response = user_ai_service._clean_json_response(career_response)
career_data = json.loads(cleaned_response)
logger.info(f"✅ 职业体系JSON解析成功")
logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}")
# 保存主职业
main_careers_created = []
@@ -338,22 +445,51 @@ async def world_building_generator(
await db.commit()
# 标记成功
career_generation_success = True
logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}")
yield await SSEResponse.send_progress(
f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)}",
90
99
)
except json.JSONDecodeError as e:
logger.error(f"❌ 职业体系JSON解析失败: {e}")
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败,已跳过", 85)
logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}: {e}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ JSON解析失败,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99)
except Exception as e:
logger.error(f"❌ 职业体系保存失败: {e}")
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败,已跳过", 85)
except Exception as e:
logger.error(f"❌ 职业体系生成异常: {e}")
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败,已跳过(不影响项目创建)", 85)
logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}: {e}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 保存失败,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99)
except Exception as e:
logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}: {e}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 生成异常,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99)
db_committed = True
@@ -366,7 +502,8 @@ async def world_building_generator(
"rules": world_data.get("rules")
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_progress("职业体系完成", 100, "success")
yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
@@ -473,7 +610,7 @@ async def characters_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮
max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮
tool_choice="auto",
provider=None,
model=None
@@ -611,15 +748,32 @@ async def characters_generator(
else:
prompt = base_prompt
# 流式生成
# 流式生成(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数
if chunk_count % 5 == 0:
progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15)
yield await SSEResponse.send_progress(
f"生成角色中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析批次结果 - 使用统一的JSON清洗方法
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
@@ -1184,18 +1338,35 @@ async def outline_generator(
requirements=outline_requirements
)
# 流式生成大纲
# 流式生成大纲(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=outline_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数(5-95%,AI生成占90%)
if chunk_count % 5 == 0:
progress = min(5 + (chunk_count // 3), 95)
yield await SSEResponse.send_progress(
f"生成大纲中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析大纲结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析大纲...", 40)
yield await SSEResponse.send_progress("解析大纲...", 96)
try:
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
@@ -1208,7 +1379,7 @@ async def outline_generator(
return
# 保存大纲到数据库
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
yield await SSEResponse.send_progress("保存大纲到数据库...", 97)
created_outlines = []
for index, outline_item in enumerate(outline_data[:outline_count], 1):
outline = Outline(
@@ -1231,7 +1402,7 @@ async def outline_generator(
created_chapters = []
if project.outline_mode == 'one-to-one':
# 一对一模式:自动为每个大纲创建对应的章节
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50)
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98)
for outline in created_outlines:
chapter = Chapter(
@@ -1250,10 +1421,10 @@ async def outline_generator(
await db.refresh(chapter)
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85)
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99)
else:
# 一对多模式:跳过自动创建,用户可手动展开
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85)
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99)
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
# 更新项目信息
@@ -1396,7 +1567,7 @@ async def world_building_regenerate_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -1433,44 +1604,109 @@ async def world_building_regenerate_generator(
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
# ===== 流式生成世界观(带重试机制) =====
MAX_WORLD_RETRIES = 3 # 最多重试3次
world_retry_count = 0
world_generation_success = False
world_data = {}
try:
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
world_data = json.loads(cleaned_text)
logger.info(f"✅ 世界观重新生成JSON解析成功")
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
try:
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
except json.JSONDecodeError as e:
logger.error(f"世界构建JSON解析失败: {e}")
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 85)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 检查是否返回空响应
if not accumulated_text or not accumulated_text.strip():
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ AI返回为空,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
logger.error("❌ 世界观重新生成多次返回空响应")
world_data = {
"time_period": "AI多次返回为空,请稍后重试",
"location": "AI多次返回为空,请稍后重试",
"atmosphere": "AI多次返回为空,请稍后重试",
"rules": "AI多次返回为空,请稍后重试"
}
world_generation_success = True
break
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
try:
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
world_data = json.loads(cleaned_text)
logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_generation_success = True
except json.JSONDecodeError as e:
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {e}")
logger.error(f" 原始内容长度: {len(accumulated_text)}")
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ JSON解析失败,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
world_generation_success = True
except Exception as e:
logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {type(e).__name__}: {e}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 生成异常,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 最后一次重试仍失败,抛出异常
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
raise
# 不保存到数据库,仅返回生成结果供用户预览
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
+37 -15
View File
@@ -15,7 +15,6 @@ from ..schemas.writing_style import (
WritingStyleListResponse,
SetDefaultStyleRequest
)
from ..services.prompt_service import WritingStyleManager
from ..logger import get_logger
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
@@ -31,21 +30,36 @@ def get_current_user_id(request: Request) -> str:
@router.get("/presets/list", response_model=List[dict])
async def get_preset_styles():
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
"""
获取所有预设风格列表
获取所有预设风格列表从数据库读取
返回格式数组形式的预设风格列表
[
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": "classical", "name": "古典优雅", ...}
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
]
"""
presets = WritingStyleManager.get_all_presets()
# 将字典转换为数组,添加 id 字段
# 从数据库获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = result.scalars().all()
# 转换为响应格式
return [
{"id": preset_id, **preset_data}
for preset_id, preset_data in presets.items()
{
"id": style.id,
"preset_id": style.preset_id,
"name": style.name,
"description": style.description,
"prompt_content": style.prompt_content,
"style_type": style.style_type,
"order_index": style.order_index
}
for style in preset_styles
]
@@ -58,25 +72,33 @@ async def create_writing_style(
"""
创建新的写作风格用户级别
- **基于预设创建**提供 preset_id系统会自动填充预设内容
- **基于预设创建**提供 preset_id系统会从数据库查询预设内容自动填充
- **完全自定义**不提供 preset_id需要手动填写所有字段
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
# 如果基于预设创建,获取预设内容
# 如果基于预设创建,从数据库获取预设内容
if style_data.preset_id:
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
result = await db.execute(
select(WritingStyle)
.where(
WritingStyle.user_id.is_(None),
WritingStyle.preset_id == style_data.preset_id
)
)
preset = result.scalar_one_or_none()
if not preset:
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
# 使用预设内容填充(如果用户未提供)
if not style_data.name:
style_data.name = preset["name"]
style_data.name = preset.name
if not style_data.description:
style_data.description = preset["description"]
style_data.description = preset.description
if not style_data.prompt_content:
style_data.prompt_content = preset["prompt_content"]
style_data.prompt_content = preset.prompt_content
# 验证必填字段
if not style_data.name or not style_data.prompt_content: