update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
+30
@@ -0,0 +1,30 @@
|
|||||||
|
"""添加system_prompt字段到settings表
|
||||||
|
|
||||||
|
Revision ID: a7e4408e1d5b
|
||||||
|
Revises: e411428f00c0
|
||||||
|
Create Date: 2025-12-27 15:41:22.310160
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'a7e4408e1d5b'
|
||||||
|
down_revision: Union[str, None] = 'e411428f00c0'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('settings', sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('settings', 'system_prompt')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
"""初始化SQLite预置数据
|
||||||
|
|
||||||
|
Revision ID: a1b2c3d4e5f6
|
||||||
|
Revises: fbeb1038c728
|
||||||
|
Create Date: 2025-12-27 08:56:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import table, column, String, Integer, Text
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'a1b2c3d4e5f6'
|
||||||
|
down_revision: Union[str, None] = 'fbeb1038c728'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""插入预置数据"""
|
||||||
|
|
||||||
|
# ==================== 1. 插入关系类型数据 ====================
|
||||||
|
relationship_types_table = table(
|
||||||
|
'relationship_types',
|
||||||
|
column('name', String),
|
||||||
|
column('category', String),
|
||||||
|
column('reverse_name', String),
|
||||||
|
column('intimacy_range', String),
|
||||||
|
column('icon', String),
|
||||||
|
column('description', Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
relationship_types_data = [
|
||||||
|
# 家庭关系
|
||||||
|
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
|
||||||
|
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
|
||||||
|
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
|
||||||
|
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
|
||||||
|
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
|
||||||
|
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
|
||||||
|
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
|
||||||
|
|
||||||
|
# 社交关系
|
||||||
|
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
|
||||||
|
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
|
||||||
|
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
|
||||||
|
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
|
||||||
|
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
|
||||||
|
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
|
||||||
|
|
||||||
|
# 职业关系
|
||||||
|
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
|
||||||
|
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
|
||||||
|
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
|
||||||
|
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
|
||||||
|
|
||||||
|
# 敌对关系
|
||||||
|
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
|
||||||
|
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
|
||||||
|
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
|
||||||
|
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡", "description": "宿命之敌"},
|
||||||
|
]
|
||||||
|
|
||||||
|
op.bulk_insert(relationship_types_table, relationship_types_data)
|
||||||
|
print(f"✅ SQLite: 已插入 {len(relationship_types_data)} 条关系类型数据")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 2. 插入全局写作风格预设 ====================
|
||||||
|
writing_styles_table = table(
|
||||||
|
'writing_styles',
|
||||||
|
column('user_id', String),
|
||||||
|
column('name', String),
|
||||||
|
column('style_type', String),
|
||||||
|
column('preset_id', String),
|
||||||
|
column('description', Text),
|
||||||
|
column('prompt_content', Text),
|
||||||
|
column('order_index', Integer),
|
||||||
|
)
|
||||||
|
|
||||||
|
writing_styles_data = [
|
||||||
|
{
|
||||||
|
"user_id": None, # NULL 表示全局预设
|
||||||
|
"name": "自然流畅",
|
||||||
|
"style_type": "preset",
|
||||||
|
"preset_id": "natural",
|
||||||
|
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
|
||||||
|
"prompt_content": """写作风格要求:
|
||||||
|
1. 语言简洁明快,贴近现代口语
|
||||||
|
2. 多用短句,节奏流畅
|
||||||
|
3. 注重情感细节的自然流露
|
||||||
|
4. 避免过度修饰和复杂句式""",
|
||||||
|
"order_index": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": None,
|
||||||
|
"name": "古典优雅",
|
||||||
|
"style_type": "preset",
|
||||||
|
"preset_id": "classical",
|
||||||
|
"description": "古典文雅的写作风格,适合古装、仙侠题材",
|
||||||
|
"prompt_content": """写作风格要求:
|
||||||
|
1. 使用文言、半文言或典雅的白话
|
||||||
|
2. 适当运用古典诗词意象
|
||||||
|
3. 注重意境营造和韵味
|
||||||
|
4. 对话和描写保持古典美感""",
|
||||||
|
"order_index": 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": None,
|
||||||
|
"name": "现代简约",
|
||||||
|
"style_type": "preset",
|
||||||
|
"preset_id": "modern",
|
||||||
|
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
|
||||||
|
"prompt_content": """写作风格要求:
|
||||||
|
1. 语言直白简练,信息密度高
|
||||||
|
2. 多用对话推进情节
|
||||||
|
3. 避免冗长描写,突出关键动作
|
||||||
|
4. 节奏明快,适合快速阅读""",
|
||||||
|
"order_index": 3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": None,
|
||||||
|
"name": "文艺细腻",
|
||||||
|
"style_type": "preset",
|
||||||
|
"preset_id": "literary",
|
||||||
|
"description": "文艺细腻风格,注重心理描写和氛围营造",
|
||||||
|
"prompt_content": """写作风格要求:
|
||||||
|
1. 注重心理活动和情感细节
|
||||||
|
2. 善用环境描写烘托氛围
|
||||||
|
3. 语言优美,富有文学性
|
||||||
|
4. 适当使用比喻、象征等修辞手法""",
|
||||||
|
"order_index": 4
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": None,
|
||||||
|
"name": "紧张悬疑",
|
||||||
|
"style_type": "preset",
|
||||||
|
"preset_id": "suspense",
|
||||||
|
"description": "紧张悬疑风格,适合推理、惊悚题材",
|
||||||
|
"prompt_content": """写作风格要求:
|
||||||
|
1. 营造紧张压迫的氛围
|
||||||
|
2. 多用短句加快节奏
|
||||||
|
3. 善于设置悬念和伏笔
|
||||||
|
4. 注重细节描写,为推理埋下线索""",
|
||||||
|
"order_index": 5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": None,
|
||||||
|
"name": "幽默诙谐",
|
||||||
|
"style_type": "preset",
|
||||||
|
"preset_id": "humorous",
|
||||||
|
"description": "幽默诙谐风格,适合轻松搞笑题材",
|
||||||
|
"prompt_content": """写作风格要求:
|
||||||
|
1. 语言活泼风趣,善用俏皮话
|
||||||
|
2. 注重对话的喜剧效果
|
||||||
|
3. 适当夸张和反转制造笑点
|
||||||
|
4. 保持轻松愉快的基调""",
|
||||||
|
"order_index": 6
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
op.bulk_insert(writing_styles_table, writing_styles_data)
|
||||||
|
print(f"✅ SQLite: 已插入 {len(writing_styles_data)} 条全局写作风格预设")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""删除预置数据"""
|
||||||
|
|
||||||
|
# 删除写作风格预设(只删除全局预设)
|
||||||
|
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
|
||||||
|
print("✅ SQLite: 已删除全局写作风格预设")
|
||||||
|
|
||||||
|
# 删除关系类型
|
||||||
|
op.execute("DELETE FROM relationship_types")
|
||||||
|
print("✅ SQLite: 已删除关系类型数据")
|
||||||
+34
@@ -0,0 +1,34 @@
|
|||||||
|
"""添加system_prompt字段到settings表
|
||||||
|
|
||||||
|
Revision ID: 7899f8d4d839
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2025-12-27 17:00:35.440551
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '7899f8d4d839'
|
||||||
|
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('system_prompt')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -309,13 +309,37 @@ async def generate_career_system(
|
|||||||
7. 只返回纯JSON,不要添加任何解释文字
|
7. 只返回纯JSON,不要添加任何解释文字
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 30)
|
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
|
||||||
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用AI生成
|
# 使用流式生成替代非流式
|
||||||
result = await user_ai_service.generate_text(prompt=prompt)
|
ai_response = ""
|
||||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
chunk_count = 0
|
||||||
|
last_progress = 10
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
|
chunk_count += 1
|
||||||
|
ai_response += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 平滑更新进度(10-90%,AI生成占60%)
|
||||||
|
# 每10个chunk增加约1%的进度,最多到90%
|
||||||
|
if chunk_count % 10 == 0:
|
||||||
|
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
|
||||||
|
current_progress = min(10 + (chunk_count // 10), 90)
|
||||||
|
if current_progress > last_progress:
|
||||||
|
last_progress = current_progress
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
|
||||||
|
current_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
except Exception as ai_error:
|
except Exception as ai_error:
|
||||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||||
@@ -326,7 +350,7 @@ async def generate_career_system(
|
|||||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("解析AI响应...", 50)
|
yield await SSEResponse.send_progress("解析AI响应...", 91)
|
||||||
|
|
||||||
# 清洗并解析JSON
|
# 清洗并解析JSON
|
||||||
try:
|
try:
|
||||||
@@ -339,7 +363,7 @@ async def generate_career_system(
|
|||||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存主职业...", 60)
|
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
|
||||||
|
|
||||||
# 保存主职业
|
# 保存主职业
|
||||||
main_careers_created = []
|
main_careers_created = []
|
||||||
@@ -371,7 +395,7 @@ async def generate_career_system(
|
|||||||
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存副职业...", 80)
|
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
|
||||||
|
|
||||||
# 保存副职业
|
# 保存副职业
|
||||||
sub_careers_created = []
|
sub_careers_created = []
|
||||||
|
|||||||
+47
-12
@@ -1070,8 +1070,7 @@ async def analyze_chapter_background(
|
|||||||
|
|
||||||
if career_update_result['updated_count'] > 0:
|
if career_update_result['updated_count'] > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息: "
|
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息"
|
||||||
f"{', '.join(career_update_result['updated_characters'])}"
|
|
||||||
)
|
)
|
||||||
if career_update_result['changes']:
|
if career_update_result['changes']:
|
||||||
for change in career_update_result['changes']:
|
for change in career_update_result['changes']:
|
||||||
@@ -1445,7 +1444,7 @@ async def generate_chapter_content_stream(
|
|||||||
user_id=current_user_id,
|
user_id=current_user_id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -1596,10 +1595,24 @@ async def generate_chapter_content_stream(
|
|||||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||||
|
|
||||||
# 发送开始生成的进度
|
# 发送开始生成的进度
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 35, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
|
||||||
|
system_prompt_with_style = None
|
||||||
|
if style_content:
|
||||||
|
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||||
|
|
||||||
|
{style_content}
|
||||||
|
|
||||||
|
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||||
|
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||||
|
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||||
|
|
||||||
# 准备生成参数
|
# 准备生成参数
|
||||||
generate_kwargs = {"prompt": prompt}
|
generate_kwargs = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||||
|
}
|
||||||
if custom_model:
|
if custom_model:
|
||||||
logger.info(f" 使用自定义模型: {custom_model}")
|
logger.info(f" 使用自定义模型: {custom_model}")
|
||||||
generate_kwargs["model"] = custom_model
|
generate_kwargs["model"] = custom_model
|
||||||
@@ -1618,11 +1631,14 @@ async def generate_chapter_content_stream(
|
|||||||
# 发送内容块
|
# 发送内容块
|
||||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 每20个chunk发送一次进度更新(提高频率)
|
# 每5个chunk发送一次进度更新(10-95%,更平滑)
|
||||||
if chunk_count % 20 == 0:
|
if chunk_count % 5 == 0:
|
||||||
current_word_count = len(full_content)
|
current_word_count = len(full_content)
|
||||||
# 根据目标字数估算进度(40%起步,最高95%,为后续保存留5%)
|
# 优化进度计算:使用更平滑的递增方式
|
||||||
estimated_progress = min(95, 40 + int((current_word_count / target_word_count) * 55))
|
# 基于chunk数量和字数的混合计算,避免大幅跳跃
|
||||||
|
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
|
||||||
|
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
|
||||||
|
estimated_progress = min(95, 10 + chunk_progress + word_progress)
|
||||||
|
|
||||||
# 只在进度变化时发送
|
# 只在进度变化时发送
|
||||||
if estimated_progress > last_progress:
|
if estimated_progress > last_progress:
|
||||||
@@ -1636,10 +1652,14 @@ async def generate_chapter_content_stream(
|
|||||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||||
last_progress = estimated_progress
|
last_progress = estimated_progress
|
||||||
|
|
||||||
|
# 每20个chunk发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
await asyncio.sleep(0) # 让出控制权
|
await asyncio.sleep(0) # 让出控制权
|
||||||
|
|
||||||
# 发送保存进度
|
# 发送保存进度
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 98, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 更新章节内容到数据库
|
# 更新章节内容到数据库
|
||||||
old_word_count = current_chapter.word_count or 0
|
old_word_count = current_chapter.word_count or 0
|
||||||
@@ -1696,7 +1716,7 @@ async def generate_chapter_content_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 发送最终进度100%
|
# 发送最终进度100%
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 100, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 发送完成事件(包含分析任务ID)
|
# 发送完成事件(包含分析任务ID)
|
||||||
completion_data = {
|
completion_data = {
|
||||||
@@ -2880,15 +2900,30 @@ async def generate_single_chapter_for_batch(
|
|||||||
else:
|
else:
|
||||||
prompt = base_prompt
|
prompt = base_prompt
|
||||||
|
|
||||||
|
# 🎨 方案一:将写作风格注入到系统提示词(批量生成)
|
||||||
|
system_prompt_with_style = None
|
||||||
|
if style_content:
|
||||||
|
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||||
|
|
||||||
|
{style_content}
|
||||||
|
|
||||||
|
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||||
|
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||||
|
logger.info(f"✅ 批量生成 - 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||||
|
|
||||||
# 非流式生成内容
|
# 非流式生成内容
|
||||||
full_content = ""
|
full_content = ""
|
||||||
# 准备生成参数
|
# 准备生成参数
|
||||||
generate_kwargs = {"prompt": prompt}
|
generate_kwargs = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||||
|
}
|
||||||
# 如果传入了自定义模型,使用指定的模型
|
# 如果传入了自定义模型,使用指定的模型
|
||||||
if custom_model:
|
if custom_model:
|
||||||
generate_kwargs["model"] = custom_model
|
generate_kwargs["model"] = custom_model
|
||||||
logger.info(f" 批量生成使用自定义模型: {custom_model}")
|
logger.info(f" 批量生成使用自定义模型: {custom_model}")
|
||||||
|
|
||||||
|
# 批量生成中的流式生成(非SSE,不需要修改进度显示)
|
||||||
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
|
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
|
|
||||||
|
|||||||
+120
-20
@@ -662,10 +662,10 @@ async def generate_character_stream(
|
|||||||
user_id = getattr(http_request.state, 'user_id', None)
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
project = await verify_project_access(request.project_id, user_id, db)
|
project = await verify_project_access(request.project_id, user_id, db)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("开始生成角色...", 0)
|
yield await SSEResponse.send_progress("开始生成角色...", 1)
|
||||||
|
|
||||||
# 获取已存在的角色列表
|
# 获取已存在的角色列表
|
||||||
yield await SSEResponse.send_progress("获取项目上下文...", 10)
|
yield await SSEResponse.send_progress("获取项目上下文...", 2)
|
||||||
|
|
||||||
existing_chars_result = await db.execute(
|
existing_chars_result = await db.execute(
|
||||||
select(Character)
|
select(Character)
|
||||||
@@ -757,7 +757,7 @@ async def generate_character_stream(
|
|||||||
- 其他要求:{request.requirements or '无'}
|
- 其他要求:{request.requirements or '无'}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
yield await SSEResponse.send_progress("构建AI提示词...", 3)
|
||||||
|
|
||||||
# 获取自定义提示词模板
|
# 获取自定义提示词模板
|
||||||
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
||||||
@@ -768,11 +768,14 @@ async def generate_character_stream(
|
|||||||
user_input=user_input
|
user_input=user_input
|
||||||
)
|
)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
|
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
|
||||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||||
|
ai_response = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.mcp_tool_service import mcp_tool_service
|
from app.services.mcp_tool_service import mcp_tool_service
|
||||||
@@ -789,7 +792,7 @@ async def generate_character_stream(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # 减少为1轮,避免超时
|
max_tool_rounds=2,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -797,22 +800,119 @@ async def generate_character_stream(
|
|||||||
|
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
ai_response = result.get('content', '')
|
ai_response = result.get('content', '')
|
||||||
if result.get('tool_calls_made', 0) > 0:
|
finish_reason = result.get('finish_reason', '')
|
||||||
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
|
tool_calls_made = result.get('tool_calls_made', 0)
|
||||||
|
|
||||||
|
# 🔧 修复:检查工具调用是否真正成功
|
||||||
|
if tool_calls_made > 0:
|
||||||
|
if finish_reason == 'tool_error':
|
||||||
|
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||||
|
# 工具调用失败,重新用基础模式生成
|
||||||
|
ai_response = ""
|
||||||
|
elif not ai_response.strip():
|
||||||
|
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||||
|
# 工具调用成功但返回空内容,重新生成
|
||||||
|
ai_response = ""
|
||||||
|
else:
|
||||||
|
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||||
|
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||||
|
chunk_size = 50
|
||||||
|
for i in range(0, len(ai_response), chunk_size):
|
||||||
|
chunk = ai_response[i:i+chunk_size]
|
||||||
|
chunk_count += 1
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
if chunk_count % 3 == 0:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||||
|
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 跳过后续的流式生成
|
||||||
|
ai_response = result.get('content', '')
|
||||||
else:
|
else:
|
||||||
ai_response = result
|
ai_response = result
|
||||||
|
|
||||||
|
# 如果MCP调用失败或返回空,继续走流式生成
|
||||||
|
if not ai_response or not ai_response.strip():
|
||||||
|
logger.info(f"🔄 开始流式生成...")
|
||||||
|
ai_response = ""
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
|
chunk_count += 1
|
||||||
|
ai_response += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||||
|
10 + min(chunk_count // 2, 85)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
else:
|
else:
|
||||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
|
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||||
result = await user_ai_service.generate_text(prompt=prompt)
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
chunk_count += 1
|
||||||
|
ai_response += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||||
|
10 + min(chunk_count // 2, 85)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
except Exception as mcp_error:
|
except Exception as mcp_error:
|
||||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
|
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||||
result = await user_ai_service.generate_text(prompt=prompt)
|
ai_response = ""
|
||||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
|
chunk_count += 1
|
||||||
|
ai_response += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||||
|
10 + min(chunk_count // 2, 85)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
else:
|
else:
|
||||||
result = await user_ai_service.generate_text(prompt=prompt)
|
logger.debug(f"未登录用户,使用流式基础模式")
|
||||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
|
chunk_count += 1
|
||||||
|
ai_response += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||||
|
10 + min(chunk_count // 2, 85)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
except Exception as ai_error:
|
except Exception as ai_error:
|
||||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||||
@@ -823,7 +923,7 @@ async def generate_character_stream(
|
|||||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||||
|
|
||||||
# ✅ 使用统一的 JSON 清洗方法
|
# ✅ 使用统一的 JSON 清洗方法
|
||||||
try:
|
try:
|
||||||
@@ -836,7 +936,7 @@ async def generate_character_stream(
|
|||||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("创建角色记录...", 75)
|
yield await SSEResponse.send_progress("创建角色记录...", 97)
|
||||||
|
|
||||||
# 转换traits
|
# 转换traits
|
||||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||||
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
|
|||||||
|
|
||||||
# 如果是组织,创建Organization详情
|
# 如果是组织,创建Organization详情
|
||||||
if is_organization:
|
if is_organization:
|
||||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||||
|
|
||||||
org_check = await db.execute(
|
org_check = await db.execute(
|
||||||
select(Organization).where(Organization.character_id == character.id)
|
select(Organization).where(Organization.character_id == character.id)
|
||||||
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
|
|||||||
|
|
||||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||||
|
|
||||||
# 记录生成历史
|
# 记录生成历史
|
||||||
history = GenerationHistory(
|
history = GenerationHistory(
|
||||||
project_id=request.project_id,
|
project_id=request.project_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
|
generated_content=ai_response,
|
||||||
model=user_ai_service.default_model
|
model=user_ai_service.default_model
|
||||||
)
|
)
|
||||||
db.add(history)
|
db.add(history)
|
||||||
|
|||||||
+204
-28
@@ -105,23 +105,27 @@ async def generate_options(
|
|||||||
user_id = getattr(http_request.state, 'user_id', None)
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
|
|
||||||
# 获取对应的提示词模板(根据step确定模板key)
|
# 获取对应的提示词模板(根据step确定模板key)
|
||||||
|
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
|
||||||
template_key_map = {
|
template_key_map = {
|
||||||
"title": "INSPIRATION_TITLE",
|
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||||
"description": "INSPIRATION_DESCRIPTION",
|
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||||
"theme": "INSPIRATION_THEME",
|
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||||
"genre": "INSPIRATION_GENRE"
|
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||||
}
|
}
|
||||||
template_key = template_key_map.get(step)
|
template_keys = template_key_map.get(step)
|
||||||
|
|
||||||
if not template_key:
|
if not template_keys:
|
||||||
return {
|
return {
|
||||||
"error": f"不支持的步骤: {step}",
|
"error": f"不支持的步骤: {step}",
|
||||||
"prompt": "",
|
"prompt": "",
|
||||||
"options": []
|
"options": []
|
||||||
}
|
}
|
||||||
|
|
||||||
# 获取自定义提示词模板
|
system_key, user_key = template_keys
|
||||||
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
|
|
||||||
|
# 获取自定义提示词模板(分别获取 system 和 user)
|
||||||
|
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||||
|
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||||
|
|
||||||
# 准备格式化参数
|
# 准备格式化参数
|
||||||
format_params = {
|
format_params = {
|
||||||
@@ -131,19 +135,9 @@ async def generate_options(
|
|||||||
"theme": context.get("theme", "")
|
"theme": context.get("theme", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
|
# 格式化提示词
|
||||||
# 尝试解析为JSON格式的字典
|
system_prompt = system_template.format(**format_params)
|
||||||
try:
|
user_prompt = user_template.format(**format_params)
|
||||||
prompt_template = json.loads(prompt_template_str)
|
|
||||||
system_prompt = prompt_template["system"].format(**format_params)
|
|
||||||
user_prompt = prompt_template["user"].format(**format_params)
|
|
||||||
except (json.JSONDecodeError, KeyError):
|
|
||||||
# 如果不是JSON格式,降级使用原有方法
|
|
||||||
prompt_template = prompt_service.get_inspiration_prompt(step)
|
|
||||||
if not prompt_template:
|
|
||||||
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
|
|
||||||
system_prompt = prompt_template["system"].format(**format_params)
|
|
||||||
user_prompt = prompt_template["user"].format(**format_params)
|
|
||||||
|
|
||||||
# 如果是重试,在提示词中强调格式要求
|
# 如果是重试,在提示词中强调格式要求
|
||||||
if attempt > 0:
|
if attempt > 0:
|
||||||
@@ -153,13 +147,18 @@ async def generate_options(
|
|||||||
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
|
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
|
||||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||||
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
|
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
|
||||||
response = await ai_service.generate_text(
|
|
||||||
|
# 流式生成并累积文本
|
||||||
|
accumulated_text = ""
|
||||||
|
async for chunk in ai_service.generate_text_stream(
|
||||||
prompt=user_prompt,
|
prompt=user_prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
temperature=temperature
|
temperature=temperature
|
||||||
)
|
):
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
content = response.get("content", "")
|
response = {"content": accumulated_text}
|
||||||
|
content = accumulated_text
|
||||||
logger.info(f"AI返回内容长度: {len(content)}")
|
logger.info(f"AI返回内容长度: {len(content)}")
|
||||||
|
|
||||||
# 解析JSON(使用统一的JSON清洗方法)
|
# 解析JSON(使用统一的JSON清洗方法)
|
||||||
@@ -222,6 +221,180 @@ async def generate_options(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refine-options")
|
||||||
|
async def refine_options(
|
||||||
|
data: Dict[str, Any],
|
||||||
|
http_request: Request,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
ai_service: AIService = Depends(get_user_ai_service)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
基于用户反馈重新生成选项(支持多轮对话)
|
||||||
|
|
||||||
|
Request:
|
||||||
|
{
|
||||||
|
"step": "title", // 当前步骤
|
||||||
|
"context": {
|
||||||
|
"initial_idea": "...",
|
||||||
|
"title": "...",
|
||||||
|
"description": "...",
|
||||||
|
"theme": "..."
|
||||||
|
},
|
||||||
|
"feedback": "我想要更悲剧一些的主题", // 用户反馈
|
||||||
|
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
|
||||||
|
}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"prompt": "引导语",
|
||||||
|
"options": ["新选项1", "新选项2", ...]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
max_retries = 3
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
step = data.get("step", "title")
|
||||||
|
context = data.get("context", {})
|
||||||
|
feedback = data.get("feedback", "")
|
||||||
|
previous_options = data.get("previous_options", [])
|
||||||
|
|
||||||
|
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
|
||||||
|
logger.info(f"用户反馈: {feedback}")
|
||||||
|
|
||||||
|
# 获取用户ID
|
||||||
|
user_id = getattr(http_request.state, 'user_id', None)
|
||||||
|
|
||||||
|
# 获取对应的提示词模板
|
||||||
|
template_key_map = {
|
||||||
|
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||||
|
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||||
|
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||||
|
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||||
|
}
|
||||||
|
template_keys = template_key_map.get(step)
|
||||||
|
|
||||||
|
if not template_keys:
|
||||||
|
return {
|
||||||
|
"error": f"不支持的步骤: {step}",
|
||||||
|
"prompt": "",
|
||||||
|
"options": []
|
||||||
|
}
|
||||||
|
|
||||||
|
system_key, user_key = template_keys
|
||||||
|
|
||||||
|
# 获取自定义提示词模板
|
||||||
|
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||||
|
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||||
|
|
||||||
|
# 准备格式化参数
|
||||||
|
format_params = {
|
||||||
|
"initial_idea": context.get("initial_idea", context.get("description", "")),
|
||||||
|
"title": context.get("title", ""),
|
||||||
|
"description": context.get("description", ""),
|
||||||
|
"theme": context.get("theme", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 格式化提示词
|
||||||
|
system_prompt = system_template.format(**format_params)
|
||||||
|
user_prompt = user_template.format(**format_params)
|
||||||
|
|
||||||
|
# 添加反馈信息到提示词
|
||||||
|
feedback_instruction = f"""
|
||||||
|
|
||||||
|
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
|
||||||
|
「{feedback}」
|
||||||
|
|
||||||
|
之前生成的选项:
|
||||||
|
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
|
||||||
|
|
||||||
|
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
|
||||||
|
注意:
|
||||||
|
1. 仔细理解用户的反馈意图
|
||||||
|
2. 生成的新选项要明显体现用户要求的调整方向
|
||||||
|
3. 保持与已有上下文的一致性
|
||||||
|
4. 确保返回6个有效选项
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt += feedback_instruction
|
||||||
|
|
||||||
|
# 如果是重试,强调格式要求
|
||||||
|
if attempt > 0:
|
||||||
|
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
|
||||||
|
|
||||||
|
# 调用AI生成选项
|
||||||
|
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||||
|
# 反馈生成时使用稍高的temperature以获得更多样化的结果
|
||||||
|
temperature = min(temperature + 0.1, 0.9)
|
||||||
|
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
|
||||||
|
|
||||||
|
# 流式生成并累积文本
|
||||||
|
accumulated_text = ""
|
||||||
|
async for chunk in ai_service.generate_text_stream(
|
||||||
|
prompt=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature
|
||||||
|
):
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
content = accumulated_text
|
||||||
|
logger.info(f"AI返回内容长度: {len(content)}")
|
||||||
|
|
||||||
|
# 解析JSON
|
||||||
|
try:
|
||||||
|
cleaned_content = ai_service._clean_json_response(content)
|
||||||
|
result = json.loads(cleaned_content)
|
||||||
|
|
||||||
|
# 校验返回格式
|
||||||
|
is_valid, error_msg = validate_options_response(result, step)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.info("准备重试...")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"prompt": f"请为【{step}】提供内容:",
|
||||||
|
"options": ["让AI重新生成", "我自己输入"],
|
||||||
|
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次"
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"第{attempt + 1}次JSON解析失败: {e}")
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.info("JSON解析失败,准备重试...")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"prompt": f"请为【{step}】提供内容:",
|
||||||
|
"options": ["让AI重新生成", "我自己输入"],
|
||||||
|
"error": f"AI返回格式错误,已自动重试{max_retries}次"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"第{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logger.info("发生异常,准备重试...")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": str(e),
|
||||||
|
"prompt": "生成失败,请重试",
|
||||||
|
"options": ["重新生成", "我自己输入"]
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"error": "生成失败",
|
||||||
|
"prompt": "请重试",
|
||||||
|
"options": []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/quick-generate")
|
@router.post("/quick-generate")
|
||||||
async def quick_generate(
|
async def quick_generate(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
@@ -280,14 +453,17 @@ async def quick_generate(
|
|||||||
# 降级使用原有方法
|
# 降级使用原有方法
|
||||||
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
|
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
|
||||||
|
|
||||||
# 调用AI
|
# 调用AI - 流式生成并累积文本
|
||||||
response = await ai_service.generate_text(
|
accumulated_text = ""
|
||||||
|
async for chunk in ai_service.generate_text_stream(
|
||||||
prompt=prompts["user"],
|
prompt=prompts["user"],
|
||||||
system_prompt=prompts["system"],
|
system_prompt=prompts["system"],
|
||||||
temperature=0.7
|
temperature=0.7
|
||||||
)
|
):
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
content = response.get("content", "")
|
response = {"content": accumulated_text}
|
||||||
|
content = accumulated_text
|
||||||
|
|
||||||
# 解析JSON(使用统一的JSON清洗方法)
|
# 解析JSON(使用统一的JSON清洗方法)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -512,8 +512,29 @@ async def generate_organization_stream(
|
|||||||
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ai_response = await user_ai_service.generate_text(prompt=prompt)
|
# 使用流式生成替代非流式
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else str(ai_response)
|
ai_content = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||||
|
chunk_count += 1
|
||||||
|
ai_content += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新字数(5-95%,AI生成占90%)
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
progress = min(5 + (chunk_count // 5), 95)
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成组织中... ({len(ai_content)}字符)",
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
except Exception as ai_error:
|
except Exception as ai_error:
|
||||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||||
@@ -523,7 +544,7 @@ async def generate_organization_stream(
|
|||||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||||
|
|
||||||
# ✅ 使用统一的 JSON 清洗方法
|
# ✅ 使用统一的 JSON 清洗方法
|
||||||
try:
|
try:
|
||||||
@@ -536,7 +557,7 @@ async def generate_organization_stream(
|
|||||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("创建组织记录...", 75)
|
yield await SSEResponse.send_progress("创建组织记录...", 97)
|
||||||
|
|
||||||
# 创建角色记录(组织也是角色的一种)
|
# 创建角色记录(组织也是角色的一种)
|
||||||
character = Character(
|
character = Character(
|
||||||
@@ -563,7 +584,7 @@ async def generate_organization_stream(
|
|||||||
|
|
||||||
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||||
|
|
||||||
# 自动创建Organization详情记录
|
# 自动创建Organization详情记录
|
||||||
organization = Organization(
|
organization = Organization(
|
||||||
@@ -580,7 +601,7 @@ async def generate_organization_stream(
|
|||||||
|
|
||||||
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||||
|
|
||||||
# 记录生成历史
|
# 记录生成历史
|
||||||
history = GenerationHistory(
|
history = GenerationHistory(
|
||||||
|
|||||||
+89
-30
@@ -470,7 +470,7 @@ async def _generate_new_outline(
|
|||||||
project: Project,
|
project: Project,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
user_ai_service: AIService,
|
user_ai_service: AIService,
|
||||||
user_id: str = None
|
user_id: str
|
||||||
) -> OutlineListResponse:
|
) -> OutlineListResponse:
|
||||||
"""全新生成大纲(MCP增强版)"""
|
"""全新生成大纲(MCP增强版)"""
|
||||||
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
|
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
|
||||||
@@ -534,7 +534,7 @@ async def _generate_new_outline(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
max_tool_rounds=2,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -573,15 +573,23 @@ async def _generate_new_outline(
|
|||||||
mcp_references=mcp_reference_materials
|
mcp_references=mcp_reference_materials
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用AI生成大纲
|
# 调用AI流式生成大纲(带字数统计)
|
||||||
ai_response = await user_ai_service.generate_text(
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=request.provider,
|
provider=request.provider,
|
||||||
model=request.model
|
model=request.model
|
||||||
)
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
# 这里是非SSE接口,不需要发送chunk
|
||||||
|
# 如果未来需要转SSE,可以在这里yield
|
||||||
|
|
||||||
# 提取内容(generate_text返回字典)
|
ai_content = accumulated_text
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
ai_response = {"content": ai_content}
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
outline_data = _parse_ai_response(ai_content)
|
outline_data = _parse_ai_response(ai_content)
|
||||||
@@ -732,7 +740,7 @@ async def _continue_outline(
|
|||||||
existing_outlines: List[Outline],
|
existing_outlines: List[Outline],
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
user_ai_service: AIService,
|
user_ai_service: AIService,
|
||||||
user_id: str = "system"
|
user_id: str
|
||||||
) -> OutlineListResponse:
|
) -> OutlineListResponse:
|
||||||
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
|
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
|
||||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
|
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
|
||||||
@@ -1000,7 +1008,7 @@ async def _continue_outline(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -1045,15 +1053,22 @@ async def _continue_outline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用AI生成当前批次
|
# 调用AI生成当前批次
|
||||||
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
|
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
|
||||||
ai_response = await user_ai_service.generate_text(
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=request.provider,
|
provider=request.provider,
|
||||||
model=request.model
|
model=request.model
|
||||||
)
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
# 这里是非SSE接口,不需要发送chunk
|
||||||
|
|
||||||
# 提取内容(generate_text返回字典)
|
ai_content = accumulated_text
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
ai_response = {"content": ai_content}
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
outline_data = _parse_ai_response(ai_content)
|
outline_data = _parse_ai_response(ai_content)
|
||||||
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
|
|||||||
user_id=user_id_for_mcp,
|
user_id=user_id_for_mcp,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
|
|||||||
mcp_references=mcp_reference_materials
|
mcp_references=mcp_reference_materials
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用AI
|
# 调用AI流式生成
|
||||||
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
|
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
|
||||||
|
|
||||||
# 添加调试日志
|
# 添加调试日志
|
||||||
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
|
|||||||
logger.info(f"=== 大纲生成AI调用参数 ===")
|
logger.info(f"=== 大纲生成AI调用参数 ===")
|
||||||
logger.info(f" provider参数: {provider_param}")
|
logger.info(f" provider参数: {provider_param}")
|
||||||
logger.info(f" model参数: {model_param}")
|
logger.info(f" model参数: {model_param}")
|
||||||
logger.info(f" 完整data: {data}")
|
|
||||||
|
|
||||||
ai_response = await user_ai_service.generate_text(
|
# ✅ 流式生成(带字数统计和进度)
|
||||||
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider_param,
|
provider=provider_param,
|
||||||
model=model_param
|
model=model_param
|
||||||
)
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度和字数(30-95%,AI生成占65%)
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
progress = min(30 + (chunk_count // 2), 95)
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"AI生成大纲中... ({len(accumulated_text)}字符)",
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每20个块发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
|
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
|
||||||
|
|
||||||
# 提取内容(generate_text返回字典)
|
ai_content = accumulated_text
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
ai_response = {"content": ai_content}
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
outline_data = _parse_ai_response(ai_content)
|
outline_data = _parse_ai_response(ai_content)
|
||||||
|
|
||||||
# 全新生成模式:删除旧大纲和关联的所有章节
|
# 全新生成模式:删除旧大纲和关联的所有章节
|
||||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
|
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
|
||||||
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})")
|
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})")
|
||||||
|
|
||||||
from sqlalchemy import delete as sql_delete
|
from sqlalchemy import delete as sql_delete
|
||||||
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
|
|||||||
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
|
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
|
||||||
|
|
||||||
# 保存新大纲
|
# 保存新大纲
|
||||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
|
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
|
||||||
outlines = await _save_outlines(
|
outlines = await _save_outlines(
|
||||||
project_id, outline_data, db, start_index=1
|
project_id, outline_data, db, start_index=1
|
||||||
)
|
)
|
||||||
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
|
|||||||
for outline in outlines:
|
for outline in outlines:
|
||||||
await db.refresh(outline)
|
await db.refresh(outline)
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("整理结果数据...", 95)
|
yield await SSEResponse.send_progress("整理结果数据...", 99)
|
||||||
|
|
||||||
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
||||||
|
|
||||||
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
|
|||||||
logger.info(f" provider参数: {provider_param}")
|
logger.info(f" provider参数: {provider_param}")
|
||||||
logger.info(f" model参数: {model_param}")
|
logger.info(f" model参数: {model_param}")
|
||||||
|
|
||||||
ai_response = await user_ai_service.generate_text(
|
# 流式生成并累积文本
|
||||||
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider_param,
|
provider=provider_param,
|
||||||
model=model_param
|
model=model_param
|
||||||
)
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度(每批占用约50%的进度空间)
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
# 在批次范围内平滑递增(从10到85,总共75%)
|
||||||
|
batch_range = 60 // total_batches # 总进度60%分配给所有批次
|
||||||
|
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
|
||||||
|
progress_in_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每20个块发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
yield await SSEResponse.send_progress(
|
yield await SSEResponse.send_progress(
|
||||||
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
|
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
|
||||||
batch_progress + 10
|
batch_progress + 10
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提取内容(generate_text返回字典)
|
# 提取内容
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
ai_content = accumulated_text
|
||||||
|
ai_response = {"content": ai_content}
|
||||||
|
|
||||||
# 解析响应
|
# 解析响应
|
||||||
outline_data = _parse_ai_response(ai_content)
|
outline_data = _parse_ai_response(ai_content)
|
||||||
|
|||||||
+25
-11
@@ -73,14 +73,15 @@ async def get_user_ai_service(
|
|||||||
await db.refresh(settings)
|
await db.refresh(settings)
|
||||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||||
|
|
||||||
# 使用用户设置创建AI服务实例
|
# 使用用户设置创建AI服务实例(包括系统提示词)
|
||||||
return create_user_ai_service(
|
return create_user_ai_service(
|
||||||
api_provider=settings.api_provider,
|
api_provider=settings.api_provider,
|
||||||
api_key=settings.api_key,
|
api_key=settings.api_key,
|
||||||
api_base_url=settings.api_base_url or "",
|
api_base_url=settings.api_base_url or "",
|
||||||
model_name=settings.llm_model,
|
model_name=settings.llm_model,
|
||||||
temperature=settings.temperature,
|
temperature=settings.temperature,
|
||||||
max_tokens=settings.max_tokens
|
max_tokens=settings.max_tokens,
|
||||||
|
system_prompt=settings.system_prompt # 传递系统提示词
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -271,17 +272,30 @@ async def get_available_models(
|
|||||||
}
|
}
|
||||||
|
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
# Anthropic 没有公开的模型列表API
|
# Anthropic models API
|
||||||
raise HTTPException(
|
url = f"{api_base_url.rstrip('/')}/v1/models"
|
||||||
status_code=400,
|
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
||||||
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
|
response = await client.get(url, headers=headers)
|
||||||
)
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
|
||||||
|
return {"provider": provider, "models": models, "count": len(models)}
|
||||||
|
|
||||||
|
elif provider == "gemini":
|
||||||
|
# Gemini models API
|
||||||
|
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
for m in data.get("models", []):
|
||||||
|
if "generateContent" in m.get("supportedGenerationMethods", []):
|
||||||
|
mid = m.get("name", "").replace("models/", "")
|
||||||
|
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
|
||||||
|
return {"provider": provider, "models": models, "count": len(models)}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
|
||||||
status_code=400,
|
|
||||||
detail=f"不支持的提供商: {provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
||||||
|
|||||||
+362
-126
@@ -99,7 +99,7 @@ async def world_building_generator(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1,
|
max_tool_rounds=2,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -139,51 +139,118 @@ async def world_building_generator(
|
|||||||
final_prompt = base_prompt
|
final_prompt = base_prompt
|
||||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||||
|
|
||||||
# 流式生成世界观
|
# ===== 流式生成世界观(带重试机制) =====
|
||||||
accumulated_text = ""
|
MAX_WORLD_RETRIES = 3 # 最多重试3次
|
||||||
chunk_count = 0
|
world_retry_count = 0
|
||||||
|
world_generation_success = False
|
||||||
async for chunk in user_ai_service.generate_text_stream(
|
|
||||||
prompt=final_prompt,
|
|
||||||
provider=provider,
|
|
||||||
model=model
|
|
||||||
):
|
|
||||||
chunk_count += 1
|
|
||||||
accumulated_text += chunk
|
|
||||||
|
|
||||||
# 发送内容块
|
|
||||||
yield await SSEResponse.send_chunk(chunk)
|
|
||||||
|
|
||||||
# 定期更新进度
|
|
||||||
if chunk_count % 5 == 0:
|
|
||||||
progress = min(30 + (chunk_count // 5), 70)
|
|
||||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
|
||||||
|
|
||||||
# 每20个块发送心跳
|
|
||||||
if chunk_count % 20 == 0:
|
|
||||||
yield await SSEResponse.send_heartbeat()
|
|
||||||
|
|
||||||
# 解析结果 - 使用统一的JSON清洗方法
|
|
||||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
|
||||||
|
|
||||||
world_data = {}
|
world_data = {}
|
||||||
try:
|
|
||||||
# ✅ 使用 AIService 的统一清洗方法
|
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
|
||||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
try:
|
||||||
world_data = json.loads(cleaned_text)
|
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
|
||||||
logger.info(f"✅ 世界观JSON解析成功")
|
yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
|
||||||
|
|
||||||
|
# 流式生成世界观
|
||||||
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
|
prompt=final_prompt,
|
||||||
|
provider=provider,
|
||||||
|
model=model
|
||||||
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
# 发送内容块
|
||||||
logger.error(f"❌ 世界构建JSON解析失败: {e}")
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
|
||||||
world_data = {
|
# 世界观生成独立进度:5-95%
|
||||||
"time_period": "AI返回格式错误,请重试",
|
if chunk_count % 5 == 0:
|
||||||
"location": "AI返回格式错误,请重试",
|
progress = min(5 + (chunk_count // 3), 95)
|
||||||
"atmosphere": "AI返回格式错误,请重试",
|
yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress)
|
||||||
"rules": "AI返回格式错误,请重试"
|
|
||||||
}
|
# 每20个块发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
|
# 检查是否返回空响应
|
||||||
|
if not accumulated_text or not accumulated_text.strip():
|
||||||
|
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||||
|
world_retry_count += 1
|
||||||
|
if world_retry_count < MAX_WORLD_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ AI返回为空,准备重试...",
|
||||||
|
30 + world_retry_count * 5,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 达到最大重试次数,使用默认值
|
||||||
|
logger.error("❌ 世界观生成多次返回空响应")
|
||||||
|
world_data = {
|
||||||
|
"time_period": "AI多次返回为空,请稍后重试",
|
||||||
|
"location": "AI多次返回为空,请稍后重试",
|
||||||
|
"atmosphere": "AI多次返回为空,请稍后重试",
|
||||||
|
"rules": "AI多次返回为空,请稍后重试"
|
||||||
|
}
|
||||||
|
world_generation_success = True # 标记为成功以继续流程
|
||||||
|
break
|
||||||
|
|
||||||
|
# 解析结果 - 使用统一的JSON清洗方法
|
||||||
|
yield await SSEResponse.send_progress("解析世界观数据...", 96)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
|
||||||
|
logger.info(f" 原始内容预览: {accumulated_text[:300]}...")
|
||||||
|
|
||||||
|
# ✅ 使用 AIService 的统一清洗方法
|
||||||
|
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||||
|
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||||
|
logger.info(f" 清洗后预览: {cleaned_text[:300]}...")
|
||||||
|
|
||||||
|
world_data = json.loads(cleaned_text)
|
||||||
|
logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||||
|
world_generation_success = True # 解析成功,标记完成
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}")
|
||||||
|
logger.error(f" 原始内容长度: {len(accumulated_text)}")
|
||||||
|
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||||
|
world_retry_count += 1
|
||||||
|
if world_retry_count < MAX_WORLD_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ JSON解析失败,准备重试...",
|
||||||
|
30 + world_retry_count * 5,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 达到最大重试次数,使用默认值
|
||||||
|
world_data = {
|
||||||
|
"time_period": "AI返回格式错误,请重试",
|
||||||
|
"location": "AI返回格式错误,请重试",
|
||||||
|
"atmosphere": "AI返回格式错误,请重试",
|
||||||
|
"rules": "AI返回格式错误,请重试"
|
||||||
|
}
|
||||||
|
world_generation_success = True # 标记为成功以继续流程
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}")
|
||||||
|
world_retry_count += 1
|
||||||
|
if world_retry_count < MAX_WORLD_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ 生成异常,准备重试...",
|
||||||
|
30 + world_retry_count * 5,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 最后一次重试仍失败,抛出异常
|
||||||
|
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
|
||||||
|
raise
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
yield await SSEResponse.send_progress("保存世界观到数据库...", 99)
|
||||||
|
|
||||||
# 确保user_id存在
|
# 确保user_id存在
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -240,41 +307,81 @@ async def world_building_generator(
|
|||||||
project.wizard_step = 1
|
project.wizard_step = 1
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# ===== 自动生成职业体系 =====
|
# ===== 自动生成职业体系(带重试机制+流式) =====
|
||||||
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 75)
|
yield await SSEResponse.send_progress("世界观完成!", 100, "success")
|
||||||
|
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5)
|
||||||
logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系")
|
logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系")
|
||||||
|
|
||||||
try:
|
MAX_CAREER_RETRIES = 3 # 最多重试3次
|
||||||
# 获取职业生成提示词模板(支持用户自定义)
|
career_retry_count = 0
|
||||||
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
|
career_generation_success = False
|
||||||
career_prompt = PromptService.format_prompt(
|
|
||||||
template,
|
while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success:
|
||||||
title=project.title,
|
try:
|
||||||
genre=genre or '未设定',
|
retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else ""
|
||||||
theme=theme or '未设定',
|
yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10)
|
||||||
time_period=world_data.get('time_period', '未设定'),
|
|
||||||
location=world_data.get('location', '未设定'),
|
# 获取职业生成提示词模板(支持用户自定义)
|
||||||
atmosphere=world_data.get('atmosphere', '未设定'),
|
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
|
||||||
rules=world_data.get('rules', '未设定')
|
career_prompt = PromptService.format_prompt(
|
||||||
)
|
template,
|
||||||
|
title=project.title,
|
||||||
yield await SSEResponse.send_progress("正在生成职业体系...", 78)
|
genre=genre or '未设定',
|
||||||
|
theme=theme or '未设定',
|
||||||
# 调用AI生成职业
|
time_period=world_data.get('time_period', '未设定'),
|
||||||
result = await user_ai_service.generate_text(prompt=career_prompt)
|
location=world_data.get('location', '未设定'),
|
||||||
career_response = result.get('content', '') if isinstance(result, dict) else result
|
atmosphere=world_data.get('atmosphere', '未设定'),
|
||||||
|
rules=world_data.get('rules', '未设定')
|
||||||
if not career_response or not career_response.strip():
|
)
|
||||||
logger.warning("⚠️ AI返回空职业体系,跳过职业生成")
|
|
||||||
yield await SSEResponse.send_progress("职业体系生成跳过(AI返回为空)", 85)
|
# ✅ 使用流式生成职业体系
|
||||||
else:
|
career_response = ""
|
||||||
yield await SSEResponse.send_progress("解析职业体系数据...", 82)
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
|
prompt=career_prompt,
|
||||||
|
provider=provider,
|
||||||
|
model=model
|
||||||
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
career_response += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 职业体系生成独立进度:10-95%
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
progress = min(10 + (chunk_count // 3), 95)
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"生成职业体系中... ({len(career_response)}字符)",
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每20个块发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
|
if not career_response or not career_response.strip():
|
||||||
|
logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||||
|
career_retry_count += 1
|
||||||
|
if career_retry_count < MAX_CAREER_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ AI返回为空,准备重试...",
|
||||||
|
10,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99)
|
||||||
|
break
|
||||||
|
|
||||||
|
yield await SSEResponse.send_progress("解析职业体系数据...", 96)
|
||||||
|
|
||||||
# 清洗并解析JSON
|
# 清洗并解析JSON
|
||||||
try:
|
try:
|
||||||
cleaned_response = user_ai_service._clean_json_response(career_response)
|
cleaned_response = user_ai_service._clean_json_response(career_response)
|
||||||
career_data = json.loads(cleaned_response)
|
career_data = json.loads(cleaned_response)
|
||||||
logger.info(f"✅ 职业体系JSON解析成功")
|
logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||||
|
|
||||||
# 保存主职业
|
# 保存主职业
|
||||||
main_careers_created = []
|
main_careers_created = []
|
||||||
@@ -338,22 +445,51 @@ async def world_building_generator(
|
|||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
# 标记成功
|
||||||
|
career_generation_success = True
|
||||||
logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}个")
|
logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}个")
|
||||||
yield await SSEResponse.send_progress(
|
yield await SSEResponse.send_progress(
|
||||||
f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)})",
|
f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)})",
|
||||||
90
|
99
|
||||||
)
|
)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||||
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败,已跳过", 85)
|
career_retry_count += 1
|
||||||
|
if career_retry_count < MAX_CAREER_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ JSON解析失败,准备重试...",
|
||||||
|
10,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 职业体系保存失败: {e}")
|
logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||||
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败,已跳过", 85)
|
career_retry_count += 1
|
||||||
|
if career_retry_count < MAX_CAREER_RETRIES:
|
||||||
except Exception as e:
|
yield await SSEResponse.send_progress(
|
||||||
logger.error(f"❌ 职业体系生成异常: {e}")
|
f"⚠️ 保存失败,准备重试...",
|
||||||
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败,已跳过(不影响项目创建)", 85)
|
10,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||||
|
career_retry_count += 1
|
||||||
|
if career_retry_count < MAX_CAREER_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ 生成异常,准备重试...",
|
||||||
|
10,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99)
|
||||||
|
|
||||||
db_committed = True
|
db_committed = True
|
||||||
|
|
||||||
@@ -366,7 +502,8 @@ async def world_building_generator(
|
|||||||
"rules": world_data.get("rules")
|
"rules": world_data.get("rules")
|
||||||
})
|
})
|
||||||
|
|
||||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
yield await SSEResponse.send_progress("职业体系完成!", 100, "success")
|
||||||
|
yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success")
|
||||||
yield await SSEResponse.send_done()
|
yield await SSEResponse.send_done()
|
||||||
|
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
@@ -473,7 +610,7 @@ async def characters_generator(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮
|
max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -611,15 +748,32 @@ async def characters_generator(
|
|||||||
else:
|
else:
|
||||||
prompt = base_prompt
|
prompt = base_prompt
|
||||||
|
|
||||||
# 流式生成
|
# 流式生成(带字数统计)
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
async for chunk in user_ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
):
|
):
|
||||||
|
chunk_count += 1
|
||||||
accumulated_text += chunk
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
yield await SSEResponse.send_chunk(chunk)
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度和字数
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15)
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"生成角色中... ({len(accumulated_text)}字符)",
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每20个块发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
# 解析批次结果 - 使用统一的JSON清洗方法
|
# 解析批次结果 - 使用统一的JSON清洗方法
|
||||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||||
@@ -1184,18 +1338,35 @@ async def outline_generator(
|
|||||||
requirements=outline_requirements
|
requirements=outline_requirements
|
||||||
)
|
)
|
||||||
|
|
||||||
# 流式生成大纲
|
# 流式生成大纲(带字数统计)
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
async for chunk in user_ai_service.generate_text_stream(
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
prompt=outline_prompt,
|
prompt=outline_prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
):
|
):
|
||||||
|
chunk_count += 1
|
||||||
accumulated_text += chunk
|
accumulated_text += chunk
|
||||||
|
|
||||||
|
# 发送内容块
|
||||||
yield await SSEResponse.send_chunk(chunk)
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
|
|
||||||
|
# 定期更新进度和字数(5-95%,AI生成占90%)
|
||||||
|
if chunk_count % 5 == 0:
|
||||||
|
progress = min(5 + (chunk_count // 3), 95)
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"生成大纲中... ({len(accumulated_text)}字符)",
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每20个块发送心跳
|
||||||
|
if chunk_count % 20 == 0:
|
||||||
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
# 解析大纲结果 - 使用统一的JSON清洗方法
|
# 解析大纲结果 - 使用统一的JSON清洗方法
|
||||||
yield await SSEResponse.send_progress("解析大纲...", 40)
|
yield await SSEResponse.send_progress("解析大纲...", 96)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||||
@@ -1208,7 +1379,7 @@ async def outline_generator(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 保存大纲到数据库
|
# 保存大纲到数据库
|
||||||
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
|
yield await SSEResponse.send_progress("保存大纲到数据库...", 97)
|
||||||
created_outlines = []
|
created_outlines = []
|
||||||
for index, outline_item in enumerate(outline_data[:outline_count], 1):
|
for index, outline_item in enumerate(outline_data[:outline_count], 1):
|
||||||
outline = Outline(
|
outline = Outline(
|
||||||
@@ -1231,7 +1402,7 @@ async def outline_generator(
|
|||||||
created_chapters = []
|
created_chapters = []
|
||||||
if project.outline_mode == 'one-to-one':
|
if project.outline_mode == 'one-to-one':
|
||||||
# 一对一模式:自动为每个大纲创建对应的章节
|
# 一对一模式:自动为每个大纲创建对应的章节
|
||||||
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50)
|
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98)
|
||||||
|
|
||||||
for outline in created_outlines:
|
for outline in created_outlines:
|
||||||
chapter = Chapter(
|
chapter = Chapter(
|
||||||
@@ -1250,10 +1421,10 @@ async def outline_generator(
|
|||||||
await db.refresh(chapter)
|
await db.refresh(chapter)
|
||||||
|
|
||||||
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
|
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
|
||||||
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85)
|
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99)
|
||||||
else:
|
else:
|
||||||
# 一对多模式:跳过自动创建,用户可手动展开
|
# 一对多模式:跳过自动创建,用户可手动展开
|
||||||
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85)
|
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99)
|
||||||
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
|
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
|
||||||
|
|
||||||
# 更新项目信息
|
# 更新项目信息
|
||||||
@@ -1396,7 +1567,7 @@ async def world_building_regenerate_generator(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1,
|
max_tool_rounds=2,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
provider=None,
|
provider=None,
|
||||||
model=None
|
model=None
|
||||||
@@ -1433,44 +1604,109 @@ async def world_building_regenerate_generator(
|
|||||||
final_prompt = base_prompt
|
final_prompt = base_prompt
|
||||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||||
|
|
||||||
# 流式生成世界观
|
# ===== 流式生成世界观(带重试机制) =====
|
||||||
accumulated_text = ""
|
MAX_WORLD_RETRIES = 3 # 最多重试3次
|
||||||
chunk_count = 0
|
world_retry_count = 0
|
||||||
|
world_generation_success = False
|
||||||
async for chunk in user_ai_service.generate_text_stream(
|
|
||||||
prompt=final_prompt,
|
|
||||||
provider=provider,
|
|
||||||
model=model
|
|
||||||
):
|
|
||||||
chunk_count += 1
|
|
||||||
accumulated_text += chunk
|
|
||||||
|
|
||||||
yield await SSEResponse.send_chunk(chunk)
|
|
||||||
|
|
||||||
if chunk_count % 5 == 0:
|
|
||||||
progress = min(30 + (chunk_count // 5), 70)
|
|
||||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
|
||||||
|
|
||||||
if chunk_count % 20 == 0:
|
|
||||||
yield await SSEResponse.send_heartbeat()
|
|
||||||
|
|
||||||
# 解析结果 - 使用统一的JSON清洗方法
|
|
||||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
|
||||||
|
|
||||||
world_data = {}
|
world_data = {}
|
||||||
try:
|
|
||||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
|
||||||
world_data = json.loads(cleaned_text)
|
try:
|
||||||
logger.info(f"✅ 世界观重新生成JSON解析成功")
|
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
|
||||||
|
yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
|
||||||
|
|
||||||
|
# 流式生成世界观
|
||||||
|
accumulated_text = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
async for chunk in user_ai_service.generate_text_stream(
|
||||||
|
prompt=final_prompt,
|
||||||
|
provider=provider,
|
||||||
|
model=model
|
||||||
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
yield await SSEResponse.send_chunk(chunk)
|
||||||
logger.error(f"世界构建JSON解析失败: {e}")
|
|
||||||
world_data = {
|
if chunk_count % 5 == 0:
|
||||||
"time_period": "AI返回格式错误,请重试",
|
progress = min(30 + (chunk_count // 5), 85)
|
||||||
"location": "AI返回格式错误,请重试",
|
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||||
"atmosphere": "AI返回格式错误,请重试",
|
|
||||||
"rules": "AI返回格式错误,请重试"
|
if chunk_count % 20 == 0:
|
||||||
}
|
yield await SSEResponse.send_heartbeat()
|
||||||
|
|
||||||
|
# 检查是否返回空响应
|
||||||
|
if not accumulated_text or not accumulated_text.strip():
|
||||||
|
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||||
|
world_retry_count += 1
|
||||||
|
if world_retry_count < MAX_WORLD_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ AI返回为空,准备重试...",
|
||||||
|
30 + world_retry_count * 5,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 达到最大重试次数,使用默认值
|
||||||
|
logger.error("❌ 世界观重新生成多次返回空响应")
|
||||||
|
world_data = {
|
||||||
|
"time_period": "AI多次返回为空,请稍后重试",
|
||||||
|
"location": "AI多次返回为空,请稍后重试",
|
||||||
|
"atmosphere": "AI多次返回为空,请稍后重试",
|
||||||
|
"rules": "AI多次返回为空,请稍后重试"
|
||||||
|
}
|
||||||
|
world_generation_success = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# 解析结果 - 使用统一的JSON清洗方法
|
||||||
|
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
|
||||||
|
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||||
|
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||||
|
|
||||||
|
world_data = json.loads(cleaned_text)
|
||||||
|
logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||||
|
world_generation_success = True
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}")
|
||||||
|
logger.error(f" 原始内容长度: {len(accumulated_text)}")
|
||||||
|
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||||
|
world_retry_count += 1
|
||||||
|
if world_retry_count < MAX_WORLD_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ JSON解析失败,准备重试...",
|
||||||
|
30 + world_retry_count * 5,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 达到最大重试次数,使用默认值
|
||||||
|
world_data = {
|
||||||
|
"time_period": "AI返回格式错误,请重试",
|
||||||
|
"location": "AI返回格式错误,请重试",
|
||||||
|
"atmosphere": "AI返回格式错误,请重试",
|
||||||
|
"rules": "AI返回格式错误,请重试"
|
||||||
|
}
|
||||||
|
world_generation_success = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}")
|
||||||
|
world_retry_count += 1
|
||||||
|
if world_retry_count < MAX_WORLD_RETRIES:
|
||||||
|
yield await SSEResponse.send_progress(
|
||||||
|
f"⚠️ 生成异常,准备重试...",
|
||||||
|
30 + world_retry_count * 5,
|
||||||
|
"warning"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 最后一次重试仍失败,抛出异常
|
||||||
|
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
|
||||||
|
raise
|
||||||
|
|
||||||
# 不保存到数据库,仅返回生成结果供用户预览
|
# 不保存到数据库,仅返回生成结果供用户预览
|
||||||
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
|
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from ..schemas.writing_style import (
|
|||||||
WritingStyleListResponse,
|
WritingStyleListResponse,
|
||||||
SetDefaultStyleRequest
|
SetDefaultStyleRequest
|
||||||
)
|
)
|
||||||
from ..services.prompt_service import WritingStyleManager
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
|
||||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||||
@@ -31,21 +30,36 @@ def get_current_user_id(request: Request) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/presets/list", response_model=List[dict])
|
@router.get("/presets/list", response_model=List[dict])
|
||||||
async def get_preset_styles():
|
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
|
||||||
"""
|
"""
|
||||||
获取所有预设风格列表
|
获取所有预设风格列表(从数据库读取)
|
||||||
|
|
||||||
返回格式:数组形式的预设风格列表
|
返回格式:数组形式的预设风格列表
|
||||||
[
|
[
|
||||||
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||||
{"id": "classical", "name": "古典优雅", ...}
|
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
presets = WritingStyleManager.get_all_presets()
|
# 从数据库获取全局预设风格(user_id 为 NULL)
|
||||||
# 将字典转换为数组,添加 id 字段
|
result = await db.execute(
|
||||||
|
select(WritingStyle)
|
||||||
|
.where(WritingStyle.user_id.is_(None))
|
||||||
|
.order_by(WritingStyle.order_index)
|
||||||
|
)
|
||||||
|
preset_styles = result.scalars().all()
|
||||||
|
|
||||||
|
# 转换为响应格式
|
||||||
return [
|
return [
|
||||||
{"id": preset_id, **preset_data}
|
{
|
||||||
for preset_id, preset_data in presets.items()
|
"id": style.id,
|
||||||
|
"preset_id": style.preset_id,
|
||||||
|
"name": style.name,
|
||||||
|
"description": style.description,
|
||||||
|
"prompt_content": style.prompt_content,
|
||||||
|
"style_type": style.style_type,
|
||||||
|
"order_index": style.order_index
|
||||||
|
}
|
||||||
|
for style in preset_styles
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -58,25 +72,33 @@ async def create_writing_style(
|
|||||||
"""
|
"""
|
||||||
创建新的写作风格(用户级别)
|
创建新的写作风格(用户级别)
|
||||||
|
|
||||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
- **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充
|
||||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||||
"""
|
"""
|
||||||
# 获取当前用户ID
|
# 获取当前用户ID
|
||||||
user_id = get_current_user_id(request)
|
user_id = get_current_user_id(request)
|
||||||
|
|
||||||
# 如果基于预设创建,获取预设内容
|
# 如果基于预设创建,从数据库获取预设内容
|
||||||
if style_data.preset_id:
|
if style_data.preset_id:
|
||||||
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
|
result = await db.execute(
|
||||||
|
select(WritingStyle)
|
||||||
|
.where(
|
||||||
|
WritingStyle.user_id.is_(None),
|
||||||
|
WritingStyle.preset_id == style_data.preset_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
preset = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not preset:
|
if not preset:
|
||||||
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
||||||
|
|
||||||
# 使用预设内容填充(如果用户未提供)
|
# 使用预设内容填充(如果用户未提供)
|
||||||
if not style_data.name:
|
if not style_data.name:
|
||||||
style_data.name = preset["name"]
|
style_data.name = preset.name
|
||||||
if not style_data.description:
|
if not style_data.description:
|
||||||
style_data.description = preset["description"]
|
style_data.description = preset.description
|
||||||
if not style_data.prompt_content:
|
if not style_data.prompt_content:
|
||||||
style_data.prompt_content = preset["prompt_content"]
|
style_data.prompt_content = preset.prompt_content
|
||||||
|
|
||||||
# 验证必填字段
|
# 验证必填字段
|
||||||
if not style_data.name or not style_data.prompt_content:
|
if not style_data.name or not style_data.prompt_content:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
|
|||||||
from mcp import ClientSession, types
|
from mcp import ClientSession, types
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
from anyio import ClosedResourceError
|
||||||
|
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
|
||||||
@@ -141,51 +142,89 @@ class HTTPMCPClient:
|
|||||||
async def call_tool(
|
async def call_tool(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
arguments: Dict[str, Any]
|
arguments: Dict[str, Any],
|
||||||
|
max_reconnect_attempts: int = 2
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
调用工具
|
调用工具(带自动重连)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: 工具名称
|
tool_name: 工具名称
|
||||||
arguments: 工具参数
|
arguments: 工具参数
|
||||||
|
max_reconnect_attempts: 最大重连尝试次数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
工具执行结果
|
工具执行结果
|
||||||
"""
|
"""
|
||||||
try:
|
for attempt in range(max_reconnect_attempts + 1):
|
||||||
await self._ensure_connected()
|
try:
|
||||||
|
await self._ensure_connected()
|
||||||
logger.info(f"调用工具: {tool_name}")
|
|
||||||
logger.debug(f"参数: {arguments}")
|
logger.info(f"调用工具: {tool_name}")
|
||||||
|
logger.debug(f" 参数类型: {type(arguments)}")
|
||||||
result = await self._session.call_tool(tool_name, arguments)
|
logger.debug(f" 参数内容: {arguments}")
|
||||||
|
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
|
||||||
# 处理返回结果
|
|
||||||
# MCP SDK 返回 CallToolResult 对象
|
result = await self._session.call_tool(tool_name, arguments)
|
||||||
if result.content:
|
|
||||||
# 提取第一个content的文本
|
logger.debug(f" 工具返回类型: {type(result)}")
|
||||||
for content in result.content:
|
logger.debug(f" 返回内容: {result}")
|
||||||
if isinstance(content, types.TextContent):
|
|
||||||
return content.text
|
# 处理返回结果
|
||||||
elif isinstance(content, types.ImageContent):
|
# MCP SDK 返回 CallToolResult 对象
|
||||||
return {
|
if result.content:
|
||||||
"type": "image",
|
logger.debug(f" 返回content数量: {len(result.content)}")
|
||||||
"data": content.data,
|
# 提取第一个content的文本
|
||||||
"mimeType": content.mimeType
|
for idx, content in enumerate(result.content):
|
||||||
}
|
logger.debug(f" content[{idx}]类型: {type(content)}")
|
||||||
# 如果没有文本内容,返回原始内容
|
if isinstance(content, types.TextContent):
|
||||||
return result.content[0] if result.content else None
|
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
|
||||||
|
return content.text
|
||||||
# 如果有结构化内容(2025-06-18规范)
|
elif isinstance(content, types.ImageContent):
|
||||||
if hasattr(result, 'structuredContent') and result.structuredContent:
|
logger.debug(f" ✅ 返回ImageContent")
|
||||||
return result.structuredContent
|
return {
|
||||||
|
"type": "image",
|
||||||
return None
|
"data": content.data,
|
||||||
|
"mimeType": content.mimeType
|
||||||
except Exception as e:
|
}
|
||||||
logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
|
# 如果没有文本内容,返回原始内容
|
||||||
raise MCPError(f"调用工具失败: {str(e)}")
|
logger.debug(f" ⚠️ 返回原始content[0]")
|
||||||
|
return result.content[0] if result.content else None
|
||||||
|
|
||||||
|
# 如果有结构化内容(2025-06-18规范)
|
||||||
|
if hasattr(result, 'structuredContent') and result.structuredContent:
|
||||||
|
logger.debug(f" ✅ 返回structuredContent")
|
||||||
|
return result.structuredContent
|
||||||
|
|
||||||
|
logger.warning(f" ⚠️ 工具返回为None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except ClosedResourceError as e:
|
||||||
|
# 连接已关闭,尝试重连
|
||||||
|
if attempt < max_reconnect_attempts:
|
||||||
|
logger.warning(
|
||||||
|
f"⚠️ MCP连接已关闭,尝试重新连接 "
|
||||||
|
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
|
||||||
|
)
|
||||||
|
await self._cleanup()
|
||||||
|
await asyncio.sleep(0.5) # 短暂延迟后重连
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
|
||||||
|
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
|
||||||
|
raise MCPError(error_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||||
|
logger.error(f" 参数: {arguments}")
|
||||||
|
logger.error(f" 错误类型: {type(e).__name__}")
|
||||||
|
logger.error(f" 错误详情: {repr(e)}")
|
||||||
|
logger.error(f" 错误字符串: '{str(e)}'")
|
||||||
|
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
|
||||||
|
raise MCPError(f"调用工具失败: {error_msg}")
|
||||||
|
|
||||||
|
# 理论上不会到这里
|
||||||
|
raise MCPError(f"工具调用失败: 未知错误")
|
||||||
|
|
||||||
async def list_resources(self) -> List[Dict[str, Any]]:
|
async def list_resources(self) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class Settings(Base):
|
|||||||
llm_model = Column(String(100), default="gpt-4", comment="模型名称")
|
llm_model = Column(String(100), default="gpt-4", comment="模型名称")
|
||||||
temperature = Column(Float, default=0.7, comment="温度参数")
|
temperature = Column(Float, default=0.7, comment="温度参数")
|
||||||
max_tokens = Column(Integer, default=2000, comment="最大token数")
|
max_tokens = Column(Integer, default=2000, comment="最大token数")
|
||||||
|
system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用")
|
||||||
preferences = Column(Text, comment="其他偏好设置(JSON)")
|
preferences = Column(Text, comment="其他偏好设置(JSON)")
|
||||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class SettingsBase(BaseModel):
|
|||||||
llm_model: Optional[str] = Field(default="gpt-4", description="模型名称")
|
llm_model: Optional[str] = Field(default="gpt-4", description="模型名称")
|
||||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
||||||
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
|
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
|
||||||
|
system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用")
|
||||||
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
|
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""AI 客户端模块"""
|
||||||
|
from .base_client import BaseAIClient
|
||||||
|
from .openai_client import OpenAIClient
|
||||||
|
from .anthropic_client import AnthropicClient
|
||||||
|
|
||||||
|
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""Anthropic 客户端"""
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
|
||||||
|
from app.logger import get_logger
|
||||||
|
from app.services.ai_config import AIClientConfig, default_config
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicClient:
|
||||||
|
"""Anthropic API 客户端"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||||
|
self.config = config or default_config
|
||||||
|
kwargs = {"api_key": api_key}
|
||||||
|
if base_url:
|
||||||
|
kwargs["base_url"] = base_url
|
||||||
|
self.client = AsyncAnthropic(**kwargs)
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
if system_prompt:
|
||||||
|
kwargs["system"] = system_prompt
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = tools
|
||||||
|
if tool_choice == "required":
|
||||||
|
kwargs["tool_choice"] = {"type": "any"}
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
kwargs["tool_choice"] = {"type": "auto"}
|
||||||
|
|
||||||
|
response = await self.client.messages.create(**kwargs)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
content = ""
|
||||||
|
for block in response.content:
|
||||||
|
if block.type == "tool_use":
|
||||||
|
tool_calls.append({
|
||||||
|
"id": block.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": block.name, "arguments": block.input},
|
||||||
|
})
|
||||||
|
elif block.type == "text":
|
||||||
|
content += block.text
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"tool_calls": tool_calls if tool_calls else None,
|
||||||
|
"finish_reason": response.stop_reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def chat_completion_stream(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
if system_prompt:
|
||||||
|
kwargs["system"] = system_prompt
|
||||||
|
|
||||||
|
async with self.client.messages.stream(**kwargs) as stream:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
yield text
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
"""AI 客户端基类"""
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.logger import get_logger
|
||||||
|
from app.services.ai_config import AIClientConfig, default_config
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# 全局 HTTP 客户端池
|
||||||
|
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
||||||
|
_global_semaphore: Optional[asyncio.Semaphore] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
||||||
|
"""获取全局信号量"""
|
||||||
|
global _global_semaphore
|
||||||
|
if _global_semaphore is None:
|
||||||
|
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
return _global_semaphore
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAIClient(ABC):
|
||||||
|
"""AI HTTP 客户端基类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
config: Optional[AIClientConfig] = None,
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.config = config or default_config
|
||||||
|
self.http_client = self._get_or_create_client()
|
||||||
|
|
||||||
|
def _get_client_key(self) -> str:
|
||||||
|
"""生成客户端唯一键"""
|
||||||
|
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
||||||
|
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
||||||
|
|
||||||
|
def _get_or_create_client(self) -> httpx.AsyncClient:
|
||||||
|
"""获取或创建 HTTP 客户端"""
|
||||||
|
client_key = self._get_client_key()
|
||||||
|
|
||||||
|
if client_key in _http_client_pool:
|
||||||
|
client = _http_client_pool[client_key]
|
||||||
|
if not client.is_closed:
|
||||||
|
return client
|
||||||
|
del _http_client_pool[client_key]
|
||||||
|
|
||||||
|
http_cfg = self.config.http
|
||||||
|
client = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(
|
||||||
|
connect=http_cfg.connect_timeout,
|
||||||
|
read=http_cfg.read_timeout,
|
||||||
|
write=http_cfg.write_timeout,
|
||||||
|
pool=http_cfg.pool_timeout,
|
||||||
|
),
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
||||||
|
max_connections=http_cfg.max_connections,
|
||||||
|
keepalive_expiry=http_cfg.keepalive_expiry,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_http_client_pool[client_key] = client
|
||||||
|
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
||||||
|
return client
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
|
"""构建请求头"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _request_with_retry(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Any:
|
||||||
|
"""带重试的 HTTP 请求"""
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
headers = self._build_headers()
|
||||||
|
retry_cfg = self.config.retry
|
||||||
|
rate_cfg = self.config.rate_limit
|
||||||
|
|
||||||
|
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
await asyncio.sleep(rate_cfg.request_delay)
|
||||||
|
|
||||||
|
for attempt in range(retry_cfg.max_retries):
|
||||||
|
try:
|
||||||
|
if attempt > 0:
|
||||||
|
delay = min(
|
||||||
|
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
||||||
|
retry_cfg.max_delay,
|
||||||
|
)
|
||||||
|
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self.http_client.stream(method, url, headers=headers, json=payload)
|
||||||
|
|
||||||
|
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
||||||
|
raise
|
||||||
|
if attempt == retry_cfg.max_retries - 1:
|
||||||
|
raise
|
||||||
|
except (httpx.ConnectError, httpx.TimeoutException):
|
||||||
|
if attempt == retry_cfg.max_retries - 1:
|
||||||
|
raise
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""聊天补全"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_completion_stream(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""流式聊天补全"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_all_clients():
|
||||||
|
"""清理所有 HTTP 客户端"""
|
||||||
|
for key, client in list(_http_client_pool.items()):
|
||||||
|
if not client.is_closed:
|
||||||
|
await client.aclose()
|
||||||
|
_http_client_pool.clear()
|
||||||
|
logger.info("✅ HTTP 客户端池已清理")
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
"""Gemini 客户端"""
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
import httpx
|
||||||
|
from app.services.ai_config import AIClientConfig, default_config
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiClient:
|
||||||
|
"""Google Gemini API 客户端"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||||
|
self.config = config or default_config
|
||||||
|
http_cfg = self.config.http
|
||||||
|
self.client = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(
|
||||||
|
connect=http_cfg.connect_timeout,
|
||||||
|
read=http_cfg.read_timeout,
|
||||||
|
write=http_cfg.write_timeout,
|
||||||
|
pool=http_cfg.pool_timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_tools_to_gemini(self, tools: list) -> list:
|
||||||
|
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
||||||
|
gemini_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool["function"]
|
||||||
|
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
||||||
|
params.pop("$schema", None)
|
||||||
|
params.pop("additionalProperties", None)
|
||||||
|
if params and "type" not in params:
|
||||||
|
params["type"] = "object"
|
||||||
|
decl = {
|
||||||
|
"name": func["name"],
|
||||||
|
"description": func.get("description") or func["name"],
|
||||||
|
}
|
||||||
|
if params:
|
||||||
|
decl["parameters"] = params
|
||||||
|
gemini_tools.append(decl)
|
||||||
|
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
contents = []
|
||||||
|
for msg in messages:
|
||||||
|
role = "user" if msg["role"] == "user" else "model"
|
||||||
|
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||||
|
}
|
||||||
|
if system_prompt:
|
||||||
|
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||||
|
|
||||||
|
response = await self.client.post(url, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if not candidates or len(candidates) == 0:
|
||||||
|
# 返回空内容而不是报错,保持流程继续
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
text = ""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if "text" in part:
|
||||||
|
text += part["text"]
|
||||||
|
elif "functionCall" in part:
|
||||||
|
fc = part["functionCall"]
|
||||||
|
tool_calls.append({
|
||||||
|
"id": f"call_{fc['name']}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": text,
|
||||||
|
"tool_calls": tool_calls if tool_calls else None,
|
||||||
|
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def chat_completion_stream(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||||
|
|
||||||
|
contents = []
|
||||||
|
for msg in messages:
|
||||||
|
role = "user" if msg["role"] == "user" else "model"
|
||||||
|
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||||
|
}
|
||||||
|
if system_prompt:
|
||||||
|
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||||
|
|
||||||
|
async with self.client.stream("POST", url, json=payload) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
data = json.loads(line[6:])
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if candidates and len(candidates) > 0:
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
if parts and len(parts) > 0:
|
||||||
|
text = parts[0].get("text", "")
|
||||||
|
if text:
|
||||||
|
yield text
|
||||||
|
except:
|
||||||
|
continue
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
"""OpenAI 客户端"""
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
|
from app.logger import get_logger
|
||||||
|
from .base_client import BaseAIClient
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIClient(BaseAIClient):
|
||||||
|
"""OpenAI API 客户端"""
|
||||||
|
|
||||||
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_payload(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
payload["stream"] = True
|
||||||
|
if tools:
|
||||||
|
# 清理 $schema 字段
|
||||||
|
cleaned = []
|
||||||
|
for t in tools:
|
||||||
|
tc = t.copy()
|
||||||
|
if "function" in tc and "parameters" in tc["function"]:
|
||||||
|
tc["function"]["parameters"] = {
|
||||||
|
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
|
||||||
|
}
|
||||||
|
cleaned.append(tc)
|
||||||
|
payload["tools"] = cleaned
|
||||||
|
if tool_choice:
|
||||||
|
payload["tool_choice"] = tool_choice
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
|
||||||
|
data = await self._request_with_retry("POST", "/chat/completions", payload)
|
||||||
|
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if not choices or len(choices) == 0:
|
||||||
|
raise ValueError("API 返回空 choices 或 choices 为空列表")
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
return {
|
||||||
|
"content": message.get("content", ""),
|
||||||
|
"tool_calls": message.get("tool_calls"),
|
||||||
|
"finish_reason": choice.get("finish_reason"),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def chat_completion_stream(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
|
||||||
|
|
||||||
|
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if choices and len(choices) > 0:
|
||||||
|
content = choices[0].get("delta", {}).get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""AI 服务配置管理"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HTTPClientConfig:
|
||||||
|
"""HTTP 客户端配置"""
|
||||||
|
connect_timeout: float = 90.0
|
||||||
|
read_timeout: float = 300.0
|
||||||
|
write_timeout: float = 90.0
|
||||||
|
pool_timeout: float = 90.0
|
||||||
|
max_keepalive_connections: int = 50
|
||||||
|
max_connections: int = 100
|
||||||
|
keepalive_expiry: float = 60.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetryConfig:
|
||||||
|
"""重试配置"""
|
||||||
|
max_retries: int = 3
|
||||||
|
base_delay: float = 0.2
|
||||||
|
max_delay: float = 10.0
|
||||||
|
exponential_base: int = 2
|
||||||
|
non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""限流配置"""
|
||||||
|
max_concurrent_requests: int = 5
|
||||||
|
request_delay: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIClientConfig:
|
||||||
|
"""AI 客户端完整配置"""
|
||||||
|
http: HTTPClientConfig = field(default_factory=HTTPClientConfig)
|
||||||
|
retry: RetryConfig = field(default_factory=RetryConfig)
|
||||||
|
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局默认配置
|
||||||
|
default_config = AIClientConfig()
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""AI Provider 模块"""
|
||||||
|
from .base_provider import BaseAIProvider
|
||||||
|
from .openai_provider import OpenAIProvider
|
||||||
|
from .anthropic_provider import AnthropicProvider
|
||||||
|
|
||||||
|
__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"]
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
"""Anthropic Provider"""
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||||
|
from .base_provider import BaseAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(BaseAIProvider):
|
||||||
|
"""Anthropic 提供商"""
|
||||||
|
|
||||||
|
def __init__(self, client: AnthropicClient):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return await self.client.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
async for chunk in self.client.chat_completion_stream(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""AI Provider 基类"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAIProvider(ABC):
|
||||||
|
"""AI 提供商抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""生成文本"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""流式生成"""
|
||||||
|
pass
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""Gemini Provider"""
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
from app.services.ai_clients.gemini_client import GeminiClient
|
||||||
|
from .base_provider import BaseAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProvider(BaseAIProvider):
|
||||||
|
def __init__(self, client: GeminiClient):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return await self.client.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
async for chunk in self.client.chat_completion_stream(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""OpenAI Provider"""
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from app.services.ai_clients.openai_client import OpenAIClient
|
||||||
|
from .base_provider import BaseAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseAIProvider):
|
||||||
|
"""OpenAI 提供商"""
|
||||||
|
|
||||||
|
def __init__(self, client: OpenAIClient):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
return await self.client.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
async for chunk in self.client.chat_completion_stream(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
+177
-1295
File diff suppressed because it is too large
Load Diff
@@ -263,7 +263,7 @@ class AutoCharacterService:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1
|
max_tool_rounds=2
|
||||||
)
|
)
|
||||||
content = result.get("content", "")
|
content = result.get("content", "")
|
||||||
# 使用统一的JSON清洗方法
|
# 使用统一的JSON清洗方法
|
||||||
@@ -362,7 +362,7 @@ class AutoCharacterService:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
db_session=db,
|
db_session=db,
|
||||||
enable_mcp=True,
|
enable_mcp=True,
|
||||||
max_tool_rounds=1
|
max_tool_rounds=2
|
||||||
)
|
)
|
||||||
content = result.get("content", "")
|
content = result.get("content", "")
|
||||||
# 使用统一的JSON清洗方法
|
# 使用统一的JSON清洗方法
|
||||||
|
|||||||
@@ -0,0 +1,149 @@
|
|||||||
|
"""JSON 处理工具类"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
from app.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def clean_json_response(text: str) -> str:
|
||||||
|
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
|
||||||
|
try:
|
||||||
|
if not text:
|
||||||
|
logger.warning("⚠️ clean_json_response: 输入为空")
|
||||||
|
return text
|
||||||
|
|
||||||
|
original_length = len(text)
|
||||||
|
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
|
||||||
|
|
||||||
|
# 去除 markdown 代码块
|
||||||
|
text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE)
|
||||||
|
text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE)
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
if len(text) != original_length:
|
||||||
|
logger.debug(f" 移除markdown后长度: {len(text)}")
|
||||||
|
|
||||||
|
# 尝试直接解析(快速路径)
|
||||||
|
try:
|
||||||
|
json.loads(text)
|
||||||
|
logger.debug(f"✅ 直接解析成功,无需清洗")
|
||||||
|
return text
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 找到第一个 { 或 [
|
||||||
|
start = -1
|
||||||
|
for i, c in enumerate(text):
|
||||||
|
if c in ('{', '['):
|
||||||
|
start = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if start == -1:
|
||||||
|
logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [")
|
||||||
|
logger.debug(f" 文本预览: {text[:200]}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
if start > 0:
|
||||||
|
logger.debug(f" 跳过前{start}个字符")
|
||||||
|
text = text[start:]
|
||||||
|
|
||||||
|
# 改进的括号匹配算法(更严格的字符串处理)
|
||||||
|
stack = []
|
||||||
|
i = 0
|
||||||
|
end = -1
|
||||||
|
|
||||||
|
while i < len(text):
|
||||||
|
c = text[i]
|
||||||
|
|
||||||
|
# 处理字符串(关键:正确处理转义)
|
||||||
|
if c == '"':
|
||||||
|
# 计算前面有多少个连续的反斜杠
|
||||||
|
num_backslashes = 0
|
||||||
|
j = i - 1
|
||||||
|
while j >= 0 and text[j] == '\\':
|
||||||
|
num_backslashes += 1
|
||||||
|
j -= 1
|
||||||
|
|
||||||
|
# 偶数个反斜杠(包括0)表示引号未被转义
|
||||||
|
if num_backslashes % 2 == 0:
|
||||||
|
# 这是字符串边界,跳过整个字符串
|
||||||
|
i += 1
|
||||||
|
while i < len(text):
|
||||||
|
if text[i] == '"':
|
||||||
|
# 再次检查转义
|
||||||
|
num_backslashes = 0
|
||||||
|
j = i - 1
|
||||||
|
while j >= 0 and text[j] == '\\':
|
||||||
|
num_backslashes += 1
|
||||||
|
j -= 1
|
||||||
|
if num_backslashes % 2 == 0:
|
||||||
|
# 字符串结束
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理括号(只有在字符串外部才有效)
|
||||||
|
if c == '{' or c == '[':
|
||||||
|
stack.append(c)
|
||||||
|
elif c == '}':
|
||||||
|
if len(stack) > 0 and stack[-1] == '{':
|
||||||
|
stack.pop()
|
||||||
|
if len(stack) == 0:
|
||||||
|
end = i + 1
|
||||||
|
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1] if stack else 'empty'}")
|
||||||
|
elif c == ']':
|
||||||
|
if len(stack) > 0 and stack[-1] == '[':
|
||||||
|
stack.pop()
|
||||||
|
if len(stack) == 0:
|
||||||
|
end = i + 1
|
||||||
|
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1] if stack else 'empty'}")
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# 提取结果
|
||||||
|
if end > 0:
|
||||||
|
result = text[:end]
|
||||||
|
logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}")
|
||||||
|
else:
|
||||||
|
result = text
|
||||||
|
logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)})")
|
||||||
|
logger.debug(f" 栈状态: {stack}")
|
||||||
|
|
||||||
|
# 验证清洗后的结果
|
||||||
|
try:
|
||||||
|
json.loads(result)
|
||||||
|
logger.debug(f"✅ 清洗后JSON验证成功")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
|
||||||
|
logger.debug(f" 结果预览: {result[:500]}")
|
||||||
|
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ clean_json_response 出错: {e}")
|
||||||
|
logger.error(f" 文本长度: {len(text) if text else 0}")
|
||||||
|
logger.error(f" 文本预览: {text[:200] if text else 'None'}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json(text: str) -> Union[Dict, List]:
|
||||||
|
"""解析 JSON"""
|
||||||
|
try:
|
||||||
|
cleaned = clean_json_response(text)
|
||||||
|
return json.loads(cleaned)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ parse_json 出错: {e}")
|
||||||
|
logger.error(f" 原始文本长度: {len(text) if text else 0}")
|
||||||
|
logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}")
|
||||||
|
raise
|
||||||
@@ -175,20 +175,34 @@ class MCPTestService:
|
|||||||
db=db_session
|
db=db_session
|
||||||
)
|
)
|
||||||
|
|
||||||
ai_response = await ai_service.generate_text(
|
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
|
||||||
|
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
|
||||||
|
# 因此这里需要特殊处理
|
||||||
|
accumulated_text = ""
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
async for chunk in ai_service.generate_text_stream(
|
||||||
prompt=prompts["user"],
|
prompt=prompts["user"],
|
||||||
system_prompt=prompts["system"],
|
system_prompt=prompts["system"],
|
||||||
tools=openai_tools,
|
tools=openai_tools,
|
||||||
tool_choice="required"
|
tool_choice="required"
|
||||||
)
|
):
|
||||||
|
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
|
||||||
|
if isinstance(chunk, dict):
|
||||||
|
if "tool_calls" in chunk:
|
||||||
|
tool_calls = chunk["tool_calls"]
|
||||||
|
if "content" in chunk:
|
||||||
|
accumulated_text += chunk.get("content", "")
|
||||||
|
else:
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
# 5. 检查AI是否返回工具调用
|
# 5. 检查AI是否返回工具调用
|
||||||
if not ai_response.get("tool_calls"):
|
if not tool_calls:
|
||||||
logger.error(f"❌ AI未返回工具调用")
|
logger.error(f"❌ AI未返回工具调用")
|
||||||
return MCPTestResult(
|
return MCPTestResult(
|
||||||
success=False,
|
success=False,
|
||||||
message="❌ AI Function Calling失败",
|
message="❌ AI Function Calling失败",
|
||||||
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
|
error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}",
|
||||||
tools_count=len(tools),
|
tools_count=len(tools),
|
||||||
suggestions=[
|
suggestions=[
|
||||||
"请确认使用的AI模型支持Function Calling",
|
"请确认使用的AI模型支持Function Calling",
|
||||||
@@ -198,7 +212,7 @@ class MCPTestService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6. 解析工具调用
|
# 6. 解析工具调用
|
||||||
tool_call = ai_response["tool_calls"][0]
|
tool_call = tool_calls[0]
|
||||||
function = tool_call["function"]
|
function = tool_call["function"]
|
||||||
tool_name = function["name"]
|
tool_name = function["name"]
|
||||||
test_arguments = function["arguments"]
|
test_arguments = function["arguments"]
|
||||||
|
|||||||
@@ -386,17 +386,30 @@ class MCPToolService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析插件名和工具名
|
# 解析插件名和工具名
|
||||||
|
logger.debug(f"🔍 解析工具名称: {function_name}")
|
||||||
if "_" in function_name:
|
if "_" in function_name:
|
||||||
plugin_name, tool_name = function_name.split("_", 1)
|
plugin_name, tool_name = function_name.split("_", 1)
|
||||||
|
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"无效的工具名称格式: {function_name}")
|
raise ValueError(f"无效的工具名称格式: {function_name}")
|
||||||
|
|
||||||
# 解析参数
|
# 解析参数
|
||||||
arguments_str = tool_call["function"]["arguments"]
|
arguments_str = tool_call["function"]["arguments"]
|
||||||
|
logger.debug(f"🔍 解析参数:")
|
||||||
|
logger.debug(f" 原始类型: {type(arguments_str)}")
|
||||||
|
logger.debug(f" 原始内容: {arguments_str}")
|
||||||
|
|
||||||
if isinstance(arguments_str, str):
|
if isinstance(arguments_str, str):
|
||||||
arguments = json.loads(arguments_str)
|
try:
|
||||||
|
arguments = json.loads(arguments_str)
|
||||||
|
logger.debug(f" ✅ JSON解析成功: {arguments}")
|
||||||
|
except json.JSONDecodeError as je:
|
||||||
|
logger.error(f" ❌ JSON解析失败: {je}")
|
||||||
|
logger.error(f" 原始字符串: '{arguments_str}'")
|
||||||
|
raise ValueError(f"参数JSON解析失败: {je}")
|
||||||
else:
|
else:
|
||||||
arguments = arguments_str
|
arguments = arguments_str
|
||||||
|
logger.debug(f" 直接使用dict类型参数")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"执行工具: {plugin_name}.{tool_name}, "
|
f"执行工具: {plugin_name}.{tool_name}, "
|
||||||
|
|||||||
@@ -71,24 +71,15 @@ class PlotAnalyzer:
|
|||||||
# 调用AI进行分析
|
# 调用AI进行分析
|
||||||
# 注意:不指定max_tokens,使用用户在设置中配置的值
|
# 注意:不指定max_tokens,使用用户在设置中配置的值
|
||||||
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
|
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
|
||||||
response = await self.ai_service.generate_text(
|
accumulated_text = ""
|
||||||
|
async for chunk in self.ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
||||||
)
|
):
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
# 🔍 添加调试日志:查看AI返回的原始内容
|
# 提取内容
|
||||||
# logger.info(f"🔍 AI返回类型: {type(response)}")
|
response_text = accumulated_text
|
||||||
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
|
|
||||||
|
|
||||||
# 从返回的字典中提取content字段
|
|
||||||
if isinstance(response, dict):
|
|
||||||
response_text = response.get('content', '')
|
|
||||||
if not response_text:
|
|
||||||
logger.error("❌ AI返回的字典中没有content字段或content为空")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# 兼容旧的字符串返回格式
|
|
||||||
response_text = response
|
|
||||||
|
|
||||||
# 解析JSON结果
|
# 解析JSON结果
|
||||||
analysis_result = self._parse_analysis_response(response_text)
|
analysis_result = self._parse_analysis_response(response_text)
|
||||||
|
|||||||
@@ -133,14 +133,16 @@ class PlotExpansionService:
|
|||||||
|
|
||||||
# 调用AI生成章节规划
|
# 调用AI生成章节规划
|
||||||
logger.info(f"调用AI生成章节规划...")
|
logger.info(f"调用AI生成章节规划...")
|
||||||
ai_response = await self.ai_service.generate_text(
|
accumulated_text = ""
|
||||||
|
async for chunk in self.ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
)
|
):
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
# 提取内容
|
# 提取内容
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
ai_content = accumulated_text
|
||||||
|
|
||||||
# 解析AI响应
|
# 解析AI响应
|
||||||
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
|
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||||
@@ -236,14 +238,16 @@ class PlotExpansionService:
|
|||||||
|
|
||||||
# 调用AI生成当前批次
|
# 调用AI生成当前批次
|
||||||
logger.info(f"调用AI生成第{batch_num + 1}批...")
|
logger.info(f"调用AI生成第{batch_num + 1}批...")
|
||||||
ai_response = await self.ai_service.generate_text(
|
accumulated_text = ""
|
||||||
|
async for chunk in self.ai_service.generate_text_stream(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=model
|
model=model
|
||||||
)
|
):
|
||||||
|
accumulated_text += chunk
|
||||||
|
|
||||||
# 提取内容
|
# 提取内容
|
||||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
ai_content = accumulated_text
|
||||||
|
|
||||||
# 解析AI响应
|
# 解析AI响应
|
||||||
batch_plans = self._parse_expansion_response(ai_content, outline.id)
|
batch_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||||
|
|||||||
@@ -6,142 +6,6 @@ import json
|
|||||||
class WritingStyleManager:
|
class WritingStyleManager:
|
||||||
"""写作风格管理器"""
|
"""写作风格管理器"""
|
||||||
|
|
||||||
# 预设风格配置
|
|
||||||
PRESET_STYLES = {
|
|
||||||
"natural": {
|
|
||||||
"name": "自然沉浸 (Natural & Immersive)",
|
|
||||||
"description": "祛除翻译腔,强调生活质感,像呼吸一样自然的叙事",
|
|
||||||
"prompt_content": """
|
|
||||||
### 核心指令:自然沉浸风格
|
|
||||||
请模拟人类作家在放松状态下的写作,通过以下规则消除“AI味”:
|
|
||||||
|
|
||||||
1. **拒绝翻译腔与书面化**:
|
|
||||||
- 严禁使用“一种...的感觉”、“随着...”、“与此同时”等连接词。
|
|
||||||
- 多用短句和“流水句”,模拟人类视线的移动和思维的跳跃。
|
|
||||||
- 口语化叙述,但不要滥用语气词,而是通过句子的长短节奏来体现语气。
|
|
||||||
|
|
||||||
2. **生活化的颗粒度**:
|
|
||||||
- 描写不要宏大,要聚焦在具体的、微小的生活细节(如:杯子上的水渍、衣服的褶皱)。
|
|
||||||
- 允许逻辑上的适度“松散”,不要让每句话都像说明书一样严丝合缝。
|
|
||||||
|
|
||||||
3. **具体的“展示”**:
|
|
||||||
- 不要写“他很生气”,要写他“把烟头按灭在还没吃完的米饭里”。
|
|
||||||
- 避免使用抽象的形容词(如:巨大的、美丽的、悲伤的),必须用名词和动词来承载画面。
|
|
||||||
"""
|
|
||||||
},
|
|
||||||
"classical": {
|
|
||||||
"name": "古典雅致 (Classical & Elegant)",
|
|
||||||
"description": "白话文与古典韵味的结合,强调留白与炼字",
|
|
||||||
"prompt_content": """
|
|
||||||
### 核心指令:古典雅致风格
|
|
||||||
请模仿民国时期或古典白话小说的笔触,构建端庄且富有余味的叙事:
|
|
||||||
|
|
||||||
1. **炼字与韵律**:
|
|
||||||
- 尽量使用双音节词或四字短语,但严禁堆砌辞藻。
|
|
||||||
- 注重句子的声调韵律,读起来要有金石之声或流水之韵。
|
|
||||||
- 适当使用倒装句或定语后置,增加古雅感。
|
|
||||||
|
|
||||||
2. **克制的修辞**:
|
|
||||||
- 少用现代的比喻(如“像机器一样”),多用取自自然的比喻(如“如风过林”)。
|
|
||||||
- **意在言外**:不要把话说透,留三分余地。写景即是写情,不要将情感直接剖白。
|
|
||||||
|
|
||||||
3. **禁忌**:
|
|
||||||
- 严禁使用现代科技词汇(除非题材需要)、网络用语或过于西化的句式(如长定语从句)。
|
|
||||||
- 避免滥用“之乎者也”,追求的是“神似”而非生硬的半文半白。
|
|
||||||
"""
|
|
||||||
},
|
|
||||||
"modern": {
|
|
||||||
"name": "冷硬现代 (Modern & Hard-boiled)",
|
|
||||||
"description": "海明威式的冰山理论,节奏极快,零度情感",
|
|
||||||
"prompt_content": """
|
|
||||||
### 核心指令:冷硬现代风格
|
|
||||||
请采用“极简主义”和“零度写作”手法,去除所有矫饰:
|
|
||||||
|
|
||||||
1. **冰山理论**:
|
|
||||||
- **只写动作和对话,完全剔除心理描写和形容词堆砌。**
|
|
||||||
- 不要告诉读者角色感觉如何,通过角色的反应和环境的冷峻反馈来体现。
|
|
||||||
|
|
||||||
2. **电影蒙太奇节奏**:
|
|
||||||
- 句子要短、脆、硬。像手术刀一样切开场景。
|
|
||||||
- 段落之间快速切换,不要用过渡句连接,直接跳切。
|
|
||||||
|
|
||||||
3. **高信息密度**:
|
|
||||||
- 删除所有废话。如果一个词删掉不影响理解,就删掉它。
|
|
||||||
- 多用名词和强动词(Strong Verbs),少用副词(Adverbs)。例如:不要写“他重重地关上门”,写“他摔上了门”。
|
|
||||||
"""
|
|
||||||
},
|
|
||||||
"poetic": {
|
|
||||||
"name": "意识流 (Stream of Consciousness)",
|
|
||||||
"description": "注重感官通感与内心独白,打破现实与幻想的边界",
|
|
||||||
"prompt_content": """
|
|
||||||
### 核心指令:意识流/诗意风格
|
|
||||||
请侧重于主观感受的流动,而非客观事实的记录:
|
|
||||||
|
|
||||||
1. **通感与陌生化**:
|
|
||||||
- 打通五感(如:听到了颜色的声音,闻到了悲伤的气味)。
|
|
||||||
- 使用“陌生化”的语言,把熟悉的事物写得陌生,迫使读者重新审视。
|
|
||||||
|
|
||||||
2. **情绪的具象化**:
|
|
||||||
- **绝对禁止**直接出现“开心”、“痛苦”等抽象词汇。
|
|
||||||
- 必须寻找“客观对应物”(Objective Correlative),将情绪投射到具体的景物上(如:生锈的铁轨、发霉的橘子)。
|
|
||||||
|
|
||||||
3. **流动的句式**:
|
|
||||||
- 句子可以很长,包含多重意象的叠加。
|
|
||||||
- 允许思维的非线性跳跃,模拟梦境或深层潜意识的逻辑。
|
|
||||||
"""
|
|
||||||
},
|
|
||||||
"concise": {
|
|
||||||
"name": "白描速写 (Sketch & Concise)",
|
|
||||||
"description": "只有骨架的叙事,强调绝对的精准和功能性",
|
|
||||||
"prompt_content": """
|
|
||||||
### 核心指令:白描速写风格
|
|
||||||
请像速写画家一样,只勾勒线条,不涂抹色彩:
|
|
||||||
|
|
||||||
1. **功能性第一**:
|
|
||||||
- 每一句话必须推动情节,或者揭示关键信息。
|
|
||||||
- 如果一句话只是为了渲染气氛,删掉它。
|
|
||||||
|
|
||||||
2. **主谓宾结构**:
|
|
||||||
- 尽量使用简单的主谓宾结构,减少修饰语。
|
|
||||||
- 避免复杂的从句和嵌套结构。
|
|
||||||
|
|
||||||
3. **直击核心**:
|
|
||||||
- 对话直接进入主题,去除寒暄和废话。
|
|
||||||
- 环境描写仅限于对情节有物理影响的物体(如:挡路的石头、藏在桌下的枪)。
|
|
||||||
"""
|
|
||||||
},
|
|
||||||
"vivid": {
|
|
||||||
"name": "感官特写 (Sensory & Vivid)",
|
|
||||||
"description": "高分辨率的描写,强调材质、光影和微观细节",
|
|
||||||
"prompt_content": """
|
|
||||||
### 核心指令:感官特写风格
|
|
||||||
请将镜头推到特写级别(Macro Lens),捕捉常人忽略的细节:
|
|
||||||
|
|
||||||
1. **反套路细节**:
|
|
||||||
- 不要写大众化的细节(如:蓝天白云),要写具有**独特性**的细节(如:云层边缘那抹像淤青一样的灰紫色)。
|
|
||||||
- 关注物体的**质感(Texture)**:粗糙的、粘稠的、冰凉的、颗粒感的。
|
|
||||||
|
|
||||||
2. **动态捕捉**:
|
|
||||||
- 不要写静止的画面,要写光影的流变、灰尘的飞舞、肌肉的抽动。
|
|
||||||
- 让读者产生生理性的反应(如:痛感、饥饿感、窒息感)。
|
|
||||||
|
|
||||||
3. **禁用词汇**:
|
|
||||||
- 禁止使用“映入眼帘”、“宛如画卷”等陈词滥调。
|
|
||||||
- 必须用具体的动词带动感官描写。
|
|
||||||
"""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_preset_style(cls, preset_id: str) -> Optional[Dict[str, str]]:
|
|
||||||
"""获取预设风格配置"""
|
|
||||||
return cls.PRESET_STYLES.get(preset_id)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_all_presets(cls) -> Dict[str, Dict[str, str]]:
|
|
||||||
"""获取所有预设风格"""
|
|
||||||
return cls.PRESET_STYLES
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_style_to_prompt(base_prompt: str, style_content: str) -> str:
|
def apply_style_to_prompt(base_prompt: str, style_content: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -692,9 +556,8 @@ class PromptService:
|
|||||||
|
|
||||||
6. **承上启下**:
|
6. **承上启下**:
|
||||||
- 开头自然衔接上一章结尾(但不重复上一章内容)
|
- 开头自然衔接上一章结尾(但不重复上一章内容)
|
||||||
- 结尾为下一章做好铺垫
|
|
||||||
|
|
||||||
6. **记忆系统使用指南**:
|
7. **记忆系统使用指南**:
|
||||||
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
|
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
|
||||||
- **语义相关记忆**:参考相似情节的处理方式
|
- **语义相关记忆**:参考相似情节的处理方式
|
||||||
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
|
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
|
||||||
@@ -1308,16 +1171,15 @@ class PromptService:
|
|||||||
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
|
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
|
||||||
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
|
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
|
||||||
|
|
||||||
# 灵感模式提示词字典
|
# 灵感模式 - 书名生成(系统提示词)
|
||||||
INSPIRATION_PROMPTS = {
|
INSPIRATION_TITLE_SYSTEM = """你是一位专业的小说创作顾问。
|
||||||
"title": {
|
|
||||||
"system": """你是一位专业的小说创作顾问。
|
|
||||||
用户的原始想法:{initial_idea}
|
用户的原始想法:{initial_idea}
|
||||||
|
|
||||||
请根据用户的想法,生成6个吸引人的书名建议,要求:
|
请根据用户的想法,生成6个吸引人的书名建议,要求:
|
||||||
1. 紧扣用户的原始想法和核心故事构思
|
1. 紧扣用户的原始想法和核心故事构思
|
||||||
2. 富有创意和吸引力
|
2. 富有创意和吸引力
|
||||||
3. 涵盖不同的风格倾向
|
3. 涵盖不同的风格倾向
|
||||||
|
4. 书名中不要带有"《》"符号
|
||||||
|
|
||||||
返回JSON格式:
|
返回JSON格式:
|
||||||
{{
|
{{
|
||||||
@@ -1325,11 +1187,13 @@ class PromptService:
|
|||||||
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
|
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
|
||||||
}}
|
}}
|
||||||
|
|
||||||
只返回纯JSON,不要有其他文字。""",
|
只返回纯JSON,不要有其他文字。"""
|
||||||
"user": "用户的想法:{initial_idea}\n请生成6个书名建议"
|
|
||||||
},
|
# 灵感模式 - 书名生成(用户提示词)
|
||||||
"description": {
|
INSPIRATION_TITLE_USER = "用户的想法:{initial_idea}\n请生成6个书名建议"
|
||||||
"system": """你是一位专业的小说创作顾问。
|
|
||||||
|
# 灵感模式 - 简介生成(系统提示词)
|
||||||
|
INSPIRATION_DESCRIPTION_SYSTEM = """你是一位专业的小说创作顾问。
|
||||||
用户的原始想法:{initial_idea}
|
用户的原始想法:{initial_idea}
|
||||||
已确定的书名:{title}
|
已确定的书名:{title}
|
||||||
|
|
||||||
@@ -1343,11 +1207,13 @@ class PromptService:
|
|||||||
返回JSON格式:
|
返回JSON格式:
|
||||||
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
|
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
|
||||||
|
|
||||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
只返回纯JSON,不要有其他文字,不要换行。"""
|
||||||
"user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
|
||||||
},
|
# 灵感模式 - 简介生成(用户提示词)
|
||||||
"theme": {
|
INSPIRATION_DESCRIPTION_USER = "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
||||||
"system": """你是一位专业的小说创作顾问。
|
|
||||||
|
# 灵感模式 - 主题生成(系统提示词)
|
||||||
|
INSPIRATION_THEME_SYSTEM = """你是一位专业的小说创作顾问。
|
||||||
用户的原始想法:{initial_idea}
|
用户的原始想法:{initial_idea}
|
||||||
小说信息:
|
小说信息:
|
||||||
- 书名:{title}
|
- 书名:{title}
|
||||||
@@ -1363,11 +1229,13 @@ class PromptService:
|
|||||||
返回JSON格式:
|
返回JSON格式:
|
||||||
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
|
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
|
||||||
|
|
||||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
只返回纯JSON,不要有其他文字,不要换行。"""
|
||||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
|
||||||
},
|
# 灵感模式 - 主题生成(用户提示词)
|
||||||
"genre": {
|
INSPIRATION_THEME_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
||||||
"system": """你是一位专业的小说创作顾问。
|
|
||||||
|
# 灵感模式 - 类型生成(系统提示词)
|
||||||
|
INSPIRATION_GENRE_SYSTEM = """你是一位专业的小说创作顾问。
|
||||||
用户的原始想法:{initial_idea}
|
用户的原始想法:{initial_idea}
|
||||||
小说信息:
|
小说信息:
|
||||||
- 书名:{title}
|
- 书名:{title}
|
||||||
@@ -1384,10 +1252,10 @@ class PromptService:
|
|||||||
返回JSON格式:
|
返回JSON格式:
|
||||||
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
|
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
|
||||||
|
|
||||||
只返回紧凑的纯JSON,不要换行,不要有其他文字。""",
|
只返回紧凑的纯JSON,不要换行,不要有其他文字。"""
|
||||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
|
||||||
}
|
# 灵感模式 - 类型生成(用户提示词)
|
||||||
}
|
INSPIRATION_GENRE_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
||||||
|
|
||||||
# 灵感模式智能补全提示词
|
# 灵感模式智能补全提示词
|
||||||
INSPIRATION_QUICK_COMPLETE = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
|
INSPIRATION_QUICK_COMPLETE = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
|
||||||
@@ -1887,7 +1755,26 @@ class PromptService:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_inspiration_prompt(cls, step: str) -> Optional[Dict[str, str]]:
|
def get_inspiration_prompt(cls, step: str) -> Optional[Dict[str, str]]:
|
||||||
"""获取灵感模式指定步骤的提示词"""
|
"""获取灵感模式指定步骤的提示词"""
|
||||||
return cls.INSPIRATION_PROMPTS.get(step)
|
# 根据步骤名称返回对应的system和user提示词
|
||||||
|
step_map = {
|
||||||
|
"title": {
|
||||||
|
"system": cls.INSPIRATION_TITLE_SYSTEM,
|
||||||
|
"user": cls.INSPIRATION_TITLE_USER
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"system": cls.INSPIRATION_DESCRIPTION_SYSTEM,
|
||||||
|
"user": cls.INSPIRATION_DESCRIPTION_USER
|
||||||
|
},
|
||||||
|
"theme": {
|
||||||
|
"system": cls.INSPIRATION_THEME_SYSTEM,
|
||||||
|
"user": cls.INSPIRATION_THEME_USER
|
||||||
|
},
|
||||||
|
"genre": {
|
||||||
|
"system": cls.INSPIRATION_GENRE_SYSTEM,
|
||||||
|
"user": cls.INSPIRATION_GENRE_USER
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return step_map.get(step)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_inspiration_quick_complete_prompt(cls, existing: str) -> Dict[str, str]:
|
def get_inspiration_quick_complete_prompt(cls, existing: str) -> Dict[str, str]:
|
||||||
@@ -1997,17 +1884,12 @@ class PromptService:
|
|||||||
# 2. 降级到系统默认模板
|
# 2. 降级到系统默认模板
|
||||||
logger.info(f"⚪ 使用系统默认提示词: user_id={user_id}, template_key={template_key} (未找到自定义模板)")
|
logger.info(f"⚪ 使用系统默认提示词: user_id={user_id}, template_key={template_key} (未找到自定义模板)")
|
||||||
|
|
||||||
# 特殊处理灵感模式的提示词(存储在INSPIRATION_PROMPTS字典中)
|
# 特殊处理灵感模式的提示词(直接从类属性获取)
|
||||||
if template_key.startswith("INSPIRATION_"):
|
if template_key.startswith("INSPIRATION_"):
|
||||||
# 提取步骤名称(如 INSPIRATION_TITLE -> title)
|
# 直接从类属性获取
|
||||||
step = template_key.replace("INSPIRATION_", "").lower()
|
template_content = getattr(cls, template_key, None)
|
||||||
inspiration_prompt = cls.INSPIRATION_PROMPTS.get(step)
|
if template_content:
|
||||||
if inspiration_prompt:
|
return template_content
|
||||||
# 返回JSON格式的提示词
|
|
||||||
return json.dumps(inspiration_prompt, ensure_ascii=False)
|
|
||||||
# 如果是INSPIRATION_QUICK_COMPLETE
|
|
||||||
if template_key == "INSPIRATION_QUICK_COMPLETE":
|
|
||||||
return cls.INSPIRATION_QUICK_COMPLETE
|
|
||||||
|
|
||||||
# 其他模板直接从类属性获取
|
# 其他模板直接从类属性获取
|
||||||
template_content = getattr(cls, template_key, None)
|
template_content = getattr(cls, template_key, None)
|
||||||
@@ -2182,6 +2064,60 @@ class PromptService:
|
|||||||
"category": "世界构建",
|
"category": "世界构建",
|
||||||
"description": "根据世界观自动生成完整的职业体系,包括主职业和副职业",
|
"description": "根据世界观自动生成完整的职业体系,包括主职业和副职业",
|
||||||
"parameters": ["title", "genre", "theme", "time_period", "location", "atmosphere", "rules"]
|
"parameters": ["title", "genre", "theme", "time_period", "location", "atmosphere", "rules"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_TITLE_SYSTEM": {
|
||||||
|
"name": "灵感模式-书名生成(系统提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据用户的原始想法生成6个书名建议的系统提示词",
|
||||||
|
"parameters": ["initial_idea"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_TITLE_USER": {
|
||||||
|
"name": "灵感模式-书名生成(用户提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据用户的原始想法生成6个书名建议的用户提示词",
|
||||||
|
"parameters": ["initial_idea"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_DESCRIPTION_SYSTEM": {
|
||||||
|
"name": "灵感模式-简介生成(系统提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据用户想法和书名生成6个简介选项的系统提示词",
|
||||||
|
"parameters": ["initial_idea", "title"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_DESCRIPTION_USER": {
|
||||||
|
"name": "灵感模式-简介生成(用户提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据用户想法和书名生成6个简介选项的用户提示词",
|
||||||
|
"parameters": ["initial_idea", "title"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_THEME_SYSTEM": {
|
||||||
|
"name": "灵感模式-主题生成(系统提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据书名和简介生成6个深刻的主题选项的系统提示词",
|
||||||
|
"parameters": ["initial_idea", "title", "description"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_THEME_USER": {
|
||||||
|
"name": "灵感模式-主题生成(用户提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据书名和简介生成6个深刻的主题选项的用户提示词",
|
||||||
|
"parameters": ["initial_idea", "title", "description"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_GENRE_SYSTEM": {
|
||||||
|
"name": "灵感模式-类型生成(系统提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据小说信息生成6个合适的类型标签的系统提示词",
|
||||||
|
"parameters": ["initial_idea", "title", "description", "theme"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_GENRE_USER": {
|
||||||
|
"name": "灵感模式-类型生成(用户提示词)",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据小说信息生成6个合适的类型标签的用户提示词",
|
||||||
|
"parameters": ["initial_idea", "title", "description", "theme"]
|
||||||
|
},
|
||||||
|
"INSPIRATION_QUICK_COMPLETE": {
|
||||||
|
"name": "灵感模式-智能补全",
|
||||||
|
"category": "灵感模式",
|
||||||
|
"description": "根据用户提供的部分信息智能补全完整的小说方案",
|
||||||
|
"parameters": ["existing"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,11 +23,22 @@ class SSEResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
格式化后的SSE消息字符串
|
格式化后的SSE消息字符串
|
||||||
"""
|
"""
|
||||||
message = ""
|
try:
|
||||||
if event:
|
message = ""
|
||||||
message += f"event: {event}\n"
|
if event:
|
||||||
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
message += f"event: {event}\n"
|
||||||
return message
|
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
return message
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ SSE格式化失败: {type(e).__name__}: {e}")
|
||||||
|
logger.error(f" data类型: {type(data)}")
|
||||||
|
logger.error(f" data内容: {str(data)[:500]}")
|
||||||
|
# 返回错误消息而不是崩溃
|
||||||
|
error_message = ""
|
||||||
|
if event:
|
||||||
|
error_message += f"event: {event}\n"
|
||||||
|
error_message += f'data: {{"type": "error", "error": "SSE格式化失败: {str(e)}", "code": 500}}\n\n'
|
||||||
|
return error_message
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def send_progress(
|
async def send_progress(
|
||||||
|
|||||||
@@ -190,7 +190,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: (result) => {
|
onResult: (result) => {
|
||||||
@@ -236,7 +237,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(33 + Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: (result) => {
|
onResult: (result) => {
|
||||||
@@ -273,7 +275,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(66 + Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: () => {
|
onResult: () => {
|
||||||
@@ -336,15 +339,13 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
// 世界观生成占0%-20%,职业生成占20%-30%
|
// 直接使用后端返回的进度值
|
||||||
const baseProgress = Math.floor(prog / 5);
|
setProgress(prog);
|
||||||
setProgress(baseProgress);
|
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
|
|
||||||
// 检测职业体系生成阶段 - 必须包含"职业体系"才算职业阶段
|
// 检测职业体系生成阶段
|
||||||
if (msg.includes('职业体系')) {
|
if (msg.includes('职业体系')) {
|
||||||
if (msg.includes('开始') || msg.includes('生成')) {
|
if (msg.includes('开始') || msg.includes('生成')) {
|
||||||
// 职业开始时,世界观应该已完成
|
|
||||||
setGenerationSteps(prev => ({
|
setGenerationSteps(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
worldBuilding: 'completed',
|
worldBuilding: 'completed',
|
||||||
@@ -403,8 +404,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
// 角色生成占40%-70%
|
// 直接使用后端返回的进度值
|
||||||
setProgress(40 + Math.floor(prog * 0.3));
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: (result) => {
|
onResult: (result) => {
|
||||||
@@ -437,8 +438,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
// 大纲生成占70%-100%
|
// 直接使用后端返回的进度值
|
||||||
setProgress(70 + Math.floor(prog * 0.3));
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: () => {
|
onResult: () => {
|
||||||
@@ -533,8 +534,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
const baseProgress = Math.floor(prog / 5);
|
// 直接使用后端返回的进度值
|
||||||
setProgress(baseProgress);
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
|
|
||||||
// 检测职业体系生成阶段
|
// 检测职业体系生成阶段
|
||||||
@@ -604,7 +605,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(33 + Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: (result) => {
|
onResult: (result) => {
|
||||||
@@ -647,7 +649,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(66 + Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: () => {
|
onResult: () => {
|
||||||
@@ -707,7 +710,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(33 + Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: (result) => {
|
onResult: (result) => {
|
||||||
@@ -746,7 +750,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
onProgress: (msg, prog) => {
|
onProgress: (msg, prog) => {
|
||||||
setProgress(66 + Math.floor(prog / 3));
|
// 直接使用后端返回的进度值
|
||||||
|
setProgress(prog);
|
||||||
setProgressMessage(msg);
|
setProgressMessage(msg);
|
||||||
},
|
},
|
||||||
onResult: () => {
|
onResult: () => {
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ interface Message {
|
|||||||
options?: string[];
|
options?: string[];
|
||||||
isMultiSelect?: boolean;
|
isMultiSelect?: boolean;
|
||||||
optionsDisabled?: boolean; // 标记选项是否已禁用
|
optionsDisabled?: boolean; // 标记选项是否已禁用
|
||||||
|
canRefine?: boolean; // 是否可以优化(用于支持多轮对话)
|
||||||
|
step?: Step; // 当前步骤(用于反馈)
|
||||||
}
|
}
|
||||||
|
|
||||||
interface WizardData {
|
interface WizardData {
|
||||||
@@ -69,6 +71,11 @@ const Inspiration: React.FC = () => {
|
|||||||
const [wizardData, setWizardData] = useState<Partial<WizardData>>({});
|
const [wizardData, setWizardData] = useState<Partial<WizardData>>({});
|
||||||
// 保存用户的原始想法,用于保持上下文一致性
|
// 保存用户的原始想法,用于保持上下文一致性
|
||||||
const [initialIdea, setInitialIdea] = useState<string>('');
|
const [initialIdea, setInitialIdea] = useState<string>('');
|
||||||
|
|
||||||
|
// 反馈相关状态
|
||||||
|
const [feedbackValue, setFeedbackValue] = useState('');
|
||||||
|
const [showFeedbackInput, setShowFeedbackInput] = useState<number | null>(null); // 当前显示反馈输入的消息索引
|
||||||
|
const [refining, setRefining] = useState(false); // 正在优化选项
|
||||||
|
|
||||||
// 生成配置
|
// 生成配置
|
||||||
const [generationConfig, setGenerationConfig] = useState<GenerationConfig | null>(null);
|
const [generationConfig, setGenerationConfig] = useState<GenerationConfig | null>(null);
|
||||||
@@ -248,6 +255,86 @@ const Inspiration: React.FC = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 处理用户反馈,重新生成选项
|
||||||
|
const handleRefineOptions = async (messageIndex: number, feedback: string) => {
|
||||||
|
if (!feedback.trim()) {
|
||||||
|
message.warning('请输入您的反馈意见');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetMessage = messages[messageIndex];
|
||||||
|
if (!targetMessage.options || !targetMessage.step) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setRefining(true);
|
||||||
|
setShowFeedbackInput(null);
|
||||||
|
setFeedbackValue('');
|
||||||
|
|
||||||
|
// 先禁用旧的选项
|
||||||
|
setMessages(prev => {
|
||||||
|
const newMessages = [...prev];
|
||||||
|
if (newMessages[messageIndex]) {
|
||||||
|
newMessages[messageIndex] = {
|
||||||
|
...newMessages[messageIndex],
|
||||||
|
optionsDisabled: true,
|
||||||
|
canRefine: false, // 同时禁用反馈功能
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return newMessages;
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 添加用户反馈消息
|
||||||
|
const feedbackMessage: Message = {
|
||||||
|
type: 'user',
|
||||||
|
content: `💭 ${feedback}`,
|
||||||
|
};
|
||||||
|
setMessages(prev => [...prev, feedbackMessage]);
|
||||||
|
|
||||||
|
const step = targetMessage.step as 'title' | 'description' | 'theme' | 'genre';
|
||||||
|
|
||||||
|
// 构建上下文
|
||||||
|
const context: any = {
|
||||||
|
initial_idea: initialIdea,
|
||||||
|
title: wizardData.title,
|
||||||
|
description: wizardData.description,
|
||||||
|
theme: wizardData.theme,
|
||||||
|
};
|
||||||
|
|
||||||
|
// 调用refine接口
|
||||||
|
const response = await inspirationApi.refineOptions({
|
||||||
|
step,
|
||||||
|
context,
|
||||||
|
feedback,
|
||||||
|
previous_options: targetMessage.options,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.error) {
|
||||||
|
message.error(response.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加新的AI消息
|
||||||
|
const aiMessage: Message = {
|
||||||
|
type: 'ai',
|
||||||
|
content: response.prompt || `根据您的反馈,我重新生成了一些${step === 'title' ? '书名' : step === 'description' ? '简介' : step === 'theme' ? '主题' : '类型'}选项:`,
|
||||||
|
options: response.options || [],
|
||||||
|
isMultiSelect: step === 'genre',
|
||||||
|
canRefine: true,
|
||||||
|
step: step,
|
||||||
|
};
|
||||||
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
|
|
||||||
|
message.success('已根据您的反馈重新生成选项');
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error('优化选项失败:', error);
|
||||||
|
message.error(error.response?.data?.detail || '优化失败,请重试');
|
||||||
|
} finally {
|
||||||
|
setRefining(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// 步骤顺序
|
// 步骤顺序
|
||||||
const stepOrder: Step[] = ['idea', 'title', 'description', 'theme', 'genre', 'perspective', 'outline_mode', 'confirm'];
|
const stepOrder: Step[] = ['idea', 'title', 'description', 'theme', 'genre', 'perspective', 'outline_mode', 'confirm'];
|
||||||
|
|
||||||
@@ -297,7 +384,9 @@ const Inspiration: React.FC = () => {
|
|||||||
const aiMessage: Message = {
|
const aiMessage: Message = {
|
||||||
type: 'ai',
|
type: 'ai',
|
||||||
content: response.prompt || '请选择一个书名,或者输入你自己的:',
|
content: response.prompt || '请选择一个书名,或者输入你自己的:',
|
||||||
options: response.options
|
options: response.options,
|
||||||
|
canRefine: true,
|
||||||
|
step: 'title'
|
||||||
};
|
};
|
||||||
setMessages(prev => [...prev, aiMessage]);
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
setCurrentStep('title');
|
setCurrentStep('title');
|
||||||
@@ -497,6 +586,24 @@ const Inspiration: React.FC = () => {
|
|||||||
updatedData.genre = [input];
|
updatedData.genre = [input];
|
||||||
} else if (currentStep === 'perspective') {
|
} else if (currentStep === 'perspective') {
|
||||||
updatedData.narrative_perspective = input;
|
updatedData.narrative_perspective = input;
|
||||||
|
setWizardData(updatedData);
|
||||||
|
|
||||||
|
// 直接进入大纲模式选择
|
||||||
|
const aiMessage: Message = {
|
||||||
|
type: 'ai',
|
||||||
|
content: `很好!现在请选择你想要的大纲模式:
|
||||||
|
|
||||||
|
📋 一对一模式:传统模式,一个大纲对应一个章节,适合结构清晰、章节独立的小说。
|
||||||
|
|
||||||
|
📚 一对多模式:细化模式,一个大纲可以展开成多个章节,适合需要详细展开情节的小说。
|
||||||
|
|
||||||
|
请选择:`,
|
||||||
|
options: ['📋 一对一模式', '📚 一对多模式']
|
||||||
|
};
|
||||||
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
|
setCurrentStep('outline_mode');
|
||||||
|
setLoading(false);
|
||||||
|
return;
|
||||||
} else if (currentStep === 'outline_mode') {
|
} else if (currentStep === 'outline_mode') {
|
||||||
// 大纲模式不支持自定义输入
|
// 大纲模式不支持自定义输入
|
||||||
message.warning('请从选项中选择一个大纲模式');
|
message.warning('请从选项中选择一个大纲模式');
|
||||||
@@ -561,7 +668,16 @@ const Inspiration: React.FC = () => {
|
|||||||
const currentIndex = stepOrder.indexOf(currentStep);
|
const currentIndex = stepOrder.indexOf(currentStep);
|
||||||
const nextStep = stepOrder[currentIndex + 1];
|
const nextStep = stepOrder[currentIndex + 1];
|
||||||
|
|
||||||
if (nextStep === 'description') {
|
if (nextStep === 'perspective') {
|
||||||
|
// genre 步骤完成后,进入 perspective
|
||||||
|
const aiMessage: Message = {
|
||||||
|
type: 'ai',
|
||||||
|
content: '很好!接下来,请选择小说的叙事视角:',
|
||||||
|
options: ['第一人称', '第三人称', '全知视角']
|
||||||
|
};
|
||||||
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
|
setCurrentStep('perspective');
|
||||||
|
} else if (nextStep === 'description') {
|
||||||
const requestData = {
|
const requestData = {
|
||||||
step: 'description' as const,
|
step: 'description' as const,
|
||||||
context: {
|
context: {
|
||||||
@@ -587,7 +703,9 @@ const Inspiration: React.FC = () => {
|
|||||||
const aiMessage: Message = {
|
const aiMessage: Message = {
|
||||||
type: 'ai',
|
type: 'ai',
|
||||||
content: response.prompt || '请选择一个简介,或者输入你自己的:',
|
content: response.prompt || '请选择一个简介,或者输入你自己的:',
|
||||||
options: response.options
|
options: response.options,
|
||||||
|
canRefine: true,
|
||||||
|
step: 'description'
|
||||||
};
|
};
|
||||||
setMessages(prev => [...prev, aiMessage]);
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
setCurrentStep('description');
|
setCurrentStep('description');
|
||||||
@@ -620,7 +738,9 @@ const Inspiration: React.FC = () => {
|
|||||||
const aiMessage: Message = {
|
const aiMessage: Message = {
|
||||||
type: 'ai',
|
type: 'ai',
|
||||||
content: response.prompt || '请选择一个主题,或者输入你自己的:',
|
content: response.prompt || '请选择一个主题,或者输入你自己的:',
|
||||||
options: response.options
|
options: response.options,
|
||||||
|
canRefine: true,
|
||||||
|
step: 'theme'
|
||||||
};
|
};
|
||||||
setMessages(prev => [...prev, aiMessage]);
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
setCurrentStep('theme');
|
setCurrentStep('theme');
|
||||||
@@ -656,7 +776,9 @@ const Inspiration: React.FC = () => {
|
|||||||
type: 'ai',
|
type: 'ai',
|
||||||
content: response.prompt || '请选择类型标签(可多选):',
|
content: response.prompt || '请选择类型标签(可多选):',
|
||||||
options: response.options,
|
options: response.options,
|
||||||
isMultiSelect: true
|
isMultiSelect: true,
|
||||||
|
canRefine: true,
|
||||||
|
step: 'genre'
|
||||||
};
|
};
|
||||||
setMessages(prev => [...prev, aiMessage]);
|
setMessages(prev => [...prev, aiMessage]);
|
||||||
setCurrentStep('genre');
|
setCurrentStep('genre');
|
||||||
@@ -767,7 +889,7 @@ const Inspiration: React.FC = () => {
|
|||||||
background: msg.optionsDisabled
|
background: msg.optionsDisabled
|
||||||
? 'var(--color-bg-layout)'
|
? 'var(--color-bg-layout)'
|
||||||
: msg.isMultiSelect && selectedOptions.includes(option)
|
: msg.isMultiSelect && selectedOptions.includes(option)
|
||||||
? 'var(--color-bg-spotlight)' // Need to ensure this exists or use safe fallback
|
? 'var(--color-bg-spotlight)'
|
||||||
: 'var(--color-bg-container)',
|
: 'var(--color-bg-container)',
|
||||||
opacity: msg.optionsDisabled ? 0.6 : 1,
|
opacity: msg.optionsDisabled ? 0.6 : 1,
|
||||||
animation: 'floatIn 0.6s ease-out',
|
animation: 'floatIn 0.6s ease-out',
|
||||||
@@ -802,19 +924,72 @@ const Inspiration: React.FC = () => {
|
|||||||
确认选择 ({selectedOptions.length})
|
确认选择 ({selectedOptions.length})
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* 反馈优化区域 - 新增 */}
|
||||||
|
{msg.canRefine && !msg.optionsDisabled && !msg.isMultiSelect && (
|
||||||
|
<div style={{ marginTop: 8, paddingTop: 8, borderTop: '1px dashed var(--color-border)' }}>
|
||||||
|
{showFeedbackInput === index ? (
|
||||||
|
<Space direction="vertical" style={{ width: '100%' }} size="small">
|
||||||
|
<TextArea
|
||||||
|
value={feedbackValue}
|
||||||
|
onChange={(e) => setFeedbackValue(e.target.value)}
|
||||||
|
placeholder="例如:我想要更悲剧的主题、能不能更简短一些、偏向古风..."
|
||||||
|
autoSize={{ minRows: 2, maxRows: 3 }}
|
||||||
|
disabled={refining}
|
||||||
|
onPressEnter={(e) => {
|
||||||
|
if (!e.shiftKey && feedbackValue.trim()) {
|
||||||
|
e.preventDefault();
|
||||||
|
handleRefineOptions(index, feedbackValue);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Space style={{ width: '100%', justifyContent: 'flex-end' }}>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
onClick={() => {
|
||||||
|
setShowFeedbackInput(null);
|
||||||
|
setFeedbackValue('');
|
||||||
|
}}
|
||||||
|
disabled={refining}
|
||||||
|
>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
size="small"
|
||||||
|
onClick={() => handleRefineOptions(index, feedbackValue)}
|
||||||
|
loading={refining}
|
||||||
|
disabled={!feedbackValue.trim()}
|
||||||
|
>
|
||||||
|
重新生成
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
</Space>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
onClick={() => setShowFeedbackInput(index)}
|
||||||
|
style={{ padding: 0, height: 'auto' }}
|
||||||
|
>
|
||||||
|
💡 不太满意?告诉我你的想法
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</Space>
|
</Space>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|
||||||
{loading && (
|
{(loading || refining) && (
|
||||||
<div style={{
|
<div style={{
|
||||||
textAlign: 'center',
|
textAlign: 'center',
|
||||||
padding: 20,
|
padding: 20,
|
||||||
animation: 'fadeIn 0.3s ease-in'
|
animation: 'fadeIn 0.3s ease-in'
|
||||||
}}>
|
}}>
|
||||||
<Spin tip="AI思考中..." />
|
<Spin tip={refining ? "正在根据您的反馈重新生成..." : "AI思考中..."} />
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|||||||
@@ -150,10 +150,9 @@ export default function SettingsPage() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const apiProviders = [
|
const apiProviders = [
|
||||||
{ value: 'openai', label: 'OpenAl Compatible', defaultUrl: 'https://api.openai.com/v1' },
|
{ value: 'openai', label: 'OpenAI Compatible', defaultUrl: 'https://api.openai.com/v1' },
|
||||||
// { value: 'azure', label: 'Azure OpenAI', defaultUrl: 'https://YOUR-RESOURCE.openai.azure.com' },
|
// { value: 'anthropic', label: 'Anthropic (Claude)', defaultUrl: 'https://api.anthropic.com' },
|
||||||
// { value: 'anthropic', label: 'Anthropic', defaultUrl: 'https://api.anthropic.com' },
|
{ value: 'gemini', label: 'Google Gemini', defaultUrl: 'https://generativelanguage.googleapis.com/v1beta' },
|
||||||
// { value: 'custom', label: '自定义', defaultUrl: '' },
|
|
||||||
];
|
];
|
||||||
|
|
||||||
const handleProviderChange = (value: string) => {
|
const handleProviderChange = (value: string) => {
|
||||||
@@ -483,8 +482,8 @@ export default function SettingsPage() {
|
|||||||
switch (provider) {
|
switch (provider) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
return 'blue';
|
return 'blue';
|
||||||
case 'anthropic':
|
// case 'anthropic':
|
||||||
return 'purple';
|
// return 'purple';
|
||||||
case 'gemini':
|
case 'gemini':
|
||||||
return 'green';
|
return 'green';
|
||||||
default:
|
default:
|
||||||
@@ -973,6 +972,26 @@ export default function SettingsPage() {
|
|||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
label={
|
||||||
|
<Space size={4}>
|
||||||
|
<span>系统提示词</span>
|
||||||
|
<Tooltip title="设置全局系统提示词,每次AI调用时都会自动使用。可用于设定AI的角色、语言风格等">
|
||||||
|
<InfoCircleOutlined style={{ color: 'var(--color-text-secondary)', fontSize: isMobile ? '12px' : '14px' }} />
|
||||||
|
</Tooltip>
|
||||||
|
</Space>
|
||||||
|
}
|
||||||
|
name="system_prompt"
|
||||||
|
>
|
||||||
|
<TextArea
|
||||||
|
rows={4}
|
||||||
|
placeholder="例如:你是一个专业的小说创作助手,请用生动、细腻的文字进行创作..."
|
||||||
|
maxLength={10000}
|
||||||
|
showCount
|
||||||
|
style={{ fontSize: isMobile ? '13px' : '14px' }}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
{/* 测试结果展示 */}
|
{/* 测试结果展示 */}
|
||||||
{showTestResult && testResult && (
|
{showTestResult && testResult && (
|
||||||
<Alert
|
<Alert
|
||||||
@@ -1247,7 +1266,7 @@ export default function SettingsPage() {
|
|||||||
>
|
>
|
||||||
<Select>
|
<Select>
|
||||||
<Select.Option value="openai">OpenAI</Select.Option>
|
<Select.Option value="openai">OpenAI</Select.Option>
|
||||||
<Select.Option value="anthropic">Anthropic (Claude)</Select.Option>
|
{/* <Select.Option value="anthropic">Anthropic (Claude)</Select.Option> */}
|
||||||
<Select.Option value="gemini">Google Gemini</Select.Option>
|
<Select.Option value="gemini">Google Gemini</Select.Option>
|
||||||
</Select>
|
</Select>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
@@ -1298,6 +1317,18 @@ export default function SettingsPage() {
|
|||||||
placeholder="2000"
|
placeholder="2000"
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="system_prompt"
|
||||||
|
label="系统提示词"
|
||||||
|
>
|
||||||
|
<TextArea
|
||||||
|
rows={3}
|
||||||
|
placeholder="例如:你是一个专业的小说创作助手...(可选)"
|
||||||
|
maxLength={10000}
|
||||||
|
showCount
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
</Form>
|
</Form>
|
||||||
</Modal>
|
</Modal>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -557,6 +557,24 @@ export const inspirationApi = {
|
|||||||
error?: string;
|
error?: string;
|
||||||
}>('/inspiration/generate-options', data),
|
}>('/inspiration/generate-options', data),
|
||||||
|
|
||||||
|
// 基于用户反馈重新生成选项(新增)
|
||||||
|
refineOptions: (data: {
|
||||||
|
step: 'title' | 'description' | 'theme' | 'genre';
|
||||||
|
context: {
|
||||||
|
initial_idea?: string;
|
||||||
|
title?: string;
|
||||||
|
description?: string;
|
||||||
|
theme?: string;
|
||||||
|
};
|
||||||
|
feedback: string;
|
||||||
|
previous_options?: string[];
|
||||||
|
}) =>
|
||||||
|
api.post<unknown, {
|
||||||
|
prompt?: string;
|
||||||
|
options: string[];
|
||||||
|
error?: string;
|
||||||
|
}>('/inspiration/refine-options', data),
|
||||||
|
|
||||||
// 智能补全缺失信息
|
// 智能补全缺失信息
|
||||||
quickGenerate: (data: {
|
quickGenerate: (data: {
|
||||||
title?: string;
|
title?: string;
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ export interface Settings {
|
|||||||
llm_model: string;
|
llm_model: string;
|
||||||
temperature: number;
|
temperature: number;
|
||||||
max_tokens: number;
|
max_tokens: number;
|
||||||
|
system_prompt?: string;
|
||||||
preferences?: string;
|
preferences?: string;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
updated_at: string;
|
updated_at: string;
|
||||||
@@ -33,6 +34,7 @@ export interface SettingsUpdate {
|
|||||||
llm_model?: string;
|
llm_model?: string;
|
||||||
temperature?: number;
|
temperature?: number;
|
||||||
max_tokens?: number;
|
max_tokens?: number;
|
||||||
|
system_prompt?: string;
|
||||||
preferences?: string;
|
preferences?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,6 +46,7 @@ export interface APIKeyPresetConfig {
|
|||||||
llm_model: string;
|
llm_model: string;
|
||||||
temperature: number;
|
temperature: number;
|
||||||
max_tokens: number;
|
max_tokens: number;
|
||||||
|
system_prompt?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface APIKeyPreset {
|
export interface APIKeyPreset {
|
||||||
|
|||||||
Reference in New Issue
Block a user