update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
@@ -309,13 +309,37 @@ async def generate_career_system(
|
||||
7. 只返回纯JSON,不要添加任何解释文字
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 30)
|
||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
|
||||
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
||||
|
||||
try:
|
||||
# 调用AI生成
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
# 使用流式生成替代非流式
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
last_progress = 10
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 平滑更新进度(10-90%,AI生成占60%)
|
||||
# 每10个chunk增加约1%的进度,最多到90%
|
||||
if chunk_count % 10 == 0:
|
||||
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
|
||||
current_progress = min(10 + (chunk_count // 10), 90)
|
||||
if current_progress > last_progress:
|
||||
last_progress = current_progress
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
|
||||
current_progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
@@ -326,7 +350,7 @@ async def generate_career_system(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 50)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 91)
|
||||
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
@@ -339,7 +363,7 @@ async def generate_career_system(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("保存主职业...", 60)
|
||||
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
|
||||
|
||||
# 保存主职业
|
||||
main_careers_created = []
|
||||
@@ -371,7 +395,7 @@ async def generate_career_system(
|
||||
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
||||
continue
|
||||
|
||||
yield await SSEResponse.send_progress("保存副职业...", 80)
|
||||
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
|
||||
|
||||
# 保存副职业
|
||||
sub_careers_created = []
|
||||
|
||||
+47
-12
@@ -1070,8 +1070,7 @@ async def analyze_chapter_background(
|
||||
|
||||
if career_update_result['updated_count'] > 0:
|
||||
logger.info(
|
||||
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息: "
|
||||
f"{', '.join(career_update_result['updated_characters'])}"
|
||||
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息"
|
||||
)
|
||||
if career_update_result['changes']:
|
||||
for change in career_update_result['changes']:
|
||||
@@ -1445,7 +1444,7 @@ async def generate_chapter_content_stream(
|
||||
user_id=current_user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1596,10 +1595,24 @@ async def generate_chapter_content_stream(
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 发送开始生成的进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 35, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||
|
||||
{style_content}
|
||||
|
||||
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||
|
||||
# 准备生成参数
|
||||
generate_kwargs = {"prompt": prompt}
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
}
|
||||
if custom_model:
|
||||
logger.info(f" 使用自定义模型: {custom_model}")
|
||||
generate_kwargs["model"] = custom_model
|
||||
@@ -1618,11 +1631,14 @@ async def generate_chapter_content_stream(
|
||||
# 发送内容块
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 每20个chunk发送一次进度更新(提高频率)
|
||||
if chunk_count % 20 == 0:
|
||||
# 每5个chunk发送一次进度更新(10-95%,更平滑)
|
||||
if chunk_count % 5 == 0:
|
||||
current_word_count = len(full_content)
|
||||
# 根据目标字数估算进度(40%起步,最高95%,为后续保存留5%)
|
||||
estimated_progress = min(95, 40 + int((current_word_count / target_word_count) * 55))
|
||||
# 优化进度计算:使用更平滑的递增方式
|
||||
# 基于chunk数量和字数的混合计算,避免大幅跳跃
|
||||
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
|
||||
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
|
||||
estimated_progress = min(95, 10 + chunk_progress + word_progress)
|
||||
|
||||
# 只在进度变化时发送
|
||||
if estimated_progress > last_progress:
|
||||
@@ -1636,10 +1652,14 @@ async def generate_chapter_content_stream(
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
last_progress = estimated_progress
|
||||
|
||||
# 每20个chunk发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
# 发送保存进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 98, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 更新章节内容到数据库
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
@@ -1696,7 +1716,7 @@ async def generate_chapter_content_stream(
|
||||
)
|
||||
|
||||
# 发送最终进度100%
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 100, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送完成事件(包含分析任务ID)
|
||||
completion_data = {
|
||||
@@ -2880,15 +2900,30 @@ async def generate_single_chapter_for_batch(
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# 🎨 方案一:将写作风格注入到系统提示词(批量生成)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||
|
||||
{style_content}
|
||||
|
||||
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||
logger.info(f"✅ 批量生成 - 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||
|
||||
# 非流式生成内容
|
||||
full_content = ""
|
||||
# 准备生成参数
|
||||
generate_kwargs = {"prompt": prompt}
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
}
|
||||
# 如果传入了自定义模型,使用指定的模型
|
||||
if custom_model:
|
||||
generate_kwargs["model"] = custom_model
|
||||
logger.info(f" 批量生成使用自定义模型: {custom_model}")
|
||||
|
||||
# 批量生成中的流式生成(非SSE,不需要修改进度显示)
|
||||
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
|
||||
full_content += chunk
|
||||
|
||||
|
||||
+120
-20
@@ -662,10 +662,10 @@ async def generate_character_stream(
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(request.project_id, user_id, db)
|
||||
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 0)
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 1)
|
||||
|
||||
# 获取已存在的角色列表
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 10)
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 2)
|
||||
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
@@ -757,7 +757,7 @@ async def generate_character_stream(
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 3)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
||||
@@ -768,11 +768,14 @@ async def generate_character_stream(
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
|
||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||
|
||||
try:
|
||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
@@ -789,7 +792,7 @@ async def generate_character_stream(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # 减少为1轮,避免超时
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -797,22 +800,119 @@ async def generate_character_stream(
|
||||
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
if result.get('tool_calls_made', 0) > 0:
|
||||
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
|
||||
finish_reason = result.get('finish_reason', '')
|
||||
tool_calls_made = result.get('tool_calls_made', 0)
|
||||
|
||||
# 🔧 修复:检查工具调用是否真正成功
|
||||
if tool_calls_made > 0:
|
||||
if finish_reason == 'tool_error':
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||
# 工具调用失败,重新用基础模式生成
|
||||
ai_response = ""
|
||||
elif not ai_response.strip():
|
||||
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||
# 工具调用成功但返回空内容,重新生成
|
||||
ai_response = ""
|
||||
else:
|
||||
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||
chunk_size = 50
|
||||
for i in range(0, len(ai_response), chunk_size):
|
||||
chunk = ai_response[i:i+chunk_size]
|
||||
chunk_count += 1
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 3 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||
)
|
||||
|
||||
# 跳过后续的流式生成
|
||||
ai_response = result.get('content', '')
|
||||
else:
|
||||
ai_response = result
|
||||
|
||||
# 如果MCP调用失败或返回空,继续走流式生成
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.info(f"🔄 开始流式生成...")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as mcp_error:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.debug(f"未登录用户,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
@@ -823,7 +923,7 @@ async def generate_character_stream(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
@@ -836,7 +936,7 @@ async def generate_character_stream(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("创建角色记录...", 75)
|
||||
yield await SSEResponse.send_progress("创建角色记录...", 97)
|
||||
|
||||
# 转换traits
|
||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
|
||||
|
||||
# 如果是组织,创建Organization详情
|
||||
if is_organization:
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||
|
||||
org_check = await db.execute(
|
||||
select(Organization).where(Organization.character_id == character.id)
|
||||
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
|
||||
generated_content=ai_response,
|
||||
model=user_ai_service.default_model
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
+204
-28
@@ -105,23 +105,27 @@ async def generate_options(
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板(根据step确定模板key)
|
||||
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
|
||||
template_key_map = {
|
||||
"title": "INSPIRATION_TITLE",
|
||||
"description": "INSPIRATION_DESCRIPTION",
|
||||
"theme": "INSPIRATION_THEME",
|
||||
"genre": "INSPIRATION_GENRE"
|
||||
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||
}
|
||||
template_key = template_key_map.get(step)
|
||||
template_keys = template_key_map.get(step)
|
||||
|
||||
if not template_key:
|
||||
if not template_keys:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
# 获取自定义提示词模板
|
||||
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
|
||||
system_key, user_key = template_keys
|
||||
|
||||
# 获取自定义提示词模板(分别获取 system 和 user)
|
||||
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
@@ -131,19 +135,9 @@ async def generate_options(
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
|
||||
# 尝试解析为JSON格式的字典
|
||||
try:
|
||||
prompt_template = json.loads(prompt_template_str)
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# 如果不是JSON格式,降级使用原有方法
|
||||
prompt_template = prompt_service.get_inspiration_prompt(step)
|
||||
if not prompt_template:
|
||||
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
# 格式化提示词
|
||||
system_prompt = system_template.format(**format_params)
|
||||
user_prompt = user_template.format(**format_params)
|
||||
|
||||
# 如果是重试,在提示词中强调格式要求
|
||||
if attempt > 0:
|
||||
@@ -153,13 +147,18 @@ async def generate_options(
|
||||
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
|
||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
|
||||
response = await ai_service.generate_text(
|
||||
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = response.get("content", "")
|
||||
response = {"content": accumulated_text}
|
||||
content = accumulated_text
|
||||
logger.info(f"AI返回内容长度: {len(content)}")
|
||||
|
||||
# 解析JSON(使用统一的JSON清洗方法)
|
||||
@@ -222,6 +221,180 @@ async def generate_options(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refine-options")
|
||||
async def refine_options(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
基于用户反馈重新生成选项(支持多轮对话)
|
||||
|
||||
Request:
|
||||
{
|
||||
"step": "title", // 当前步骤
|
||||
"context": {
|
||||
"initial_idea": "...",
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"theme": "..."
|
||||
},
|
||||
"feedback": "我想要更悲剧一些的主题", // 用户反馈
|
||||
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"prompt": "引导语",
|
||||
"options": ["新选项1", "新选项2", ...]
|
||||
}
|
||||
"""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step = data.get("step", "title")
|
||||
context = data.get("context", {})
|
||||
feedback = data.get("feedback", "")
|
||||
previous_options = data.get("previous_options", [])
|
||||
|
||||
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
|
||||
logger.info(f"用户反馈: {feedback}")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板
|
||||
template_key_map = {
|
||||
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||
}
|
||||
template_keys = template_key_map.get(step)
|
||||
|
||||
if not template_keys:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
system_key, user_key = template_keys
|
||||
|
||||
# 获取自定义提示词模板
|
||||
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
"initial_idea": context.get("initial_idea", context.get("description", "")),
|
||||
"title": context.get("title", ""),
|
||||
"description": context.get("description", ""),
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化提示词
|
||||
system_prompt = system_template.format(**format_params)
|
||||
user_prompt = user_template.format(**format_params)
|
||||
|
||||
# 添加反馈信息到提示词
|
||||
feedback_instruction = f"""
|
||||
|
||||
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
|
||||
「{feedback}」
|
||||
|
||||
之前生成的选项:
|
||||
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
|
||||
|
||||
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
|
||||
注意:
|
||||
1. 仔细理解用户的反馈意图
|
||||
2. 生成的新选项要明显体现用户要求的调整方向
|
||||
3. 保持与已有上下文的一致性
|
||||
4. 确保返回6个有效选项
|
||||
"""
|
||||
|
||||
system_prompt += feedback_instruction
|
||||
|
||||
# 如果是重试,强调格式要求
|
||||
if attempt > 0:
|
||||
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
|
||||
|
||||
# 调用AI生成选项
|
||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||
# 反馈生成时使用稍高的temperature以获得更多样化的结果
|
||||
temperature = min(temperature + 0.1, 0.9)
|
||||
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
|
||||
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = accumulated_text
|
||||
logger.info(f"AI返回内容长度: {len(content)}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
result = json.loads(cleaned_content)
|
||||
|
||||
# 校验返回格式
|
||||
is_valid, error_msg = validate_options_response(result, step)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次"
|
||||
}
|
||||
|
||||
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"第{attempt + 1}次JSON解析失败: {e}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("JSON解析失败,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI返回格式错误,已自动重试{max_retries}次"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("发生异常,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"error": str(e),
|
||||
"prompt": "生成失败,请重试",
|
||||
"options": ["重新生成", "我自己输入"]
|
||||
}
|
||||
|
||||
return {
|
||||
"error": "生成失败",
|
||||
"prompt": "请重试",
|
||||
"options": []
|
||||
}
|
||||
|
||||
|
||||
@router.post("/quick-generate")
|
||||
async def quick_generate(
|
||||
data: Dict[str, Any],
|
||||
@@ -280,14 +453,17 @@ async def quick_generate(
|
||||
# 降级使用原有方法
|
||||
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
|
||||
|
||||
# 调用AI
|
||||
response = await ai_service.generate_text(
|
||||
# 调用AI - 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
temperature=0.7
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = response.get("content", "")
|
||||
response = {"content": accumulated_text}
|
||||
content = accumulated_text
|
||||
|
||||
# 解析JSON(使用统一的JSON清洗方法)
|
||||
try:
|
||||
|
||||
@@ -512,8 +512,29 @@ async def generate_organization_stream(
|
||||
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
||||
|
||||
try:
|
||||
ai_response = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else str(ai_response)
|
||||
# 使用流式生成替代非流式
|
||||
ai_content = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_content += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新字数(5-95%,AI生成占90%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(5 + (chunk_count // 5), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成组织中... ({len(ai_content)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
@@ -523,7 +544,7 @@ async def generate_organization_stream(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
@@ -536,7 +557,7 @@ async def generate_organization_stream(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("创建组织记录...", 75)
|
||||
yield await SSEResponse.send_progress("创建组织记录...", 97)
|
||||
|
||||
# 创建角色记录(组织也是角色的一种)
|
||||
character = Character(
|
||||
@@ -563,7 +584,7 @@ async def generate_organization_stream(
|
||||
|
||||
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
||||
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||
|
||||
# 自动创建Organization详情记录
|
||||
organization = Organization(
|
||||
@@ -580,7 +601,7 @@ async def generate_organization_stream(
|
||||
|
||||
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
||||
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
|
||||
+89
-30
@@ -470,7 +470,7 @@ async def _generate_new_outline(
|
||||
project: Project,
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
user_id: str = None
|
||||
user_id: str
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲(MCP增强版)"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
|
||||
@@ -534,7 +534,7 @@ async def _generate_new_outline(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -573,15 +573,23 @@ async def _generate_new_outline(
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI生成大纲
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# 调用AI流式生成大纲(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 这里是非SSE接口,不需要发送chunk
|
||||
# 如果未来需要转SSE,可以在这里yield
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
@@ -732,7 +740,7 @@ async def _continue_outline(
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
user_id: str
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
|
||||
@@ -1000,7 +1008,7 @@ async def _continue_outline(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1045,15 +1053,22 @@ async def _continue_outline(
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 这里是非SSE接口,不需要发送chunk
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
|
||||
user_id=user_id_for_mcp,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
# 调用AI流式生成
|
||||
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
|
||||
|
||||
# 添加调试日志
|
||||
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
|
||||
logger.info(f"=== 大纲生成AI调用参数 ===")
|
||||
logger.info(f" provider参数: {provider_param}")
|
||||
logger.info(f" model参数: {model_param}")
|
||||
logger.info(f" 完整data: {data}")
|
||||
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# ✅ 流式生成(带字数统计和进度)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider_param,
|
||||
model=model_param
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数(30-95%,AI生成占65%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 2), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成大纲中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 全新生成模式:删除旧大纲和关联的所有章节
|
||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
|
||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
|
||||
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})")
|
||||
|
||||
from sqlalchemy import delete as sql_delete
|
||||
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
|
||||
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
|
||||
|
||||
# 保存新大纲
|
||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
|
||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
|
||||
outlines = await _save_outlines(
|
||||
project_id, outline_data, db, start_index=1
|
||||
)
|
||||
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
|
||||
for outline in outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
yield await SSEResponse.send_progress("整理结果数据...", 95)
|
||||
yield await SSEResponse.send_progress("整理结果数据...", 99)
|
||||
|
||||
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
||||
|
||||
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
|
||||
logger.info(f" provider参数: {provider_param}")
|
||||
logger.info(f" model参数: {model_param}")
|
||||
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider_param,
|
||||
model=model_param
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度(每批占用约50%的进度空间)
|
||||
if chunk_count % 5 == 0:
|
||||
# 在批次范围内平滑递增(从10到85,总共75%)
|
||||
batch_range = 60 // total_batches # 总进度60%分配给所有批次
|
||||
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
|
||||
progress_in_batch
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
|
||||
batch_progress + 10
|
||||
)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
# 提取内容
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
+25
-11
@@ -73,14 +73,15 @@ async def get_user_ai_service(
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例
|
||||
# 使用用户设置创建AI服务实例(包括系统提示词)
|
||||
return create_user_ai_service(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
max_tokens=settings.max_tokens,
|
||||
system_prompt=settings.system_prompt # 传递系统提示词
|
||||
)
|
||||
|
||||
|
||||
@@ -271,17 +272,30 @@ async def get_available_models(
|
||||
}
|
||||
|
||||
elif provider == "anthropic":
|
||||
# Anthropic 没有公开的模型列表API
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
|
||||
)
|
||||
# Anthropic models API
|
||||
url = f"{api_base_url.rstrip('/')}/v1/models"
|
||||
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
|
||||
return {"provider": provider, "models": models, "count": len(models)}
|
||||
|
||||
elif provider == "gemini":
|
||||
# Gemini models API
|
||||
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = []
|
||||
for m in data.get("models", []):
|
||||
if "generateContent" in m.get("supportedGenerationMethods", []):
|
||||
mid = m.get("name", "").replace("models/", "")
|
||||
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
|
||||
return {"provider": provider, "models": models, "count": len(models)}
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的提供商: {provider}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
||||
|
||||
+362
-126
@@ -99,7 +99,7 @@ async def world_building_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -139,51 +139,118 @@ async def world_building_generator(
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
# ===== 流式生成世界观(带重试机制) =====
|
||||
MAX_WORLD_RETRIES = 3 # 最多重试3次
|
||||
world_retry_count = 0
|
||||
world_generation_success = False
|
||||
world_data = {}
|
||||
try:
|
||||
# ✅ 使用 AIService 的统一清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观JSON解析成功")
|
||||
|
||||
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
|
||||
try:
|
||||
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
|
||||
yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 世界构建JSON解析失败: {e}")
|
||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 世界观生成独立进度:5-95%
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(5 + (chunk_count // 3), 95)
|
||||
yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 检查是否返回空响应
|
||||
if not accumulated_text or not accumulated_text.strip():
|
||||
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI返回为空,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
logger.error("❌ 世界观生成多次返回空响应")
|
||||
world_data = {
|
||||
"time_period": "AI多次返回为空,请稍后重试",
|
||||
"location": "AI多次返回为空,请稍后重试",
|
||||
"atmosphere": "AI多次返回为空,请稍后重试",
|
||||
"rules": "AI多次返回为空,请稍后重试"
|
||||
}
|
||||
world_generation_success = True # 标记为成功以继续流程
|
||||
break
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析世界观数据...", 96)
|
||||
|
||||
try:
|
||||
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
|
||||
logger.info(f" 原始内容预览: {accumulated_text[:300]}...")
|
||||
|
||||
# ✅ 使用 AIService 的统一清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||
logger.info(f" 清洗后预览: {cleaned_text[:300]}...")
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_generation_success = True # 解析成功,标记完成
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}")
|
||||
logger.error(f" 原始内容长度: {len(accumulated_text)}")
|
||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ JSON解析失败,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
world_generation_success = True # 标记为成功以继续流程
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成异常,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 最后一次重试仍失败,抛出异常
|
||||
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
|
||||
raise
|
||||
# 保存到数据库
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
yield await SSEResponse.send_progress("保存世界观到数据库...", 99)
|
||||
|
||||
# 确保user_id存在
|
||||
if not user_id:
|
||||
@@ -240,41 +307,81 @@ async def world_building_generator(
|
||||
project.wizard_step = 1
|
||||
await db.commit()
|
||||
|
||||
# ===== 自动生成职业体系 =====
|
||||
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 75)
|
||||
# ===== 自动生成职业体系(带重试机制+流式) =====
|
||||
yield await SSEResponse.send_progress("世界观完成!", 100, "success")
|
||||
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5)
|
||||
logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系")
|
||||
|
||||
try:
|
||||
# 获取职业生成提示词模板(支持用户自定义)
|
||||
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
|
||||
career_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
genre=genre or '未设定',
|
||||
theme=theme or '未设定',
|
||||
time_period=world_data.get('time_period', '未设定'),
|
||||
location=world_data.get('location', '未设定'),
|
||||
atmosphere=world_data.get('atmosphere', '未设定'),
|
||||
rules=world_data.get('rules', '未设定')
|
||||
)
|
||||
|
||||
yield await SSEResponse.send_progress("正在生成职业体系...", 78)
|
||||
|
||||
# 调用AI生成职业
|
||||
result = await user_ai_service.generate_text(prompt=career_prompt)
|
||||
career_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
|
||||
if not career_response or not career_response.strip():
|
||||
logger.warning("⚠️ AI返回空职业体系,跳过职业生成")
|
||||
yield await SSEResponse.send_progress("职业体系生成跳过(AI返回为空)", 85)
|
||||
else:
|
||||
yield await SSEResponse.send_progress("解析职业体系数据...", 82)
|
||||
MAX_CAREER_RETRIES = 3 # 最多重试3次
|
||||
career_retry_count = 0
|
||||
career_generation_success = False
|
||||
|
||||
while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success:
|
||||
try:
|
||||
retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else ""
|
||||
yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10)
|
||||
|
||||
# 获取职业生成提示词模板(支持用户自定义)
|
||||
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
|
||||
career_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
genre=genre or '未设定',
|
||||
theme=theme or '未设定',
|
||||
time_period=world_data.get('time_period', '未设定'),
|
||||
location=world_data.get('location', '未设定'),
|
||||
atmosphere=world_data.get('atmosphere', '未设定'),
|
||||
rules=world_data.get('rules', '未设定')
|
||||
)
|
||||
|
||||
# ✅ 使用流式生成职业体系
|
||||
career_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=career_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
career_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 职业体系生成独立进度:10-95%
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(10 + (chunk_count // 3), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"生成职业体系中... ({len(career_response)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
if not career_response or not career_response.strip():
|
||||
logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI返回为空,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99)
|
||||
break
|
||||
|
||||
yield await SSEResponse.send_progress("解析职业体系数据...", 96)
|
||||
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(career_response)
|
||||
career_data = json.loads(cleaned_response)
|
||||
logger.info(f"✅ 职业体系JSON解析成功")
|
||||
logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||
|
||||
# 保存主职业
|
||||
main_careers_created = []
|
||||
@@ -338,22 +445,51 @@ async def world_building_generator(
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 标记成功
|
||||
career_generation_success = True
|
||||
logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}个")
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)})",
|
||||
90
|
||||
99
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败,已跳过", 85)
|
||||
logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ JSON解析失败,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 职业体系保存失败: {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败,已跳过", 85)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 职业体系生成异常: {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败,已跳过(不影响项目创建)", 85)
|
||||
logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 保存失败,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成异常,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99)
|
||||
|
||||
db_committed = True
|
||||
|
||||
@@ -366,7 +502,8 @@ async def world_building_generator(
|
||||
"rules": world_data.get("rules")
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_progress("职业体系完成!", 100, "success")
|
||||
yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
@@ -473,7 +610,7 @@ async def characters_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮
|
||||
max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -611,15 +748,32 @@ async def characters_generator(
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# 流式生成
|
||||
# 流式生成(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"生成角色中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析批次结果 - 使用统一的JSON清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
@@ -1184,18 +1338,35 @@ async def outline_generator(
|
||||
requirements=outline_requirements
|
||||
)
|
||||
|
||||
# 流式生成大纲
|
||||
# 流式生成大纲(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=outline_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数(5-95%,AI生成占90%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(5 + (chunk_count // 3), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"生成大纲中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析大纲结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析大纲...", 40)
|
||||
yield await SSEResponse.send_progress("解析大纲...", 96)
|
||||
|
||||
try:
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
@@ -1208,7 +1379,7 @@ async def outline_generator(
|
||||
return
|
||||
|
||||
# 保存大纲到数据库
|
||||
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
|
||||
yield await SSEResponse.send_progress("保存大纲到数据库...", 97)
|
||||
created_outlines = []
|
||||
for index, outline_item in enumerate(outline_data[:outline_count], 1):
|
||||
outline = Outline(
|
||||
@@ -1231,7 +1402,7 @@ async def outline_generator(
|
||||
created_chapters = []
|
||||
if project.outline_mode == 'one-to-one':
|
||||
# 一对一模式:自动为每个大纲创建对应的章节
|
||||
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50)
|
||||
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98)
|
||||
|
||||
for outline in created_outlines:
|
||||
chapter = Chapter(
|
||||
@@ -1250,10 +1421,10 @@ async def outline_generator(
|
||||
await db.refresh(chapter)
|
||||
|
||||
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
|
||||
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85)
|
||||
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99)
|
||||
else:
|
||||
# 一对多模式:跳过自动创建,用户可手动展开
|
||||
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85)
|
||||
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99)
|
||||
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
|
||||
|
||||
# 更新项目信息
|
||||
@@ -1396,7 +1567,7 @@ async def world_building_regenerate_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1433,44 +1604,109 @@ async def world_building_regenerate_generator(
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
# ===== 流式生成世界观(带重试机制) =====
|
||||
MAX_WORLD_RETRIES = 3 # 最多重试3次
|
||||
world_retry_count = 0
|
||||
world_generation_success = False
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观重新生成JSON解析成功")
|
||||
|
||||
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
|
||||
try:
|
||||
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
|
||||
yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"世界构建JSON解析失败: {e}")
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 85)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 检查是否返回空响应
|
||||
if not accumulated_text or not accumulated_text.strip():
|
||||
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI返回为空,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
logger.error("❌ 世界观重新生成多次返回空响应")
|
||||
world_data = {
|
||||
"time_period": "AI多次返回为空,请稍后重试",
|
||||
"location": "AI多次返回为空,请稍后重试",
|
||||
"atmosphere": "AI多次返回为空,请稍后重试",
|
||||
"rules": "AI多次返回为空,请稍后重试"
|
||||
}
|
||||
world_generation_success = True
|
||||
break
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
try:
|
||||
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_generation_success = True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}")
|
||||
logger.error(f" 原始内容长度: {len(accumulated_text)}")
|
||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ JSON解析失败,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
world_generation_success = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成异常,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 最后一次重试仍失败,抛出异常
|
||||
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
|
||||
raise
|
||||
|
||||
# 不保存到数据库,仅返回生成结果供用户预览
|
||||
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
|
||||
|
||||
@@ -15,7 +15,6 @@ from ..schemas.writing_style import (
|
||||
WritingStyleListResponse,
|
||||
SetDefaultStyleRequest
|
||||
)
|
||||
from ..services.prompt_service import WritingStyleManager
|
||||
from ..logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
@@ -31,21 +30,36 @@ def get_current_user_id(request: Request) -> str:
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
async def get_preset_styles():
|
||||
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
获取所有预设风格列表
|
||||
获取所有预设风格列表(从数据库读取)
|
||||
|
||||
返回格式:数组形式的预设风格列表
|
||||
[
|
||||
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": "classical", "name": "古典优雅", ...}
|
||||
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
|
||||
]
|
||||
"""
|
||||
presets = WritingStyleManager.get_all_presets()
|
||||
# 将字典转换为数组,添加 id 字段
|
||||
# 从数据库获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = result.scalars().all()
|
||||
|
||||
# 转换为响应格式
|
||||
return [
|
||||
{"id": preset_id, **preset_data}
|
||||
for preset_id, preset_data in presets.items()
|
||||
{
|
||||
"id": style.id,
|
||||
"preset_id": style.preset_id,
|
||||
"name": style.name,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"style_type": style.style_type,
|
||||
"order_index": style.order_index
|
||||
}
|
||||
for style in preset_styles
|
||||
]
|
||||
|
||||
|
||||
@@ -58,25 +72,33 @@ async def create_writing_style(
|
||||
"""
|
||||
创建新的写作风格(用户级别)
|
||||
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
# 如果基于预设创建,从数据库获取预设内容
|
||||
if style_data.preset_id:
|
||||
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(
|
||||
WritingStyle.user_id.is_(None),
|
||||
WritingStyle.preset_id == style_data.preset_id
|
||||
)
|
||||
)
|
||||
preset = result.scalar_one_or_none()
|
||||
|
||||
if not preset:
|
||||
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
||||
|
||||
# 使用预设内容填充(如果用户未提供)
|
||||
if not style_data.name:
|
||||
style_data.name = preset["name"]
|
||||
style_data.name = preset.name
|
||||
if not style_data.description:
|
||||
style_data.description = preset["description"]
|
||||
style_data.description = preset.description
|
||||
if not style_data.prompt_content:
|
||||
style_data.prompt_content = preset["prompt_content"]
|
||||
style_data.prompt_content = preset.prompt_content
|
||||
|
||||
# 验证必填字段
|
||||
if not style_data.name or not style_data.prompt_content:
|
||||
|
||||
Reference in New Issue
Block a user