update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词

This commit is contained in:
xiamuceer
2025-12-28 19:35:23 +08:00
parent f32e51b594
commit 89848e2258
40 changed files with 2752 additions and 1824 deletions
+89 -30
View File
@@ -470,7 +470,7 @@ async def _generate_new_outline(
project: Project,
db: AsyncSession,
user_ai_service: AIService,
user_id: str = None
user_id: str
) -> OutlineListResponse:
"""全新生成大纲(MCP增强版)"""
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
@@ -534,7 +534,7 @@ async def _generate_new_outline(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -573,15 +573,23 @@ async def _generate_new_outline(
mcp_references=mcp_reference_materials
)
# 调用AI生成大纲
ai_response = await user_ai_service.generate_text(
# 调用AI流式生成大纲(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=request.provider,
model=request.model
)
):
chunk_count += 1
accumulated_text += chunk
# 这里是非SSE接口,不需要发送chunk
# 如果未来需要转SSE,可以在这里yield
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
@@ -732,7 +740,7 @@ async def _continue_outline(
existing_outlines: List[Outline],
db: AsyncSession,
user_ai_service: AIService,
user_id: str = "system"
user_id: str
) -> OutlineListResponse:
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
@@ -1000,7 +1008,7 @@ async def _continue_outline(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1045,15 +1053,22 @@ async def _continue_outline(
)
# 调用AI生成当前批次
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
ai_response = await user_ai_service.generate_text(
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=request.provider,
model=request.model
)
):
chunk_count += 1
accumulated_text += chunk
# 这里是非SSE接口,不需要发送chunk
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
user_id=user_id_for_mcp,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
mcp_references=mcp_reference_materials
)
# 调用AI
# 调用AI流式生成
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
# 添加调试日志
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
logger.info(f"=== 大纲生成AI调用参数 ===")
logger.info(f" provider参数: {provider_param}")
logger.info(f" model参数: {model_param}")
logger.info(f" 完整data: {data}")
ai_response = await user_ai_service.generate_text(
# ✅ 流式生成(带字数统计和进度)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider_param,
model=model_param
)
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数(30-95%,AI生成占65%
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 2), 95)
yield await SSEResponse.send_progress(
f"AI生成大纲中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
# 全新生成模式:删除旧大纲和关联的所有章节
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode}")
from sqlalchemy import delete as sql_delete
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
# 保存新大纲
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
outlines = await _save_outlines(
project_id, outline_data, db, start_index=1
)
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
for outline in outlines:
await db.refresh(outline)
yield await SSEResponse.send_progress("整理结果数据...", 95)
yield await SSEResponse.send_progress("整理结果数据...", 99)
logger.info(f"全新生成完成 - {len(outlines)}")
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
logger.info(f" provider参数: {provider_param}")
logger.info(f" model参数: {model_param}")
ai_response = await user_ai_service.generate_text(
# 流式生成并累积文本
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider_param,
model=model_param
)
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度(每批占用约50%的进度空间)
if chunk_count % 5 == 0:
# 在批次范围内平滑递增(从10到85,总共75%)
batch_range = 60 // total_batches # 总进度60%分配给所有批次
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
yield await SSEResponse.send_progress(
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
progress_in_batch
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await SSEResponse.send_progress(
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
batch_progress + 10
)
# 提取内容generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
# 提取内容
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)