update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
+89
-30
@@ -470,7 +470,7 @@ async def _generate_new_outline(
|
||||
project: Project,
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
user_id: str = None
|
||||
user_id: str
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲(MCP增强版)"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
|
||||
@@ -534,7 +534,7 @@ async def _generate_new_outline(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -573,15 +573,23 @@ async def _generate_new_outline(
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI生成大纲
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# 调用AI流式生成大纲(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 这里是非SSE接口,不需要发送chunk
|
||||
# 如果未来需要转SSE,可以在这里yield
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
@@ -732,7 +740,7 @@ async def _continue_outline(
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
user_id: str
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
|
||||
@@ -1000,7 +1008,7 @@ async def _continue_outline(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1045,15 +1053,22 @@ async def _continue_outline(
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 这里是非SSE接口,不需要发送chunk
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
|
||||
user_id=user_id_for_mcp,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
# 调用AI流式生成
|
||||
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
|
||||
|
||||
# 添加调试日志
|
||||
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
|
||||
logger.info(f"=== 大纲生成AI调用参数 ===")
|
||||
logger.info(f" provider参数: {provider_param}")
|
||||
logger.info(f" model参数: {model_param}")
|
||||
logger.info(f" 完整data: {data}")
|
||||
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# ✅ 流式生成(带字数统计和进度)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider_param,
|
||||
model=model_param
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数(30-95%,AI生成占65%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 2), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成大纲中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 全新生成模式:删除旧大纲和关联的所有章节
|
||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
|
||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
|
||||
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})")
|
||||
|
||||
from sqlalchemy import delete as sql_delete
|
||||
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
|
||||
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
|
||||
|
||||
# 保存新大纲
|
||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
|
||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
|
||||
outlines = await _save_outlines(
|
||||
project_id, outline_data, db, start_index=1
|
||||
)
|
||||
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
|
||||
for outline in outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
yield await SSEResponse.send_progress("整理结果数据...", 95)
|
||||
yield await SSEResponse.send_progress("整理结果数据...", 99)
|
||||
|
||||
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
||||
|
||||
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
|
||||
logger.info(f" provider参数: {provider_param}")
|
||||
logger.info(f" model参数: {model_param}")
|
||||
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider_param,
|
||||
model=model_param
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度(每批占用约50%的进度空间)
|
||||
if chunk_count % 5 == 0:
|
||||
# 在批次范围内平滑递增(从10到85,总共75%)
|
||||
batch_range = 60 // total_batches # 总进度60%分配给所有批次
|
||||
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
|
||||
progress_in_batch
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
|
||||
batch_progress + 10
|
||||
)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
# 提取内容
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
Reference in New Issue
Block a user