update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词

This commit is contained in:
xiamuceer
2025-12-28 19:35:23 +08:00
parent f32e51b594
commit 89848e2258
40 changed files with 2752 additions and 1824 deletions
@@ -0,0 +1,30 @@
"""添加system_prompt字段到settings表
Revision ID: a7e4408e1d5b
Revises: e411428f00c0
Create Date: 2025-12-27 15:41:22.310160
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'a7e4408e1d5b'
down_revision: Union[str, None] = 'e411428f00c0'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('settings', sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('settings', 'system_prompt')
# ### end Alembic commands ###
@@ -0,0 +1,177 @@
"""初始化SQLite预置数据
Revision ID: a1b2c3d4e5f6
Revises: fbeb1038c728
Create Date: 2025-12-27 08:56:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Text
# revision identifiers, used by Alembic.
revision: str = 'a1b2c3d4e5f6'
down_revision: Union[str, None] = 'fbeb1038c728'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""插入预置数据"""
# ==================== 1. 插入关系类型数据 ====================
relationship_types_table = table(
'relationship_types',
column('name', String),
column('category', String),
column('reverse_name', String),
column('intimacy_range', String),
column('icon', String),
column('description', Text),
)
relationship_types_data = [
# 家庭关系
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
# 社交关系
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
# 职业关系
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
# 敌对关系
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "", "description": "宿命之敌"},
]
op.bulk_insert(relationship_types_table, relationship_types_data)
print(f"✅ SQLite: 已插入 {len(relationship_types_data)} 条关系类型数据")
# ==================== 2. 插入全局写作风格预设 ====================
writing_styles_table = table(
'writing_styles',
column('user_id', String),
column('name', String),
column('style_type', String),
column('preset_id', String),
column('description', Text),
column('prompt_content', Text),
column('order_index', Integer),
)
writing_styles_data = [
{
"user_id": None, # NULL 表示全局预设
"name": "自然流畅",
"style_type": "preset",
"preset_id": "natural",
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
"prompt_content": """写作风格要求:
1. 语言简洁明快,贴近现代口语
2. 多用短句,节奏流畅
3. 注重情感细节的自然流露
4. 避免过度修饰和复杂句式""",
"order_index": 1
},
{
"user_id": None,
"name": "古典优雅",
"style_type": "preset",
"preset_id": "classical",
"description": "古典文雅的写作风格,适合古装、仙侠题材",
"prompt_content": """写作风格要求:
1. 使用文言、半文言或典雅的白话
2. 适当运用古典诗词意象
3. 注重意境营造和韵味
4. 对话和描写保持古典美感""",
"order_index": 2
},
{
"user_id": None,
"name": "现代简约",
"style_type": "preset",
"preset_id": "modern",
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
"prompt_content": """写作风格要求:
1. 语言直白简练,信息密度高
2. 多用对话推进情节
3. 避免冗长描写,突出关键动作
4. 节奏明快,适合快速阅读""",
"order_index": 3
},
{
"user_id": None,
"name": "文艺细腻",
"style_type": "preset",
"preset_id": "literary",
"description": "文艺细腻风格,注重心理描写和氛围营造",
"prompt_content": """写作风格要求:
1. 注重心理活动和情感细节
2. 善用环境描写烘托氛围
3. 语言优美,富有文学性
4. 适当使用比喻、象征等修辞手法""",
"order_index": 4
},
{
"user_id": None,
"name": "紧张悬疑",
"style_type": "preset",
"preset_id": "suspense",
"description": "紧张悬疑风格,适合推理、惊悚题材",
"prompt_content": """写作风格要求:
1. 营造紧张压迫的氛围
2. 多用短句加快节奏
3. 善于设置悬念和伏笔
4. 注重细节描写,为推理埋下线索""",
"order_index": 5
},
{
"user_id": None,
"name": "幽默诙谐",
"style_type": "preset",
"preset_id": "humorous",
"description": "幽默诙谐风格,适合轻松搞笑题材",
"prompt_content": """写作风格要求:
1. 语言活泼风趣,善用俏皮话
2. 注重对话的喜剧效果
3. 适当夸张和反转制造笑点
4. 保持轻松愉快的基调""",
"order_index": 6
},
]
op.bulk_insert(writing_styles_table, writing_styles_data)
print(f"✅ SQLite: 已插入 {len(writing_styles_data)} 条全局写作风格预设")
def downgrade() -> None:
"""删除预置数据"""
# 删除写作风格预设(只删除全局预设)
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
print("✅ SQLite: 已删除全局写作风格预设")
# 删除关系类型
op.execute("DELETE FROM relationship_types")
print("✅ SQLite: 已删除关系类型数据")
@@ -0,0 +1,34 @@
"""添加system_prompt字段到settings表
Revision ID: 7899f8d4d839
Revises: a1b2c3d4e5f6
Create Date: 2025-12-27 17:00:35.440551
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '7899f8d4d839'
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('settings', schema=None) as batch_op:
batch_op.add_column(sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('settings', schema=None) as batch_op:
batch_op.drop_column('system_prompt')
# ### end Alembic commands ###
+31 -7
View File
@@ -309,13 +309,37 @@ async def generate_career_system(
7. 只返回纯JSON,不要添加任何解释文字
"""
yield await SSEResponse.send_progress("调用AI生成新职业...", 30)
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
try:
# 调用AI生成
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
# 使用流式生成替代非流式
ai_response = ""
chunk_count = 0
last_progress = 10
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 平滑更新进度(10-90%,AI生成占60%
# 每10个chunk增加约1%的进度,最多到90%
if chunk_count % 10 == 0:
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
current_progress = min(10 + (chunk_count // 10), 90)
if current_progress > last_progress:
last_progress = current_progress
yield await SSEResponse.send_progress(
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
current_progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
@@ -326,7 +350,7 @@ async def generate_career_system(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 50)
yield await SSEResponse.send_progress("解析AI响应...", 91)
# 清洗并解析JSON
try:
@@ -339,7 +363,7 @@ async def generate_career_system(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("保存主职业...", 60)
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
# 保存主职业
main_careers_created = []
@@ -371,7 +395,7 @@ async def generate_career_system(
logger.error(f" ❌ 创建主职业失败:{str(e)}")
continue
yield await SSEResponse.send_progress("保存副职业...", 80)
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
# 保存副职业
sub_careers_created = []
+47 -12
View File
@@ -1070,8 +1070,7 @@ async def analyze_chapter_background(
if career_update_result['updated_count'] > 0:
logger.info(
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息: "
f"{', '.join(career_update_result['updated_characters'])}"
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息"
)
if career_update_result['changes']:
for change in career_update_result['changes']:
@@ -1445,7 +1444,7 @@ async def generate_chapter_content_stream(
user_id=current_user_id,
db_session=db_session,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1596,10 +1595,24 @@ async def generate_chapter_content_stream(
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 发送开始生成的进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 35, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
system_prompt_with_style = None
if style_content:
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
{style_content}
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
确保在整个章节创作过程中始终保持风格的一致性。"""
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
# 准备生成参数
generate_kwargs = {"prompt": prompt}
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
}
if custom_model:
logger.info(f" 使用自定义模型: {custom_model}")
generate_kwargs["model"] = custom_model
@@ -1618,11 +1631,14 @@ async def generate_chapter_content_stream(
# 发送内容块
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
# 每20个chunk发送一次进度更新(提高频率
if chunk_count % 20 == 0:
# 每5个chunk发送一次进度更新(10-95%,更平滑
if chunk_count % 5 == 0:
current_word_count = len(full_content)
# 根据目标字数估算进度(40%起步,最高95%,为后续保存留5%)
estimated_progress = min(95, 40 + int((current_word_count / target_word_count) * 55))
# 优化进度计算:使用更平滑的递增方式
# 基于chunk数量和字数的混合计算,避免大幅跳跃
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
estimated_progress = min(95, 10 + chunk_progress + word_progress)
# 只在进度变化时发送
if estimated_progress > last_progress:
@@ -1636,10 +1652,14 @@ async def generate_chapter_content_stream(
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
last_progress = estimated_progress
# 每20个chunk发送心跳
if chunk_count % 20 == 0:
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
await asyncio.sleep(0) # 让出控制权
# 发送保存进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 98, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
@@ -1696,7 +1716,7 @@ async def generate_chapter_content_stream(
)
# 发送最终进度100%
yield f"data: {json.dumps({'type': 'progress', 'progress': 100, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
# 发送完成事件(包含分析任务ID
completion_data = {
@@ -2880,15 +2900,30 @@ async def generate_single_chapter_for_batch(
else:
prompt = base_prompt
# 🎨 方案一:将写作风格注入到系统提示词(批量生成)
system_prompt_with_style = None
if style_content:
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
{style_content}
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
确保在整个章节创作过程中始终保持风格的一致性。"""
logger.info(f"✅ 批量生成 - 已将写作风格注入系统提示词({len(style_content)}字符)")
# 非流式生成内容
full_content = ""
# 准备生成参数
generate_kwargs = {"prompt": prompt}
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
}
# 如果传入了自定义模型,使用指定的模型
if custom_model:
generate_kwargs["model"] = custom_model
logger.info(f" 批量生成使用自定义模型: {custom_model}")
# 批量生成中的流式生成(非SSE,不需要修改进度显示)
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
full_content += chunk
+120 -20
View File
@@ -662,10 +662,10 @@ async def generate_character_stream(
user_id = getattr(http_request.state, 'user_id', None)
project = await verify_project_access(request.project_id, user_id, db)
yield await SSEResponse.send_progress("开始生成角色...", 0)
yield await SSEResponse.send_progress("开始生成角色...", 1)
# 获取已存在的角色列表
yield await SSEResponse.send_progress("获取项目上下文...", 10)
yield await SSEResponse.send_progress("获取项目上下文...", 2)
existing_chars_result = await db.execute(
select(Character)
@@ -757,7 +757,7 @@ async def generate_character_stream(
- 其他要求:{request.requirements or ''}
"""
yield await SSEResponse.send_progress("构建AI提示词...", 20)
yield await SSEResponse.send_progress("构建AI提示词...", 3)
# 获取自定义提示词模板
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
@@ -768,11 +768,14 @@ async def generate_character_stream(
user_input=user_input
)
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
try:
# 🔧 MCP工具增强:静默检查并收集参考资料
ai_response = ""
chunk_count = 0
if user_id:
try:
from app.services.mcp_tool_service import mcp_tool_service
@@ -789,7 +792,7 @@ async def generate_character_stream(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # 减少为1轮,避免超时
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -797,22 +800,119 @@ async def generate_character_stream(
if isinstance(result, dict):
ai_response = result.get('content', '')
if result.get('tool_calls_made', 0) > 0:
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
finish_reason = result.get('finish_reason', '')
tool_calls_made = result.get('tool_calls_made', 0)
# 🔧 修复:检查工具调用是否真正成功
if tool_calls_made > 0:
if finish_reason == 'tool_error':
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
# 工具调用失败,重新用基础模式生成
ai_response = ""
elif not ai_response.strip():
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
# 工具调用成功但返回空内容,重新生成
ai_response = ""
else:
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
# MCP成功且有内容,模拟流式输出(分块发送)
chunk_size = 50
for i in range(0, len(ai_response), chunk_size):
chunk = ai_response[i:i+chunk_size]
chunk_count += 1
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 3 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
)
# 跳过后续的流式生成
ai_response = result.get('content', '')
else:
ai_response = result
# 如果MCP调用失败或返回空,继续走流式生成
if not ai_response or not ai_response.strip():
logger.info(f"🔄 开始流式生成...")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as mcp_error:
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
result = await user_ai_service.generate_text(prompt=prompt)
ai_response = result.get('content', '') if isinstance(result, dict) else result
logger.debug(f"未登录用户,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
@@ -823,7 +923,7 @@ async def generate_character_stream(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 60)
yield await SSEResponse.send_progress("解析AI响应...", 96)
# ✅ 使用统一的 JSON 清洗方法
try:
@@ -836,7 +936,7 @@ async def generate_character_stream(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建角色记录...", 75)
yield await SSEResponse.send_progress("创建角色记录...", 97)
# 转换traits
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
# 如果是组织,创建Organization详情
if is_organization:
yield await SSEResponse.send_progress("创建组织详情...", 85)
yield await SSEResponse.send_progress("创建组织详情...", 98)
org_check = await db.execute(
select(Organization).where(Organization.character_id == character.id)
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
yield await SSEResponse.send_progress("保存生成历史...", 95)
yield await SSEResponse.send_progress("保存生成历史...", 99)
# 记录生成历史
history = GenerationHistory(
project_id=request.project_id,
prompt=prompt,
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
generated_content=ai_response,
model=user_ai_service.default_model
)
db.add(history)
+204 -28
View File
@@ -105,23 +105,27 @@ async def generate_options(
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板(根据step确定模板key)
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
template_key_map = {
"title": "INSPIRATION_TITLE",
"description": "INSPIRATION_DESCRIPTION",
"theme": "INSPIRATION_THEME",
"genre": "INSPIRATION_GENRE"
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_key = template_key_map.get(step)
template_keys = template_key_map.get(step)
if not template_key:
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
# 获取自定义提示词模板
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
system_key, user_key = template_keys
# 获取自定义提示词模板(分别获取 system 和 user
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
@@ -131,19 +135,9 @@ async def generate_options(
"theme": context.get("theme", "")
}
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
# 尝试解析为JSON格式的字典
try:
prompt_template = json.loads(prompt_template_str)
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
except (json.JSONDecodeError, KeyError):
# 如果不是JSON格式,降级使用原有方法
prompt_template = prompt_service.get_inspiration_prompt(step)
if not prompt_template:
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
system_prompt = prompt_template["system"].format(**format_params)
user_prompt = prompt_template["user"].format(**format_params)
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 如果是重试,在提示词中强调格式要求
if attempt > 0:
@@ -153,13 +147,18 @@ async def generate_options(
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
response = await ai_service.generate_text(
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
)
):
accumulated_text += chunk
content = response.get("content", "")
response = {"content": accumulated_text}
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON(使用统一的JSON清洗方法)
@@ -222,6 +221,180 @@ async def generate_options(
}
@router.post("/refine-options")
async def refine_options(
data: Dict[str, Any],
http_request: Request,
db: AsyncSession = Depends(get_db),
ai_service: AIService = Depends(get_user_ai_service)
) -> Dict[str, Any]:
"""
基于用户反馈重新生成选项(支持多轮对话)
Request:
{
"step": "title", // 当前步骤
"context": {
"initial_idea": "...",
"title": "...",
"description": "...",
"theme": "..."
},
"feedback": "我想要更悲剧一些的主题", // 用户反馈
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
}
Response:
{
"prompt": "引导语",
"options": ["新选项1", "新选项2", ...]
}
"""
max_retries = 3
for attempt in range(max_retries):
try:
step = data.get("step", "title")
context = data.get("context", {})
feedback = data.get("feedback", "")
previous_options = data.get("previous_options", [])
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
logger.info(f"用户反馈: {feedback}")
# 获取用户ID
user_id = getattr(http_request.state, 'user_id', None)
# 获取对应的提示词模板
template_key_map = {
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
}
template_keys = template_key_map.get(step)
if not template_keys:
return {
"error": f"不支持的步骤: {step}",
"prompt": "",
"options": []
}
system_key, user_key = template_keys
# 获取自定义提示词模板
system_template = await PromptService.get_template(system_key, user_id, db)
user_template = await PromptService.get_template(user_key, user_id, db)
# 准备格式化参数
format_params = {
"initial_idea": context.get("initial_idea", context.get("description", "")),
"title": context.get("title", ""),
"description": context.get("description", ""),
"theme": context.get("theme", "")
}
# 格式化提示词
system_prompt = system_template.format(**format_params)
user_prompt = user_template.format(**format_params)
# 添加反馈信息到提示词
feedback_instruction = f"""
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
{feedback}
之前生成的选项:
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
注意:
1. 仔细理解用户的反馈意图
2. 生成的新选项要明显体现用户要求的调整方向
3. 保持与已有上下文的一致性
4. 确保返回6个有效选项
"""
system_prompt += feedback_instruction
# 如果是重试,强调格式要求
if attempt > 0:
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
# 调用AI生成选项
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
# 反馈生成时使用稍高的temperature以获得更多样化的结果
temperature = min(temperature + 0.1, 0.9)
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
# 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature
):
accumulated_text += chunk
content = accumulated_text
logger.info(f"AI返回内容长度: {len(content)}")
# 解析JSON
try:
cleaned_content = ai_service._clean_json_response(content)
result = json.loads(cleaned_content)
# 校验返回格式
is_valid, error_msg = validate_options_response(result, step)
if not is_valid:
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
if attempt < max_retries - 1:
logger.info("准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}"
}
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
return result
except json.JSONDecodeError as e:
logger.error(f"{attempt + 1}次JSON解析失败: {e}")
if attempt < max_retries - 1:
logger.info("JSON解析失败,准备重试...")
continue
else:
return {
"prompt": f"请为【{step}】提供内容:",
"options": ["让AI重新生成", "我自己输入"],
"error": f"AI返回格式错误,已自动重试{max_retries}"
}
except Exception as e:
logger.error(f"{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
if attempt < max_retries - 1:
logger.info("发生异常,准备重试...")
continue
else:
return {
"error": str(e),
"prompt": "生成失败,请重试",
"options": ["重新生成", "我自己输入"]
}
return {
"error": "生成失败",
"prompt": "请重试",
"options": []
}
@router.post("/quick-generate")
async def quick_generate(
data: Dict[str, Any],
@@ -280,14 +453,17 @@ async def quick_generate(
# 降级使用原有方法
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
# 调用AI
response = await ai_service.generate_text(
# 调用AI - 流式生成并累积文本
accumulated_text = ""
async for chunk in ai_service.generate_text_stream(
prompt=prompts["user"],
system_prompt=prompts["system"],
temperature=0.7
)
):
accumulated_text += chunk
content = response.get("content", "")
response = {"content": accumulated_text}
content = accumulated_text
# 解析JSON(使用统一的JSON清洗方法)
try:
+27 -6
View File
@@ -512,8 +512,29 @@ async def generate_organization_stream(
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
try:
ai_response = await user_ai_service.generate_text(prompt=prompt)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else str(ai_response)
# 使用流式生成替代非流式
ai_content = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_content += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新字数(5-95%,AI生成占90%)
if chunk_count % 5 == 0:
progress = min(5 + (chunk_count // 5), 95)
yield await SSEResponse.send_progress(
f"AI生成组织中... ({len(ai_content)}字符)",
progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
@@ -523,7 +544,7 @@ async def generate_organization_stream(
yield await SSEResponse.send_error("AI服务返回空响应")
return
yield await SSEResponse.send_progress("解析AI响应...", 60)
yield await SSEResponse.send_progress("解析AI响应...", 96)
# ✅ 使用统一的 JSON 清洗方法
try:
@@ -536,7 +557,7 @@ async def generate_organization_stream(
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON{str(e)}")
return
yield await SSEResponse.send_progress("创建组织记录...", 75)
yield await SSEResponse.send_progress("创建组织记录...", 97)
# 创建角色记录(组织也是角色的一种)
character = Character(
@@ -563,7 +584,7 @@ async def generate_organization_stream(
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
yield await SSEResponse.send_progress("创建组织详情...", 85)
yield await SSEResponse.send_progress("创建组织详情...", 98)
# 自动创建Organization详情记录
organization = Organization(
@@ -580,7 +601,7 @@ async def generate_organization_stream(
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
yield await SSEResponse.send_progress("保存生成历史...", 95)
yield await SSEResponse.send_progress("保存生成历史...", 99)
# 记录生成历史
history = GenerationHistory(
+89 -30
View File
@@ -470,7 +470,7 @@ async def _generate_new_outline(
project: Project,
db: AsyncSession,
user_ai_service: AIService,
user_id: str = None
user_id: str
) -> OutlineListResponse:
"""全新生成大纲(MCP增强版)"""
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
@@ -534,7 +534,7 @@ async def _generate_new_outline(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -573,15 +573,23 @@ async def _generate_new_outline(
mcp_references=mcp_reference_materials
)
# 调用AI生成大纲
ai_response = await user_ai_service.generate_text(
# 调用AI流式生成大纲(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=request.provider,
model=request.model
)
):
chunk_count += 1
accumulated_text += chunk
# 这里是非SSE接口,不需要发送chunk
# 如果未来需要转SSE,可以在这里yield
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
@@ -732,7 +740,7 @@ async def _continue_outline(
existing_outlines: List[Outline],
db: AsyncSession,
user_ai_service: AIService,
user_id: str = "system"
user_id: str
) -> OutlineListResponse:
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
@@ -1000,7 +1008,7 @@ async def _continue_outline(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1045,15 +1053,22 @@ async def _continue_outline(
)
# 调用AI生成当前批次
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
ai_response = await user_ai_service.generate_text(
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=request.provider,
model=request.model
)
):
chunk_count += 1
accumulated_text += chunk
# 这里是非SSE接口,不需要发送chunk
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
user_id=user_id_for_mcp,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
mcp_references=mcp_reference_materials
)
# 调用AI
# 调用AI流式生成
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
# 添加调试日志
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
logger.info(f"=== 大纲生成AI调用参数 ===")
logger.info(f" provider参数: {provider_param}")
logger.info(f" model参数: {model_param}")
logger.info(f" 完整data: {data}")
ai_response = await user_ai_service.generate_text(
# ✅ 流式生成(带字数统计和进度)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider_param,
model=model_param
)
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数(30-95%,AI生成占65%
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 2), 95)
yield await SSEResponse.send_progress(
f"AI生成大纲中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
# 提取内容(generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
# 全新生成模式:删除旧大纲和关联的所有章节
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode}")
from sqlalchemy import delete as sql_delete
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
# 保存新大纲
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
outlines = await _save_outlines(
project_id, outline_data, db, start_index=1
)
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
for outline in outlines:
await db.refresh(outline)
yield await SSEResponse.send_progress("整理结果数据...", 95)
yield await SSEResponse.send_progress("整理结果数据...", 99)
logger.info(f"全新生成完成 - {len(outlines)}")
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
logger.info(f" provider参数: {provider_param}")
logger.info(f" model参数: {model_param}")
ai_response = await user_ai_service.generate_text(
# 流式生成并累积文本
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider_param,
model=model_param
)
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度(每批占用约50%的进度空间)
if chunk_count % 5 == 0:
# 在批次范围内平滑递增(从10到85,总共75%)
batch_range = 60 // total_batches # 总进度60%分配给所有批次
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
yield await SSEResponse.send_progress(
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
progress_in_batch
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
yield await SSEResponse.send_progress(
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
batch_progress + 10
)
# 提取内容generate_text返回字典)
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
# 提取内容
ai_content = accumulated_text
ai_response = {"content": ai_content}
# 解析响应
outline_data = _parse_ai_response(ai_content)
+25 -11
View File
@@ -73,14 +73,15 @@ async def get_user_ai_service(
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
# 使用用户设置创建AI服务实例
# 使用用户设置创建AI服务实例(包括系统提示词)
return create_user_ai_service(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url or "",
model_name=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens
max_tokens=settings.max_tokens,
system_prompt=settings.system_prompt # 传递系统提示词
)
@@ -271,17 +272,30 @@ async def get_available_models(
}
elif provider == "anthropic":
# Anthropic 没有公开的模型列表API
raise HTTPException(
status_code=400,
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
)
# Anthropic models API
url = f"{api_base_url.rstrip('/')}/v1/models"
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
return {"provider": provider, "models": models, "count": len(models)}
elif provider == "gemini":
# Gemini models API
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
response = await client.get(url)
response.raise_for_status()
data = response.json()
models = []
for m in data.get("models", []):
if "generateContent" in m.get("supportedGenerationMethods", []):
mid = m.get("name", "").replace("models/", "")
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
return {"provider": provider, "models": models, "count": len(models)}
else:
raise HTTPException(
status_code=400,
detail=f"不支持的提供商: {provider}"
)
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
except httpx.HTTPStatusError as e:
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
+362 -126
View File
@@ -99,7 +99,7 @@ async def world_building_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -139,51 +139,118 @@ async def world_building_generator(
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
# ===== 流式生成世界观(带重试机制) =====
MAX_WORLD_RETRIES = 3 # 最多重试3次
world_retry_count = 0
world_generation_success = False
world_data = {}
try:
# ✅ 使用 AIService 的统一清洗方法
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
world_data = json.loads(cleaned_text)
logger.info(f"世界观JSON解析成功")
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
try:
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
except json.JSONDecodeError as e:
logger.error(f"❌ 世界构建JSON解析失败: {e}")
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 世界观生成独立进度:5-95%
if chunk_count % 5 == 0:
progress = min(5 + (chunk_count // 3), 95)
yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 检查是否返回空响应
if not accumulated_text or not accumulated_text.strip():
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ AI返回为空,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
logger.error("❌ 世界观生成多次返回空响应")
world_data = {
"time_period": "AI多次返回为空,请稍后重试",
"location": "AI多次返回为空,请稍后重试",
"atmosphere": "AI多次返回为空,请稍后重试",
"rules": "AI多次返回为空,请稍后重试"
}
world_generation_success = True # 标记为成功以继续流程
break
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析世界观数据...", 96)
try:
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
logger.info(f" 原始内容预览: {accumulated_text[:300]}...")
# ✅ 使用 AIService 的统一清洗方法
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
logger.info(f" 清洗后预览: {cleaned_text[:300]}...")
world_data = json.loads(cleaned_text)
logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_generation_success = True # 解析成功,标记完成
except json.JSONDecodeError as e:
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {e}")
logger.error(f" 原始内容长度: {len(accumulated_text)}")
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ JSON解析失败,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
world_generation_success = True # 标记为成功以继续流程
except Exception as e:
logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {type(e).__name__}: {e}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 生成异常,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 最后一次重试仍失败,抛出异常
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
raise
# 保存到数据库
yield await SSEResponse.send_progress("保存到数据库...", 90)
yield await SSEResponse.send_progress("保存世界观到数据库...", 99)
# 确保user_id存在
if not user_id:
@@ -240,41 +307,81 @@ async def world_building_generator(
project.wizard_step = 1
await db.commit()
# ===== 自动生成职业体系 =====
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 75)
# ===== 自动生成职业体系(带重试机制+流式) =====
yield await SSEResponse.send_progress("世界观完成!", 100, "success")
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5)
logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系")
try:
# 获取职业生成提示词模板(支持用户自定义)
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
career_prompt = PromptService.format_prompt(
template,
title=project.title,
genre=genre or '未设定',
theme=theme or '未设定',
time_period=world_data.get('time_period', '未设定'),
location=world_data.get('location', '未设定'),
atmosphere=world_data.get('atmosphere', '未设定'),
rules=world_data.get('rules', '未设定')
)
yield await SSEResponse.send_progress("正在生成职业体系...", 78)
# 调用AI生成职业
result = await user_ai_service.generate_text(prompt=career_prompt)
career_response = result.get('content', '') if isinstance(result, dict) else result
if not career_response or not career_response.strip():
logger.warning("⚠️ AI返回空职业体系,跳过职业生成")
yield await SSEResponse.send_progress("职业体系生成跳过(AI返回为空)", 85)
else:
yield await SSEResponse.send_progress("解析职业体系数据...", 82)
MAX_CAREER_RETRIES = 3 # 最多重试3次
career_retry_count = 0
career_generation_success = False
while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success:
try:
retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else ""
yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10)
# 获取职业生成提示词模板(支持用户自定义)
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
career_prompt = PromptService.format_prompt(
template,
title=project.title,
genre=genre or '未设定',
theme=theme or '未设定',
time_period=world_data.get('time_period', '未设定'),
location=world_data.get('location', '未设定'),
atmosphere=world_data.get('atmosphere', '未设定'),
rules=world_data.get('rules', '未设定')
)
# ✅ 使用流式生成职业体系
career_response = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=career_prompt,
provider=provider,
model=model
):
chunk_count += 1
career_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 职业体系生成独立进度:10-95%
if chunk_count % 5 == 0:
progress = min(10 + (chunk_count // 3), 95)
yield await SSEResponse.send_progress(
f"生成职业体系中... ({len(career_response)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
if not career_response or not career_response.strip():
logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ AI返回为空,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99)
break
yield await SSEResponse.send_progress("解析职业体系数据...", 96)
# 清洗并解析JSON
try:
cleaned_response = user_ai_service._clean_json_response(career_response)
career_data = json.loads(cleaned_response)
logger.info(f"✅ 职业体系JSON解析成功")
logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}")
# 保存主职业
main_careers_created = []
@@ -338,22 +445,51 @@ async def world_building_generator(
await db.commit()
# 标记成功
career_generation_success = True
logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}")
yield await SSEResponse.send_progress(
f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)}",
90
99
)
except json.JSONDecodeError as e:
logger.error(f"❌ 职业体系JSON解析失败: {e}")
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败,已跳过", 85)
logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}: {e}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ JSON解析失败,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99)
except Exception as e:
logger.error(f"❌ 职业体系保存失败: {e}")
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败,已跳过", 85)
except Exception as e:
logger.error(f"❌ 职业体系生成异常: {e}")
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败,已跳过(不影响项目创建)", 85)
logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}: {e}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 保存失败,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99)
except Exception as e:
logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}: {e}")
career_retry_count += 1
if career_retry_count < MAX_CAREER_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 生成异常,准备重试...",
10,
"warning"
)
continue
else:
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99)
db_committed = True
@@ -366,7 +502,8 @@ async def world_building_generator(
"rules": world_data.get("rules")
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_progress("职业体系完成", 100, "success")
yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
@@ -473,7 +610,7 @@ async def characters_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮
max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮
tool_choice="auto",
provider=None,
model=None
@@ -611,15 +748,32 @@ async def characters_generator(
else:
prompt = base_prompt
# 流式生成
# 流式生成(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数
if chunk_count % 5 == 0:
progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15)
yield await SSEResponse.send_progress(
f"生成角色中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析批次结果 - 使用统一的JSON清洗方法
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
@@ -1184,18 +1338,35 @@ async def outline_generator(
requirements=outline_requirements
)
# 流式生成大纲
# 流式生成大纲(带字数统计)
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=outline_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度和字数(5-95%,AI生成占90%)
if chunk_count % 5 == 0:
progress = min(5 + (chunk_count // 3), 95)
yield await SSEResponse.send_progress(
f"生成大纲中... ({len(accumulated_text)}字符)",
progress
)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析大纲结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析大纲...", 40)
yield await SSEResponse.send_progress("解析大纲...", 96)
try:
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
@@ -1208,7 +1379,7 @@ async def outline_generator(
return
# 保存大纲到数据库
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
yield await SSEResponse.send_progress("保存大纲到数据库...", 97)
created_outlines = []
for index, outline_item in enumerate(outline_data[:outline_count], 1):
outline = Outline(
@@ -1231,7 +1402,7 @@ async def outline_generator(
created_chapters = []
if project.outline_mode == 'one-to-one':
# 一对一模式:自动为每个大纲创建对应的章节
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50)
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98)
for outline in created_outlines:
chapter = Chapter(
@@ -1250,10 +1421,10 @@ async def outline_generator(
await db.refresh(chapter)
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85)
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99)
else:
# 一对多模式:跳过自动创建,用户可手动展开
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85)
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99)
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
# 更新项目信息
@@ -1396,7 +1567,7 @@ async def world_building_regenerate_generator(
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
@@ -1433,44 +1604,109 @@ async def world_building_regenerate_generator(
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
# ===== 流式生成世界观(带重试机制) =====
MAX_WORLD_RETRIES = 3 # 最多重试3次
world_retry_count = 0
world_generation_success = False
world_data = {}
try:
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
world_data = json.loads(cleaned_text)
logger.info(f"✅ 世界观重新生成JSON解析成功")
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
try:
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
except json.JSONDecodeError as e:
logger.error(f"世界构建JSON解析失败: {e}")
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 85)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 检查是否返回空响应
if not accumulated_text or not accumulated_text.strip():
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ AI返回为空,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
logger.error("❌ 世界观重新生成多次返回空响应")
world_data = {
"time_period": "AI多次返回为空,请稍后重试",
"location": "AI多次返回为空,请稍后重试",
"atmosphere": "AI多次返回为空,请稍后重试",
"rules": "AI多次返回为空,请稍后重试"
}
world_generation_success = True
break
# 解析结果 - 使用统一的JSON清洗方法
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
try:
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
world_data = json.loads(cleaned_text)
logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}")
world_generation_success = True
except json.JSONDecodeError as e:
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {e}")
logger.error(f" 原始内容长度: {len(accumulated_text)}")
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ JSON解析失败,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 达到最大重试次数,使用默认值
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
world_generation_success = True
except Exception as e:
logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}: {type(e).__name__}: {e}")
world_retry_count += 1
if world_retry_count < MAX_WORLD_RETRIES:
yield await SSEResponse.send_progress(
f"⚠️ 生成异常,准备重试...",
30 + world_retry_count * 5,
"warning"
)
continue
else:
# 最后一次重试仍失败,抛出异常
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
raise
# 不保存到数据库,仅返回生成结果供用户预览
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
+37 -15
View File
@@ -15,7 +15,6 @@ from ..schemas.writing_style import (
WritingStyleListResponse,
SetDefaultStyleRequest
)
from ..services.prompt_service import WritingStyleManager
from ..logger import get_logger
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
@@ -31,21 +30,36 @@ def get_current_user_id(request: Request) -> str:
@router.get("/presets/list", response_model=List[dict])
async def get_preset_styles():
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
"""
获取所有预设风格列表
获取所有预设风格列表从数据库读取
返回格式数组形式的预设风格列表
[
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": "classical", "name": "古典优雅", ...}
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
]
"""
presets = WritingStyleManager.get_all_presets()
# 将字典转换为数组,添加 id 字段
# 从数据库获取全局预设风格(user_id 为 NULL
result = await db.execute(
select(WritingStyle)
.where(WritingStyle.user_id.is_(None))
.order_by(WritingStyle.order_index)
)
preset_styles = result.scalars().all()
# 转换为响应格式
return [
{"id": preset_id, **preset_data}
for preset_id, preset_data in presets.items()
{
"id": style.id,
"preset_id": style.preset_id,
"name": style.name,
"description": style.description,
"prompt_content": style.prompt_content,
"style_type": style.style_type,
"order_index": style.order_index
}
for style in preset_styles
]
@@ -58,25 +72,33 @@ async def create_writing_style(
"""
创建新的写作风格用户级别
- **基于预设创建**提供 preset_id系统会自动填充预设内容
- **基于预设创建**提供 preset_id系统会从数据库查询预设内容自动填充
- **完全自定义**不提供 preset_id需要手动填写所有字段
"""
# 获取当前用户ID
user_id = get_current_user_id(request)
# 如果基于预设创建,获取预设内容
# 如果基于预设创建,从数据库获取预设内容
if style_data.preset_id:
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
result = await db.execute(
select(WritingStyle)
.where(
WritingStyle.user_id.is_(None),
WritingStyle.preset_id == style_data.preset_id
)
)
preset = result.scalar_one_or_none()
if not preset:
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
# 使用预设内容填充(如果用户未提供)
if not style_data.name:
style_data.name = preset["name"]
style_data.name = preset.name
if not style_data.description:
style_data.description = preset["description"]
style_data.description = preset.description
if not style_data.prompt_content:
style_data.prompt_content = preset["prompt_content"]
style_data.prompt_content = preset.prompt_content
# 验证必填字段
if not style_data.name or not style_data.prompt_content:
+74 -35
View File
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl
from anyio import ClosedResourceError
from app.logger import get_logger
@@ -141,51 +142,89 @@ class HTTPMCPClient:
async def call_tool(
self,
tool_name: str,
arguments: Dict[str, Any]
arguments: Dict[str, Any],
max_reconnect_attempts: int = 2
) -> Any:
"""
调用工具
调用工具带自动重连
Args:
tool_name: 工具名称
arguments: 工具参数
max_reconnect_attempts: 最大重连尝试次数
Returns:
工具执行结果
"""
try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}")
logger.debug(f"参数: {arguments}")
result = await self._session.call_tool(tool_name, arguments)
# 处理返回结果
# MCP SDK 返回 CallToolResult 对象
if result.content:
# 提取第一个content的文本
for content in result.content:
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
# 如果没有文本内容,返回原始内容
return result.content[0] if result.content else None
# 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
return result.structuredContent
return None
except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
raise MCPError(f"调用工具失败: {str(e)}")
for attempt in range(max_reconnect_attempts + 1):
try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}")
logger.debug(f" 参数类型: {type(arguments)}")
logger.debug(f" 参数内容: {arguments}")
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
result = await self._session.call_tool(tool_name, arguments)
logger.debug(f" 工具返回类型: {type(result)}")
logger.debug(f" 返回内容: {result}")
# 处理返回结果
# MCP SDK 返回 CallToolResult 对象
if result.content:
logger.debug(f" 返回content数量: {len(result.content)}")
# 提取第一个content的文本
for idx, content in enumerate(result.content):
logger.debug(f" content[{idx}]类型: {type(content)}")
if isinstance(content, types.TextContent):
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
return content.text
elif isinstance(content, types.ImageContent):
logger.debug(f" ✅ 返回ImageContent")
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
# 如果没有文本内容,返回原始内容
logger.debug(f" ⚠️ 返回原始content[0]")
return result.content[0] if result.content else None
# 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
logger.debug(f" ✅ 返回structuredContent")
return result.structuredContent
logger.warning(f" ⚠️ 工具返回为None")
return None
except ClosedResourceError as e:
# 连接已关闭,尝试重连
if attempt < max_reconnect_attempts:
logger.warning(
f"⚠️ MCP连接已关闭,尝试重新连接 "
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
)
await self._cleanup()
await asyncio.sleep(0.5) # 短暂延迟后重连
continue
else:
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
raise MCPError(error_msg)
except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
logger.error(f" 参数: {arguments}")
logger.error(f" 错误类型: {type(e).__name__}")
logger.error(f" 错误详情: {repr(e)}")
logger.error(f" 错误字符串: '{str(e)}'")
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
raise MCPError(f"调用工具失败: {error_msg}")
# 理论上不会到这里
raise MCPError(f"工具调用失败: 未知错误")
async def list_resources(self) -> List[Dict[str, Any]]:
"""
+1
View File
@@ -17,6 +17,7 @@ class Settings(Base):
llm_model = Column(String(100), default="gpt-4", comment="模型名称")
temperature = Column(Float, default=0.7, comment="温度参数")
max_tokens = Column(Integer, default=2000, comment="最大token数")
system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用")
preferences = Column(Text, comment="其他偏好设置(JSON)")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
+1
View File
@@ -14,6 +14,7 @@ class SettingsBase(BaseModel):
llm_model: Optional[str] = Field(default="gpt-4", description="模型名称")
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用")
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
@@ -0,0 +1,6 @@
"""AI 客户端模块"""
from .base_client import BaseAIClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
@@ -0,0 +1,86 @@
"""Anthropic 客户端"""
from typing import Any, AsyncGenerator, Dict, Optional
from anthropic import AsyncAnthropic
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
logger = get_logger(__name__)
class AnthropicClient:
"""Anthropic API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.config = config or default_config
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self.client = AsyncAnthropic(**kwargs)
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
kwargs = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": messages,
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = tools
if tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "auto":
kwargs["tool_choice"] = {"type": "auto"}
response = await self.client.messages.create(**kwargs)
tool_calls = []
content = ""
for block in response.content:
if block.type == "tool_use":
tool_calls.append({
"id": block.id,
"type": "function",
"function": {"name": block.name, "arguments": block.input},
})
elif block.type == "text":
content += block.text
return {
"content": content,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": response.stop_reason,
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
kwargs = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": messages,
}
if system_prompt:
kwargs["system"] = system_prompt
async with self.client.messages.stream(**kwargs) as stream:
async for text in stream.text_stream:
yield text
@@ -0,0 +1,154 @@
"""AI 客户端基类"""
import asyncio
import hashlib
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, Optional
import httpx
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
logger = get_logger(__name__)
# 全局 HTTP 客户端池
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
_global_semaphore: Optional[asyncio.Semaphore] = None
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
"""获取全局信号量"""
global _global_semaphore
if _global_semaphore is None:
_global_semaphore = asyncio.Semaphore(max_concurrent)
return _global_semaphore
class BaseAIClient(ABC):
"""AI HTTP 客户端基类"""
def __init__(
self,
api_key: str,
base_url: str,
config: Optional[AIClientConfig] = None,
):
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.config = config or default_config
self.http_client = self._get_or_create_client()
def _get_client_key(self) -> str:
"""生成客户端唯一键"""
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
def _get_or_create_client(self) -> httpx.AsyncClient:
"""获取或创建 HTTP 客户端"""
client_key = self._get_client_key()
if client_key in _http_client_pool:
client = _http_client_pool[client_key]
if not client.is_closed:
return client
del _http_client_pool[client_key]
http_cfg = self.config.http
client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout,
),
limits=httpx.Limits(
max_keepalive_connections=http_cfg.max_keepalive_connections,
max_connections=http_cfg.max_connections,
keepalive_expiry=http_cfg.keepalive_expiry,
),
)
_http_client_pool[client_key] = client
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
return client
@abstractmethod
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
pass
async def _request_with_retry(
self,
method: str,
endpoint: str,
payload: Dict[str, Any],
stream: bool = False,
) -> Any:
"""带重试的 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
headers = self._build_headers()
retry_cfg = self.config.retry
rate_cfg = self.config.rate_limit
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
async with semaphore:
await asyncio.sleep(rate_cfg.request_delay)
for attempt in range(retry_cfg.max_retries):
try:
if attempt > 0:
delay = min(
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
retry_cfg.max_delay,
)
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
await asyncio.sleep(delay)
if stream:
return self.http_client.stream(method, url, headers=headers, json=payload)
response = await self.http_client.request(method, url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code in retry_cfg.non_retryable_status_codes:
raise
if attempt == retry_cfg.max_retries - 1:
raise
except (httpx.ConnectError, httpx.TimeoutException):
if attempt == retry_cfg.max_retries - 1:
raise
@abstractmethod
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天补全"""
pass
@abstractmethod
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
"""流式聊天补全"""
pass
async def cleanup_all_clients():
"""清理所有 HTTP 客户端"""
for key, client in list(_http_client_pool.items()):
if not client.is_closed:
await client.aclose()
_http_client_pool.clear()
logger.info("✅ HTTP 客户端池已清理")
@@ -0,0 +1,141 @@
"""Gemini 客户端"""
from typing import Any, AsyncGenerator, Dict, List, Optional
import httpx
from app.services.ai_config import AIClientConfig, default_config
class GeminiClient:
"""Google Gemini API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.api_key = api_key
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
self.config = config or default_config
http_cfg = self.config.http
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout
)
)
def _convert_tools_to_gemini(self, tools: list) -> list:
"""将 OpenAI 格式工具转换为 Gemini 格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
params.pop("$schema", None)
params.pop("additionalProperties", None)
if params and "type" not in params:
params["type"] = "object"
decl = {
"name": func["name"],
"description": func.get("description") or func["name"],
}
if params:
decl["parameters"] = params
gemini_tools.append(decl)
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
response = await self.client.post(url, json=payload)
response.raise_for_status()
data = response.json()
candidates = data.get("candidates", [])
if not candidates or len(candidates) == 0:
# 返回空内容而不是报错,保持流程继续
return {
"content": "",
"tool_calls": None,
"finish_reason": "stop"
}
parts = candidates[0].get("content", {}).get("parts", [])
text = ""
tool_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
tool_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
})
return {
"content": text,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": "tool_calls" if tool_calls else "stop"
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
async with self.client.stream("POST", url, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
import json
try:
data = json.loads(line[6:])
candidates = data.get("candidates", [])
if candidates and len(candidates) > 0:
parts = candidates[0].get("content", {}).get("parts", [])
if parts and len(parts) > 0:
text = parts[0].get("text", "")
if text:
yield text
except:
continue
@@ -0,0 +1,101 @@
"""OpenAI 客户端"""
import json
from typing import Any, AsyncGenerator, Dict, Optional
from app.logger import get_logger
from .base_client import BaseAIClient
logger = get_logger(__name__)
class OpenAIClient(BaseAIClient):
"""OpenAI API 客户端"""
def _build_headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _build_payload(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
stream: bool = False,
) -> Dict[str, Any]:
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if stream:
payload["stream"] = True
if tools:
# 清理 $schema 字段
cleaned = []
for t in tools:
tc = t.copy()
if "function" in tc and "parameters" in tc["function"]:
tc["function"]["parameters"] = {
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
}
cleaned.append(tc)
payload["tools"] = cleaned
if tool_choice:
payload["tool_choice"] = tool_choice
return payload
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
data = await self._request_with_retry("POST", "/chat/completions", payload)
choices = data.get("choices", [])
if not choices or len(choices) == 0:
raise ValueError("API 返回空 choices 或 choices 为空列表")
choice = choices[0]
message = choice.get("message", {})
return {
"content": message.get("content", ""),
"tool_calls": message.get("tool_calls"),
"finish_reason": choice.get("finish_reason"),
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices and len(choices) > 0:
content = choices[0].get("delta", {}).get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
+44
View File
@@ -0,0 +1,44 @@
"""AI 服务配置管理"""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class HTTPClientConfig:
"""HTTP 客户端配置"""
connect_timeout: float = 90.0
read_timeout: float = 300.0
write_timeout: float = 90.0
pool_timeout: float = 90.0
max_keepalive_connections: int = 50
max_connections: int = 100
keepalive_expiry: float = 60.0
@dataclass
class RetryConfig:
"""重试配置"""
max_retries: int = 3
base_delay: float = 0.2
max_delay: float = 10.0
exponential_base: int = 2
non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404))
@dataclass
class RateLimitConfig:
"""限流配置"""
max_concurrent_requests: int = 5
request_delay: float = 0.2
@dataclass
class AIClientConfig:
"""AI 客户端完整配置"""
http: HTTPClientConfig = field(default_factory=HTTPClientConfig)
retry: RetryConfig = field(default_factory=RetryConfig)
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
# 全局默认配置
default_config = AIClientConfig()
@@ -0,0 +1,6 @@
"""AI Provider 模块"""
from .base_provider import BaseAIProvider
from .openai_provider import OpenAIProvider
from .anthropic_provider import AnthropicProvider
__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"]
@@ -0,0 +1,51 @@
"""Anthropic Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.services.ai_clients.anthropic_client import AnthropicClient
from .base_provider import BaseAIProvider
class AnthropicProvider(BaseAIProvider):
"""Anthropic 提供商"""
def __init__(self, client: AnthropicClient):
self.client = client
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
messages = [{"role": "user", "content": prompt}]
return await self.client.chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=tool_choice,
)
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
):
yield chunk
@@ -0,0 +1,33 @@
"""AI Provider 基类"""
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional
class BaseAIProvider(ABC):
"""AI 提供商抽象基类"""
@abstractmethod
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
"""生成文本"""
pass
@abstractmethod
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
pass
@@ -0,0 +1,48 @@
"""Gemini Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.services.ai_clients.gemini_client import GeminiClient
from .base_provider import BaseAIProvider
class GeminiProvider(BaseAIProvider):
def __init__(self, client: GeminiClient):
self.client = client
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
messages = [{"role": "user", "content": prompt}]
return await self.client.chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=tool_choice,
)
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
):
yield chunk
@@ -0,0 +1,57 @@
"""OpenAI Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.services.ai_clients.openai_client import OpenAIClient
from .base_provider import BaseAIProvider
class OpenAIProvider(BaseAIProvider):
"""OpenAI 提供商"""
def __init__(self, client: OpenAIClient):
self.client = client
async def generate(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return await self.client.chat_completion(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=tool_choice,
)
async def generate_stream(
self,
prompt: str,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
yield chunk
File diff suppressed because it is too large Load Diff
@@ -263,7 +263,7 @@ class AutoCharacterService:
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
@@ -362,7 +362,7 @@ class AutoCharacterService:
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=1
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
+149
View File
@@ -0,0 +1,149 @@
"""JSON 处理工具类"""
import json
import re
from typing import Any, Dict, List, Union
from app.logger import get_logger
logger = get_logger(__name__)
def clean_json_response(text: str) -> str:
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
try:
if not text:
logger.warning("⚠️ clean_json_response: 输入为空")
return text
original_length = len(text)
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
# 去除 markdown 代码块
text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE)
text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE)
text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE)
text = text.strip()
if len(text) != original_length:
logger.debug(f" 移除markdown后长度: {len(text)}")
# 尝试直接解析(快速路径)
try:
json.loads(text)
logger.debug(f"✅ 直接解析成功,无需清洗")
return text
except:
pass
# 找到第一个 { 或 [
start = -1
for i, c in enumerate(text):
if c in ('{', '['):
start = i
break
if start == -1:
logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [")
logger.debug(f" 文本预览: {text[:200]}")
return text
if start > 0:
logger.debug(f" 跳过前{start}个字符")
text = text[start:]
# 改进的括号匹配算法(更严格的字符串处理)
stack = []
i = 0
end = -1
while i < len(text):
c = text[i]
# 处理字符串(关键:正确处理转义)
if c == '"':
# 计算前面有多少个连续的反斜杠
num_backslashes = 0
j = i - 1
while j >= 0 and text[j] == '\\':
num_backslashes += 1
j -= 1
# 偶数个反斜杠(包括0)表示引号未被转义
if num_backslashes % 2 == 0:
# 这是字符串边界,跳过整个字符串
i += 1
while i < len(text):
if text[i] == '"':
# 再次检查转义
num_backslashes = 0
j = i - 1
while j >= 0 and text[j] == '\\':
num_backslashes += 1
j -= 1
if num_backslashes % 2 == 0:
# 字符串结束
break
i += 1
i += 1
continue
# 处理括号(只有在字符串外部才有效)
if c == '{' or c == '[':
stack.append(c)
elif c == '}':
if len(stack) > 0 and stack[-1] == '{':
stack.pop()
if len(stack) == 0:
end = i + 1
logger.debug(f"✅ 找到JSON结束位置: {end}")
break
else:
logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1] if stack else 'empty'}")
elif c == ']':
if len(stack) > 0 and stack[-1] == '[':
stack.pop()
if len(stack) == 0:
end = i + 1
logger.debug(f"✅ 找到JSON结束位置: {end}")
break
else:
logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1] if stack else 'empty'}")
i += 1
# 提取结果
if end > 0:
result = text[:end]
logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}")
else:
result = text
logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)}")
logger.debug(f" 栈状态: {stack}")
# 验证清洗后的结果
try:
json.loads(result)
logger.debug(f"✅ 清洗后JSON验证成功")
except json.JSONDecodeError as e:
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
logger.debug(f" 结果预览: {result[:500]}")
logger.debug(f" 结果结尾: ...{result[-200:]}")
return result
except Exception as e:
logger.error(f"❌ clean_json_response 出错: {e}")
logger.error(f" 文本长度: {len(text) if text else 0}")
logger.error(f" 文本预览: {text[:200] if text else 'None'}")
raise
def parse_json(text: str) -> Union[Dict, List]:
"""解析 JSON"""
try:
cleaned = clean_json_response(text)
return json.loads(cleaned)
except Exception as e:
logger.error(f"❌ parse_json 出错: {e}")
logger.error(f" 原始文本长度: {len(text) if text else 0}")
logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}")
raise
+19 -5
View File
@@ -175,20 +175,34 @@ class MCPTestService:
db=db_session
)
ai_response = await ai_service.generate_text(
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
# 因此这里需要特殊处理
accumulated_text = ""
tool_calls = None
async for chunk in ai_service.generate_text_stream(
prompt=prompts["user"],
system_prompt=prompts["system"],
tools=openai_tools,
tool_choice="required"
)
):
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
if isinstance(chunk, dict):
if "tool_calls" in chunk:
tool_calls = chunk["tool_calls"]
if "content" in chunk:
accumulated_text += chunk.get("content", "")
else:
accumulated_text += chunk
# 5. 检查AI是否返回工具调用
if not ai_response.get("tool_calls"):
if not tool_calls:
logger.error(f"❌ AI未返回工具调用")
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
@@ -198,7 +212,7 @@ class MCPTestService:
)
# 6. 解析工具调用
tool_call = ai_response["tool_calls"][0]
tool_call = tool_calls[0]
function = tool_call["function"]
tool_name = function["name"]
test_arguments = function["arguments"]
+14 -1
View File
@@ -386,17 +386,30 @@ class MCPToolService:
try:
# 解析插件名和工具名
logger.debug(f"🔍 解析工具名称: {function_name}")
if "_" in function_name:
plugin_name, tool_name = function_name.split("_", 1)
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
else:
raise ValueError(f"无效的工具名称格式: {function_name}")
# 解析参数
arguments_str = tool_call["function"]["arguments"]
logger.debug(f"🔍 解析参数:")
logger.debug(f" 原始类型: {type(arguments_str)}")
logger.debug(f" 原始内容: {arguments_str}")
if isinstance(arguments_str, str):
arguments = json.loads(arguments_str)
try:
arguments = json.loads(arguments_str)
logger.debug(f" ✅ JSON解析成功: {arguments}")
except json.JSONDecodeError as je:
logger.error(f" ❌ JSON解析失败: {je}")
logger.error(f" 原始字符串: '{arguments_str}'")
raise ValueError(f"参数JSON解析失败: {je}")
else:
arguments = arguments_str
logger.debug(f" 直接使用dict类型参数")
logger.info(
f"执行工具: {plugin_name}.{tool_name}, "
+6 -15
View File
@@ -71,24 +71,15 @@ class PlotAnalyzer:
# 调用AI进行分析
# 注意:不指定max_tokens,使用用户在设置中配置的值
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
response = await self.ai_service.generate_text(
accumulated_text = ""
async for chunk in self.ai_service.generate_text_stream(
prompt=prompt,
temperature=0.3 # 降低温度以获得更稳定的JSON输出
)
):
accumulated_text += chunk
# 🔍 添加调试日志:查看AI返回的原始内容
# logger.info(f"🔍 AI返回类型: {type(response)}")
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
# 从返回的字典中提取content字段
if isinstance(response, dict):
response_text = response.get('content', '')
if not response_text:
logger.error("❌ AI返回的字典中没有content字段或content为空")
return None
else:
# 兼容旧的字符串返回格式
response_text = response
# 提取内容
response_text = accumulated_text
# 解析JSON结果
analysis_result = self._parse_analysis_response(response_text)
+10 -6
View File
@@ -133,14 +133,16 @@ class PlotExpansionService:
# 调用AI生成章节规划
logger.info(f"调用AI生成章节规划...")
ai_response = await self.ai_service.generate_text(
accumulated_text = ""
async for chunk in self.ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
)
):
accumulated_text += chunk
# 提取内容
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
# 解析AI响应
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
@@ -236,14 +238,16 @@ class PlotExpansionService:
# 调用AI生成当前批次
logger.info(f"调用AI生成第{batch_num + 1}批...")
ai_response = await self.ai_service.generate_text(
accumulated_text = ""
async for chunk in self.ai_service.generate_text_stream(
prompt=prompt,
provider=provider,
model=model
)
):
accumulated_text += chunk
# 提取内容
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
ai_content = accumulated_text
# 解析AI响应
batch_plans = self._parse_expansion_response(ai_content, outline.id)
+108 -172
View File
@@ -6,142 +6,6 @@ import json
class WritingStyleManager:
"""写作风格管理器"""
# 预设风格配置
PRESET_STYLES = {
"natural": {
"name": "自然沉浸 (Natural & Immersive)",
"description": "祛除翻译腔,强调生活质感,像呼吸一样自然的叙事",
"prompt_content": """
### 核心指令:自然沉浸风格
请模拟人类作家在放松状态下的写作通过以下规则消除AI味
1. **拒绝翻译腔与书面化**
- 严禁使用一种...的感觉随着...与此同时等连接词
- 多用短句和流水句模拟人类视线的移动和思维的跳跃
- 口语化叙述但不要滥用语气词而是通过句子的长短节奏来体现语气
2. **生活化的颗粒度**
- 描写不要宏大要聚焦在具体的微小的生活细节杯子上的水渍衣服的褶皱
- 允许逻辑上的适度松散不要让每句话都像说明书一样严丝合缝
3. **具体的展示**
- 不要写他很生气要写他把烟头按灭在还没吃完的米饭里
- 避免使用抽象的形容词巨大的美丽的悲伤的必须用名词和动词来承载画面
"""
},
"classical": {
"name": "古典雅致 (Classical & Elegant)",
"description": "白话文与古典韵味的结合,强调留白与炼字",
"prompt_content": """
### 核心指令:古典雅致风格
请模仿民国时期或古典白话小说的笔触构建端庄且富有余味的叙事
1. **炼字与韵律**
- 尽量使用双音节词或四字短语但严禁堆砌辞藻
- 注重句子的声调韵律读起来要有金石之声或流水之韵
- 适当使用倒装句或定语后置增加古雅感
2. **克制的修辞**
- 少用现代的比喻像机器一样多用取自自然的比喻如风过林
- **意在言外**不要把话说透留三分余地写景即是写情不要将情感直接剖白
3. **禁忌**
- 严禁使用现代科技词汇除非题材需要网络用语或过于西化的句式如长定语从句
- 避免滥用之乎者也追求的是神似而非生硬的半文半白
"""
},
"modern": {
"name": "冷硬现代 (Modern & Hard-boiled)",
"description": "海明威式的冰山理论,节奏极快,零度情感",
"prompt_content": """
### 核心指令:冷硬现代风格
请采用极简主义零度写作手法去除所有矫饰
1. **冰山理论**
- **只写动作和对话完全剔除心理描写和形容词堆砌**
- 不要告诉读者角色感觉如何通过角色的反应和环境的冷峻反馈来体现
2. **电影蒙太奇节奏**
- 句子要短像手术刀一样切开场景
- 段落之间快速切换不要用过渡句连接直接跳切
3. **高信息密度**
- 删除所有废话如果一个词删掉不影响理解就删掉它
- 多用名词和强动词Strong Verbs少用副词Adverbs例如不要写他重重地关上门他摔上了门
"""
},
"poetic": {
"name": "意识流 (Stream of Consciousness)",
"description": "注重感官通感与内心独白,打破现实与幻想的边界",
"prompt_content": """
### 核心指令:意识流/诗意风格
请侧重于主观感受的流动而非客观事实的记录
1. **通感与陌生化**
- 打通五感听到了颜色的声音闻到了悲伤的气味
- 使用陌生化的语言把熟悉的事物写得陌生迫使读者重新审视
2. **情绪的具象化**
- **绝对禁止**直接出现开心痛苦等抽象词汇
- 必须寻找客观对应物Objective Correlative将情绪投射到具体的景物上生锈的铁轨发霉的橘子
3. **流动的句式**
- 句子可以很长包含多重意象的叠加
- 允许思维的非线性跳跃模拟梦境或深层潜意识的逻辑
"""
},
"concise": {
"name": "白描速写 (Sketch & Concise)",
"description": "只有骨架的叙事,强调绝对的精准和功能性",
"prompt_content": """
### 核心指令:白描速写风格
请像速写画家一样只勾勒线条不涂抹色彩
1. **功能性第一**
- 每一句话必须推动情节或者揭示关键信息
- 如果一句话只是为了渲染气氛删掉它
2. **主谓宾结构**
- 尽量使用简单的主谓宾结构减少修饰语
- 避免复杂的从句和嵌套结构
3. **直击核心**
- 对话直接进入主题去除寒暄和废话
- 环境描写仅限于对情节有物理影响的物体挡路的石头藏在桌下的枪
"""
},
"vivid": {
"name": "感官特写 (Sensory & Vivid)",
"description": "高分辨率的描写,强调材质、光影和微观细节",
"prompt_content": """
### 核心指令:感官特写风格
请将镜头推到特写级别Macro Lens捕捉常人忽略的细节
1. **反套路细节**
- 不要写大众化的细节蓝天白云要写具有**独特性**的细节云层边缘那抹像淤青一样的灰紫色
- 关注物体的**质感Texture**粗糙的粘稠的冰凉的颗粒感的
2. **动态捕捉**
- 不要写静止的画面要写光影的流变灰尘的飞舞肌肉的抽动
- 让读者产生生理性的反应痛感饥饿感窒息感
3. **禁用词汇**
- 禁止使用映入眼帘宛如画卷等陈词滥调
- 必须用具体的动词带动感官描写
"""
}
}
@classmethod
def get_preset_style(cls, preset_id: str) -> Optional[Dict[str, str]]:
"""获取预设风格配置"""
return cls.PRESET_STYLES.get(preset_id)
@classmethod
def get_all_presets(cls) -> Dict[str, Dict[str, str]]:
"""获取所有预设风格"""
return cls.PRESET_STYLES
@staticmethod
def apply_style_to_prompt(base_prompt: str, style_content: str) -> str:
"""
@@ -692,9 +556,8 @@ class PromptService:
6. **承上启下**
- 开头自然衔接上一章结尾但不重复上一章内容
- 结尾为下一章做好铺垫
6. **记忆系统使用指南**
7. **记忆系统使用指南**
- **最近章节记忆**保持情节连贯注意角色状态和剧情发展
- **语义相关记忆**参考相似情节的处理方式
- **未完结伏笔**适当时机可以回收伏笔制造呼应效果
@@ -1308,16 +1171,15 @@ class PromptService:
- 如果参数名是 snake_case next_thought就使用 snake_case
- 保持与 schema 中定义的完全一致包括大小写和命名风格"""
# 灵感模式提示词字典
INSPIRATION_PROMPTS = {
"title": {
"system": """你是一位专业的小说创作顾问。
# 灵感模式 - 书名生成(系统提示词
INSPIRATION_TITLE_SYSTEM = """你是一位专业的小说创作顾问。
用户的原始想法{initial_idea}
请根据用户的想法生成6个吸引人的书名建议要求
1. 紧扣用户的原始想法和核心故事构思
2. 富有创意和吸引力
3. 涵盖不同的风格倾向
4. 书名中不要带有"《》"符号
返回JSON格式
{{
@@ -1325,11 +1187,13 @@ class PromptService:
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
}}
只返回纯JSON不要有其他文字""",
"user": "用户的想法:{initial_idea}\n请生成6个书名建议"
},
"description": {
"system": """你是一位专业的小说创作顾问。
只返回纯JSON不要有其他文字"""
# 灵感模式 - 书名生成(用户提示词)
INSPIRATION_TITLE_USER = "用户的想法:{initial_idea}\n请生成6个书名建议"
# 灵感模式 - 简介生成(系统提示词)
INSPIRATION_DESCRIPTION_SYSTEM = """你是一位专业的小说创作顾问。
用户的原始想法{initial_idea}
已确定的书名{title}
@@ -1343,11 +1207,13 @@ class PromptService:
返回JSON格式
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
只返回纯JSON不要有其他文字不要换行""",
"user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
},
"theme": {
"system": """你是一位专业的小说创作顾问。
只返回纯JSON不要有其他文字不要换行"""
# 灵感模式 - 简介生成(用户提示词)
INSPIRATION_DESCRIPTION_USER = "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
# 灵感模式 - 主题生成(系统提示词)
INSPIRATION_THEME_SYSTEM = """你是一位专业的小说创作顾问。
用户的原始想法{initial_idea}
小说信息
- 书名{title}
@@ -1363,11 +1229,13 @@ class PromptService:
返回JSON格式
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
只返回纯JSON不要有其他文字不要换行""",
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
},
"genre": {
"system": """你是一位专业的小说创作顾问。
只返回纯JSON不要有其他文字不要换行"""
# 灵感模式 - 主题生成(用户提示词)
INSPIRATION_THEME_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
# 灵感模式 - 类型生成(系统提示词)
INSPIRATION_GENRE_SYSTEM = """你是一位专业的小说创作顾问。
用户的原始想法{initial_idea}
小说信息
- 书名{title}
@@ -1384,10 +1252,10 @@ class PromptService:
返回JSON格式
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
只返回紧凑的纯JSON不要换行不要有其他文字""",
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
}
}
只返回紧凑的纯JSON不要换行不要有其他文字"""
# 灵感模式 - 类型生成(用户提示词)
INSPIRATION_GENRE_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
# 灵感模式智能补全提示词
INSPIRATION_QUICK_COMPLETE = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
@@ -1887,7 +1755,26 @@ class PromptService:
@classmethod
def get_inspiration_prompt(cls, step: str) -> Optional[Dict[str, str]]:
"""获取灵感模式指定步骤的提示词"""
return cls.INSPIRATION_PROMPTS.get(step)
# 根据步骤名称返回对应的system和user提示词
step_map = {
"title": {
"system": cls.INSPIRATION_TITLE_SYSTEM,
"user": cls.INSPIRATION_TITLE_USER
},
"description": {
"system": cls.INSPIRATION_DESCRIPTION_SYSTEM,
"user": cls.INSPIRATION_DESCRIPTION_USER
},
"theme": {
"system": cls.INSPIRATION_THEME_SYSTEM,
"user": cls.INSPIRATION_THEME_USER
},
"genre": {
"system": cls.INSPIRATION_GENRE_SYSTEM,
"user": cls.INSPIRATION_GENRE_USER
}
}
return step_map.get(step)
@classmethod
def get_inspiration_quick_complete_prompt(cls, existing: str) -> Dict[str, str]:
@@ -1997,17 +1884,12 @@ class PromptService:
# 2. 降级到系统默认模板
logger.info(f"⚪ 使用系统默认提示词: user_id={user_id}, template_key={template_key} (未找到自定义模板)")
# 特殊处理灵感模式的提示词(存储在INSPIRATION_PROMPTS字典中
# 特殊处理灵感模式的提示词(直接从类属性获取
if template_key.startswith("INSPIRATION_"):
# 提取步骤名称(如 INSPIRATION_TITLE -> title
step = template_key.replace("INSPIRATION_", "").lower()
inspiration_prompt = cls.INSPIRATION_PROMPTS.get(step)
if inspiration_prompt:
# 返回JSON格式的提示词
return json.dumps(inspiration_prompt, ensure_ascii=False)
# 如果是INSPIRATION_QUICK_COMPLETE
if template_key == "INSPIRATION_QUICK_COMPLETE":
return cls.INSPIRATION_QUICK_COMPLETE
# 直接从类属性获取
template_content = getattr(cls, template_key, None)
if template_content:
return template_content
# 其他模板直接从类属性获取
template_content = getattr(cls, template_key, None)
@@ -2182,6 +2064,60 @@ class PromptService:
"category": "世界构建",
"description": "根据世界观自动生成完整的职业体系,包括主职业和副职业",
"parameters": ["title", "genre", "theme", "time_period", "location", "atmosphere", "rules"]
},
"INSPIRATION_TITLE_SYSTEM": {
"name": "灵感模式-书名生成(系统提示词)",
"category": "灵感模式",
"description": "根据用户的原始想法生成6个书名建议的系统提示词",
"parameters": ["initial_idea"]
},
"INSPIRATION_TITLE_USER": {
"name": "灵感模式-书名生成(用户提示词)",
"category": "灵感模式",
"description": "根据用户的原始想法生成6个书名建议的用户提示词",
"parameters": ["initial_idea"]
},
"INSPIRATION_DESCRIPTION_SYSTEM": {
"name": "灵感模式-简介生成(系统提示词)",
"category": "灵感模式",
"description": "根据用户想法和书名生成6个简介选项的系统提示词",
"parameters": ["initial_idea", "title"]
},
"INSPIRATION_DESCRIPTION_USER": {
"name": "灵感模式-简介生成(用户提示词)",
"category": "灵感模式",
"description": "根据用户想法和书名生成6个简介选项的用户提示词",
"parameters": ["initial_idea", "title"]
},
"INSPIRATION_THEME_SYSTEM": {
"name": "灵感模式-主题生成(系统提示词)",
"category": "灵感模式",
"description": "根据书名和简介生成6个深刻的主题选项的系统提示词",
"parameters": ["initial_idea", "title", "description"]
},
"INSPIRATION_THEME_USER": {
"name": "灵感模式-主题生成(用户提示词)",
"category": "灵感模式",
"description": "根据书名和简介生成6个深刻的主题选项的用户提示词",
"parameters": ["initial_idea", "title", "description"]
},
"INSPIRATION_GENRE_SYSTEM": {
"name": "灵感模式-类型生成(系统提示词)",
"category": "灵感模式",
"description": "根据小说信息生成6个合适的类型标签的系统提示词",
"parameters": ["initial_idea", "title", "description", "theme"]
},
"INSPIRATION_GENRE_USER": {
"name": "灵感模式-类型生成(用户提示词)",
"category": "灵感模式",
"description": "根据小说信息生成6个合适的类型标签的用户提示词",
"parameters": ["initial_idea", "title", "description", "theme"]
},
"INSPIRATION_QUICK_COMPLETE": {
"name": "灵感模式-智能补全",
"category": "灵感模式",
"description": "根据用户提供的部分信息智能补全完整的小说方案",
"parameters": ["existing"]
}
}
+16 -5
View File
@@ -23,11 +23,22 @@ class SSEResponse:
Returns:
格式化后的SSE消息字符串
"""
message = ""
if event:
message += f"event: {event}\n"
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
return message
try:
message = ""
if event:
message += f"event: {event}\n"
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
return message
except Exception as e:
logger.error(f"❌ SSE格式化失败: {type(e).__name__}: {e}")
logger.error(f" data类型: {type(data)}")
logger.error(f" data内容: {str(data)[:500]}")
# 返回错误消息而不是崩溃
error_message = ""
if event:
error_message += f"event: {event}\n"
error_message += f'data: {{"type": "error", "error": "SSE格式化失败: {str(e)}", "code": 500}}\n\n'
return error_message
@staticmethod
async def send_progress(