update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
+30
@@ -0,0 +1,30 @@
|
||||
"""添加system_prompt字段到settings表
|
||||
|
||||
Revision ID: a7e4408e1d5b
|
||||
Revises: e411428f00c0
|
||||
Create Date: 2025-12-27 15:41:22.310160
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a7e4408e1d5b'
|
||||
down_revision: Union[str, None] = 'e411428f00c0'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('settings', sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('settings', 'system_prompt')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,177 @@
|
||||
"""初始化SQLite预置数据
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: fbeb1038c728
|
||||
Create Date: 2025-12-27 08:56:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a1b2c3d4e5f6'
|
||||
down_revision: Union[str, None] = 'fbeb1038c728'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""插入预置数据"""
|
||||
|
||||
# ==================== 1. 插入关系类型数据 ====================
|
||||
relationship_types_table = table(
|
||||
'relationship_types',
|
||||
column('name', String),
|
||||
column('category', String),
|
||||
column('reverse_name', String),
|
||||
column('intimacy_range', String),
|
||||
column('icon', String),
|
||||
column('description', Text),
|
||||
)
|
||||
|
||||
relationship_types_data = [
|
||||
# 家庭关系
|
||||
{"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"},
|
||||
{"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"},
|
||||
{"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"},
|
||||
{"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"},
|
||||
{"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"},
|
||||
{"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"},
|
||||
{"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"},
|
||||
|
||||
# 社交关系
|
||||
{"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"},
|
||||
{"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"},
|
||||
{"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"},
|
||||
{"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"},
|
||||
{"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"},
|
||||
{"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"},
|
||||
|
||||
# 职业关系
|
||||
{"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"},
|
||||
{"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"},
|
||||
{"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"},
|
||||
{"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"},
|
||||
|
||||
# 敌对关系
|
||||
{"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"},
|
||||
{"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"},
|
||||
{"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"},
|
||||
{"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡", "description": "宿命之敌"},
|
||||
]
|
||||
|
||||
op.bulk_insert(relationship_types_table, relationship_types_data)
|
||||
print(f"✅ SQLite: 已插入 {len(relationship_types_data)} 条关系类型数据")
|
||||
|
||||
|
||||
# ==================== 2. 插入全局写作风格预设 ====================
|
||||
writing_styles_table = table(
|
||||
'writing_styles',
|
||||
column('user_id', String),
|
||||
column('name', String),
|
||||
column('style_type', String),
|
||||
column('preset_id', String),
|
||||
column('description', Text),
|
||||
column('prompt_content', Text),
|
||||
column('order_index', Integer),
|
||||
)
|
||||
|
||||
writing_styles_data = [
|
||||
{
|
||||
"user_id": None, # NULL 表示全局预设
|
||||
"name": "自然流畅",
|
||||
"style_type": "preset",
|
||||
"preset_id": "natural",
|
||||
"description": "自然流畅的叙事风格,适合现代都市、现实题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言简洁明快,贴近现代口语
|
||||
2. 多用短句,节奏流畅
|
||||
3. 注重情感细节的自然流露
|
||||
4. 避免过度修饰和复杂句式""",
|
||||
"order_index": 1
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "古典优雅",
|
||||
"style_type": "preset",
|
||||
"preset_id": "classical",
|
||||
"description": "古典文雅的写作风格,适合古装、仙侠题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 使用文言、半文言或典雅的白话
|
||||
2. 适当运用古典诗词意象
|
||||
3. 注重意境营造和韵味
|
||||
4. 对话和描写保持古典美感""",
|
||||
"order_index": 2
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "现代简约",
|
||||
"style_type": "preset",
|
||||
"preset_id": "modern",
|
||||
"description": "现代简约风格,适合轻小说、网文快节奏叙事",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言直白简练,信息密度高
|
||||
2. 多用对话推进情节
|
||||
3. 避免冗长描写,突出关键动作
|
||||
4. 节奏明快,适合快速阅读""",
|
||||
"order_index": 3
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "文艺细腻",
|
||||
"style_type": "preset",
|
||||
"preset_id": "literary",
|
||||
"description": "文艺细腻风格,注重心理描写和氛围营造",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 注重心理活动和情感细节
|
||||
2. 善用环境描写烘托氛围
|
||||
3. 语言优美,富有文学性
|
||||
4. 适当使用比喻、象征等修辞手法""",
|
||||
"order_index": 4
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "紧张悬疑",
|
||||
"style_type": "preset",
|
||||
"preset_id": "suspense",
|
||||
"description": "紧张悬疑风格,适合推理、惊悚题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 营造紧张压迫的氛围
|
||||
2. 多用短句加快节奏
|
||||
3. 善于设置悬念和伏笔
|
||||
4. 注重细节描写,为推理埋下线索""",
|
||||
"order_index": 5
|
||||
},
|
||||
{
|
||||
"user_id": None,
|
||||
"name": "幽默诙谐",
|
||||
"style_type": "preset",
|
||||
"preset_id": "humorous",
|
||||
"description": "幽默诙谐风格,适合轻松搞笑题材",
|
||||
"prompt_content": """写作风格要求:
|
||||
1. 语言活泼风趣,善用俏皮话
|
||||
2. 注重对话的喜剧效果
|
||||
3. 适当夸张和反转制造笑点
|
||||
4. 保持轻松愉快的基调""",
|
||||
"order_index": 6
|
||||
},
|
||||
]
|
||||
|
||||
op.bulk_insert(writing_styles_table, writing_styles_data)
|
||||
print(f"✅ SQLite: 已插入 {len(writing_styles_data)} 条全局写作风格预设")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""删除预置数据"""
|
||||
|
||||
# 删除写作风格预设(只删除全局预设)
|
||||
op.execute("DELETE FROM writing_styles WHERE user_id IS NULL")
|
||||
print("✅ SQLite: 已删除全局写作风格预设")
|
||||
|
||||
# 删除关系类型
|
||||
op.execute("DELETE FROM relationship_types")
|
||||
print("✅ SQLite: 已删除关系类型数据")
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
"""添加system_prompt字段到settings表
|
||||
|
||||
Revision ID: 7899f8d4d839
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-27 17:00:35.440551
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7899f8d4d839'
|
||||
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('settings', schema=None) as batch_op:
|
||||
batch_op.drop_column('system_prompt')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -309,13 +309,37 @@ async def generate_career_system(
|
||||
7. 只返回纯JSON,不要添加任何解释文字
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 30)
|
||||
yield await SSEResponse.send_progress("调用AI生成新职业...", 10)
|
||||
logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)")
|
||||
|
||||
try:
|
||||
# 调用AI生成
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
# 使用流式生成替代非流式
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
last_progress = 10
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 平滑更新进度(10-90%,AI生成占60%)
|
||||
# 每10个chunk增加约1%的进度,最多到90%
|
||||
if chunk_count % 10 == 0:
|
||||
# 计算进度:10% + (chunk_count / 10) * 1%,但不超过90%
|
||||
current_progress = min(10 + (chunk_count // 10), 90)
|
||||
if current_progress > last_progress:
|
||||
last_progress = current_progress
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)",
|
||||
current_progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
@@ -326,7 +350,7 @@ async def generate_career_system(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 50)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 91)
|
||||
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
@@ -339,7 +363,7 @@ async def generate_career_system(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("保存主职业...", 60)
|
||||
yield await SSEResponse.send_progress("保存主职业到数据库...", 93)
|
||||
|
||||
# 保存主职业
|
||||
main_careers_created = []
|
||||
@@ -371,7 +395,7 @@ async def generate_career_system(
|
||||
logger.error(f" ❌ 创建主职业失败:{str(e)}")
|
||||
continue
|
||||
|
||||
yield await SSEResponse.send_progress("保存副职业...", 80)
|
||||
yield await SSEResponse.send_progress("保存副职业到数据库...", 96)
|
||||
|
||||
# 保存副职业
|
||||
sub_careers_created = []
|
||||
|
||||
+47
-12
@@ -1070,8 +1070,7 @@ async def analyze_chapter_background(
|
||||
|
||||
if career_update_result['updated_count'] > 0:
|
||||
logger.info(
|
||||
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息: "
|
||||
f"{', '.join(career_update_result['updated_characters'])}"
|
||||
f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息"
|
||||
)
|
||||
if career_update_result['changes']:
|
||||
for change in career_update_result['changes']:
|
||||
@@ -1445,7 +1444,7 @@ async def generate_chapter_content_stream(
|
||||
user_id=current_user_id,
|
||||
db_session=db_session,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1596,10 +1595,24 @@ async def generate_chapter_content_stream(
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
# 发送开始生成的进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 35, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||
|
||||
{style_content}
|
||||
|
||||
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||
logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||
|
||||
# 准备生成参数
|
||||
generate_kwargs = {"prompt": prompt}
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
}
|
||||
if custom_model:
|
||||
logger.info(f" 使用自定义模型: {custom_model}")
|
||||
generate_kwargs["model"] = custom_model
|
||||
@@ -1618,11 +1631,14 @@ async def generate_chapter_content_stream(
|
||||
# 发送内容块
|
||||
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 每20个chunk发送一次进度更新(提高频率)
|
||||
if chunk_count % 20 == 0:
|
||||
# 每5个chunk发送一次进度更新(10-95%,更平滑)
|
||||
if chunk_count % 5 == 0:
|
||||
current_word_count = len(full_content)
|
||||
# 根据目标字数估算进度(40%起步,最高95%,为后续保存留5%)
|
||||
estimated_progress = min(95, 40 + int((current_word_count / target_word_count) * 55))
|
||||
# 优化进度计算:使用更平滑的递增方式
|
||||
# 基于chunk数量和字数的混合计算,避免大幅跳跃
|
||||
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
|
||||
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
|
||||
estimated_progress = min(95, 10 + chunk_progress + word_progress)
|
||||
|
||||
# 只在进度变化时发送
|
||||
if estimated_progress > last_progress:
|
||||
@@ -1636,10 +1652,14 @@ async def generate_chapter_content_stream(
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
last_progress = estimated_progress
|
||||
|
||||
# 每20个chunk发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
await asyncio.sleep(0) # 让出控制权
|
||||
|
||||
# 发送保存进度
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 98, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 更新章节内容到数据库
|
||||
old_word_count = current_chapter.word_count or 0
|
||||
@@ -1696,7 +1716,7 @@ async def generate_chapter_content_stream(
|
||||
)
|
||||
|
||||
# 发送最终进度100%
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 100, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送完成事件(包含分析任务ID)
|
||||
completion_data = {
|
||||
@@ -2880,15 +2900,30 @@ async def generate_single_chapter_for_batch(
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# 🎨 方案一:将写作风格注入到系统提示词(批量生成)
|
||||
system_prompt_with_style = None
|
||||
if style_content:
|
||||
system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】
|
||||
|
||||
{style_content}
|
||||
|
||||
⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令!
|
||||
确保在整个章节创作过程中始终保持风格的一致性。"""
|
||||
logger.info(f"✅ 批量生成 - 已将写作风格注入系统提示词({len(style_content)}字符)")
|
||||
|
||||
# 非流式生成内容
|
||||
full_content = ""
|
||||
# 准备生成参数
|
||||
generate_kwargs = {"prompt": prompt}
|
||||
generate_kwargs = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
|
||||
}
|
||||
# 如果传入了自定义模型,使用指定的模型
|
||||
if custom_model:
|
||||
generate_kwargs["model"] = custom_model
|
||||
logger.info(f" 批量生成使用自定义模型: {custom_model}")
|
||||
|
||||
# 批量生成中的流式生成(非SSE,不需要修改进度显示)
|
||||
async for chunk in ai_service.generate_text_stream(**generate_kwargs):
|
||||
full_content += chunk
|
||||
|
||||
|
||||
+120
-20
@@ -662,10 +662,10 @@ async def generate_character_stream(
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(request.project_id, user_id, db)
|
||||
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 0)
|
||||
yield await SSEResponse.send_progress("开始生成角色...", 1)
|
||||
|
||||
# 获取已存在的角色列表
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 10)
|
||||
yield await SSEResponse.send_progress("获取项目上下文...", 2)
|
||||
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
@@ -757,7 +757,7 @@ async def generate_character_stream(
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
"""
|
||||
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 20)
|
||||
yield await SSEResponse.send_progress("构建AI提示词...", 3)
|
||||
|
||||
# 获取自定义提示词模板
|
||||
template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db)
|
||||
@@ -768,11 +768,14 @@ async def generate_character_stream(
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 30)
|
||||
yield await SSEResponse.send_progress("调用AI服务生成角色...", 10)
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
|
||||
|
||||
try:
|
||||
# 🔧 MCP工具增强:静默检查并收集参考资料
|
||||
ai_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
@@ -789,7 +792,7 @@ async def generate_character_stream(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # 减少为1轮,避免超时
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -797,22 +800,119 @@ async def generate_character_stream(
|
||||
|
||||
if isinstance(result, dict):
|
||||
ai_response = result.get('content', '')
|
||||
if result.get('tool_calls_made', 0) > 0:
|
||||
logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)")
|
||||
finish_reason = result.get('finish_reason', '')
|
||||
tool_calls_made = result.get('tool_calls_made', 0)
|
||||
|
||||
# 🔧 修复:检查工具调用是否真正成功
|
||||
if tool_calls_made > 0:
|
||||
if finish_reason == 'tool_error':
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
|
||||
# 工具调用失败,重新用基础模式生成
|
||||
ai_response = ""
|
||||
elif not ai_response.strip():
|
||||
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
|
||||
# 工具调用成功但返回空内容,重新生成
|
||||
ai_response = ""
|
||||
else:
|
||||
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
|
||||
# MCP成功且有内容,模拟流式输出(分块发送)
|
||||
chunk_size = 50
|
||||
for i in range(0, len(ai_response), chunk_size):
|
||||
chunk = ai_response[i:i+chunk_size]
|
||||
chunk_count += 1
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 3 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
|
||||
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
|
||||
)
|
||||
|
||||
# 跳过后续的流式生成
|
||||
ai_response = result.get('content', '')
|
||||
else:
|
||||
ai_response = result
|
||||
|
||||
# 如果MCP调用失败或返回空,继续走流式生成
|
||||
if not ai_response or not ai_response.strip():
|
||||
logger.info(f"🔄 开始流式生成...")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式")
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as mcp_error:
|
||||
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}")
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
|
||||
ai_response = ""
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
else:
|
||||
result = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
logger.debug(f"未登录用户,使用流式基础模式")
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成角色中... ({len(ai_response)}字符)",
|
||||
10 + min(chunk_count // 2, 85)
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
@@ -823,7 +923,7 @@ async def generate_character_stream(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
@@ -836,7 +936,7 @@ async def generate_character_stream(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("创建角色记录...", 75)
|
||||
yield await SSEResponse.send_progress("创建角色记录...", 97)
|
||||
|
||||
# 转换traits
|
||||
traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None
|
||||
@@ -1001,7 +1101,7 @@ async def generate_character_stream(
|
||||
|
||||
# 如果是组织,创建Organization详情
|
||||
if is_organization:
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||
|
||||
org_check = await db.execute(
|
||||
select(Organization).where(Organization.character_id == character.id)
|
||||
@@ -1168,13 +1268,13 @@ async def generate_character_stream(
|
||||
|
||||
logger.info(f"✅ 成功创建 {created_members} 条组织成员记录")
|
||||
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response,
|
||||
generated_content=ai_response,
|
||||
model=user_ai_service.default_model
|
||||
)
|
||||
db.add(history)
|
||||
|
||||
+204
-28
@@ -105,23 +105,27 @@ async def generate_options(
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板(根据step确定模板key)
|
||||
# 新结构:每个步骤有独立的 SYSTEM 和 USER 模板
|
||||
template_key_map = {
|
||||
"title": "INSPIRATION_TITLE",
|
||||
"description": "INSPIRATION_DESCRIPTION",
|
||||
"theme": "INSPIRATION_THEME",
|
||||
"genre": "INSPIRATION_GENRE"
|
||||
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||
}
|
||||
template_key = template_key_map.get(step)
|
||||
template_keys = template_key_map.get(step)
|
||||
|
||||
if not template_key:
|
||||
if not template_keys:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
# 获取自定义提示词模板
|
||||
prompt_template_str = await PromptService.get_template(template_key, user_id, db)
|
||||
system_key, user_key = template_keys
|
||||
|
||||
# 获取自定义提示词模板(分别获取 system 和 user)
|
||||
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
@@ -131,19 +135,9 @@ async def generate_options(
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分)
|
||||
# 尝试解析为JSON格式的字典
|
||||
try:
|
||||
prompt_template = json.loads(prompt_template_str)
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# 如果不是JSON格式,降级使用原有方法
|
||||
prompt_template = prompt_service.get_inspiration_prompt(step)
|
||||
if not prompt_template:
|
||||
return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []}
|
||||
system_prompt = prompt_template["system"].format(**format_params)
|
||||
user_prompt = prompt_template["user"].format(**format_params)
|
||||
# 格式化提示词
|
||||
system_prompt = system_template.format(**format_params)
|
||||
user_prompt = user_template.format(**format_params)
|
||||
|
||||
# 如果是重试,在提示词中强调格式要求
|
||||
if attempt > 0:
|
||||
@@ -153,13 +147,18 @@ async def generate_options(
|
||||
# 关键改进:使用递减的temperature以保持后续阶段与前文的一致性
|
||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||
logger.info(f"调用AI生成{step}选项... (temperature={temperature})")
|
||||
response = await ai_service.generate_text(
|
||||
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = response.get("content", "")
|
||||
response = {"content": accumulated_text}
|
||||
content = accumulated_text
|
||||
logger.info(f"AI返回内容长度: {len(content)}")
|
||||
|
||||
# 解析JSON(使用统一的JSON清洗方法)
|
||||
@@ -222,6 +221,180 @@ async def generate_options(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refine-options")
|
||||
async def refine_options(
|
||||
data: Dict[str, Any],
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
ai_service: AIService = Depends(get_user_ai_service)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
基于用户反馈重新生成选项(支持多轮对话)
|
||||
|
||||
Request:
|
||||
{
|
||||
"step": "title", // 当前步骤
|
||||
"context": {
|
||||
"initial_idea": "...",
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"theme": "..."
|
||||
},
|
||||
"feedback": "我想要更悲剧一些的主题", // 用户反馈
|
||||
"previous_options": ["选项1", "选项2", ...] // 之前的选项(可选)
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"prompt": "引导语",
|
||||
"options": ["新选项1", "新选项2", ...]
|
||||
}
|
||||
"""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step = data.get("step", "title")
|
||||
context = data.get("context", {})
|
||||
feedback = data.get("feedback", "")
|
||||
previous_options = data.get("previous_options", [])
|
||||
|
||||
logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)")
|
||||
logger.info(f"用户反馈: {feedback}")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
|
||||
# 获取对应的提示词模板
|
||||
template_key_map = {
|
||||
"title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"),
|
||||
"description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"),
|
||||
"theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"),
|
||||
"genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER")
|
||||
}
|
||||
template_keys = template_key_map.get(step)
|
||||
|
||||
if not template_keys:
|
||||
return {
|
||||
"error": f"不支持的步骤: {step}",
|
||||
"prompt": "",
|
||||
"options": []
|
||||
}
|
||||
|
||||
system_key, user_key = template_keys
|
||||
|
||||
# 获取自定义提示词模板
|
||||
system_template = await PromptService.get_template(system_key, user_id, db)
|
||||
user_template = await PromptService.get_template(user_key, user_id, db)
|
||||
|
||||
# 准备格式化参数
|
||||
format_params = {
|
||||
"initial_idea": context.get("initial_idea", context.get("description", "")),
|
||||
"title": context.get("title", ""),
|
||||
"description": context.get("description", ""),
|
||||
"theme": context.get("theme", "")
|
||||
}
|
||||
|
||||
# 格式化提示词
|
||||
system_prompt = system_template.format(**format_params)
|
||||
user_prompt = user_template.format(**format_params)
|
||||
|
||||
# 添加反馈信息到提示词
|
||||
feedback_instruction = f"""
|
||||
|
||||
⚠️ 用户对之前的选项不太满意,提供了以下反馈:
|
||||
「{feedback}」
|
||||
|
||||
之前生成的选项:
|
||||
{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"}
|
||||
|
||||
请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。
|
||||
注意:
|
||||
1. 仔细理解用户的反馈意图
|
||||
2. 生成的新选项要明显体现用户要求的调整方向
|
||||
3. 保持与已有上下文的一致性
|
||||
4. 确保返回6个有效选项
|
||||
"""
|
||||
|
||||
system_prompt += feedback_instruction
|
||||
|
||||
# 如果是重试,强调格式要求
|
||||
if attempt > 0:
|
||||
system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!"
|
||||
|
||||
# 调用AI生成选项
|
||||
temperature = TEMPERATURE_SETTINGS.get(step, 0.7)
|
||||
# 反馈生成时使用稍高的temperature以获得更多样化的结果
|
||||
temperature = min(temperature + 0.1, 0.9)
|
||||
logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})")
|
||||
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
temperature=temperature
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = accumulated_text
|
||||
logger.info(f"AI返回内容长度: {len(content)}")
|
||||
|
||||
# 解析JSON
|
||||
try:
|
||||
cleaned_content = ai_service._clean_json_response(content)
|
||||
result = json.loads(cleaned_content)
|
||||
|
||||
# 校验返回格式
|
||||
is_valid, error_msg = validate_options_response(result, step)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次"
|
||||
}
|
||||
|
||||
logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项")
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"第{attempt + 1}次JSON解析失败: {e}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("JSON解析失败,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"prompt": f"请为【{step}】提供内容:",
|
||||
"options": ["让AI重新生成", "我自己输入"],
|
||||
"error": f"AI返回格式错误,已自动重试{max_retries}次"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第{attempt + 1}次根据反馈生成失败: {e}", exc_info=True)
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("发生异常,准备重试...")
|
||||
continue
|
||||
else:
|
||||
return {
|
||||
"error": str(e),
|
||||
"prompt": "生成失败,请重试",
|
||||
"options": ["重新生成", "我自己输入"]
|
||||
}
|
||||
|
||||
return {
|
||||
"error": "生成失败",
|
||||
"prompt": "请重试",
|
||||
"options": []
|
||||
}
|
||||
|
||||
|
||||
@router.post("/quick-generate")
|
||||
async def quick_generate(
|
||||
data: Dict[str, Any],
|
||||
@@ -280,14 +453,17 @@ async def quick_generate(
|
||||
# 降级使用原有方法
|
||||
prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text)
|
||||
|
||||
# 调用AI
|
||||
response = await ai_service.generate_text(
|
||||
# 调用AI - 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
temperature=0.7
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
content = response.get("content", "")
|
||||
response = {"content": accumulated_text}
|
||||
content = accumulated_text
|
||||
|
||||
# 解析JSON(使用统一的JSON清洗方法)
|
||||
try:
|
||||
|
||||
@@ -512,8 +512,29 @@ async def generate_organization_stream(
|
||||
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)")
|
||||
|
||||
try:
|
||||
ai_response = await user_ai_service.generate_text(prompt=prompt)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else str(ai_response)
|
||||
# 使用流式生成替代非流式
|
||||
ai_content = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
|
||||
chunk_count += 1
|
||||
ai_content += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新字数(5-95%,AI生成占90%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(5 + (chunk_count // 5), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成组织中... ({len(ai_content)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
except Exception as ai_error:
|
||||
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
|
||||
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
|
||||
@@ -523,7 +544,7 @@ async def generate_organization_stream(
|
||||
yield await SSEResponse.send_error("AI服务返回空响应")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 60)
|
||||
yield await SSEResponse.send_progress("解析AI响应...", 96)
|
||||
|
||||
# ✅ 使用统一的 JSON 清洗方法
|
||||
try:
|
||||
@@ -536,7 +557,7 @@ async def generate_organization_stream(
|
||||
yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}")
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("创建组织记录...", 75)
|
||||
yield await SSEResponse.send_progress("创建组织记录...", 97)
|
||||
|
||||
# 创建角色记录(组织也是角色的一种)
|
||||
character = Character(
|
||||
@@ -563,7 +584,7 @@ async def generate_organization_stream(
|
||||
|
||||
logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})")
|
||||
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 85)
|
||||
yield await SSEResponse.send_progress("创建组织详情...", 98)
|
||||
|
||||
# 自动创建Organization详情记录
|
||||
organization = Organization(
|
||||
@@ -580,7 +601,7 @@ async def generate_organization_stream(
|
||||
|
||||
logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})")
|
||||
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 95)
|
||||
yield await SSEResponse.send_progress("保存生成历史...", 99)
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
|
||||
+89
-30
@@ -470,7 +470,7 @@ async def _generate_new_outline(
|
||||
project: Project,
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
user_id: str = None
|
||||
user_id: str
|
||||
) -> OutlineListResponse:
|
||||
"""全新生成大纲(MCP增强版)"""
|
||||
logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}")
|
||||
@@ -534,7 +534,7 @@ async def _generate_new_outline(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -573,15 +573,23 @@ async def _generate_new_outline(
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI生成大纲
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# 调用AI流式生成大纲(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 这里是非SSE接口,不需要发送chunk
|
||||
# 如果未来需要转SSE,可以在这里yield
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
@@ -732,7 +740,7 @@ async def _continue_outline(
|
||||
existing_outlines: List[Outline],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService,
|
||||
user_id: str = "system"
|
||||
user_id: str
|
||||
) -> OutlineListResponse:
|
||||
"""续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)"""
|
||||
logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}")
|
||||
@@ -1000,7 +1008,7 @@ async def _continue_outline(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1045,15 +1053,22 @@ async def _continue_outline(
|
||||
)
|
||||
|
||||
# 调用AI生成当前批次
|
||||
logger.info(f"正在调用AI生成第{batch_num + 1}批...")
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
logger.info(f"正在调用AI流式生成第{batch_num + 1}批...")
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=request.provider,
|
||||
model=request.model
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 这里是非SSE接口,不需要发送chunk
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
@@ -1291,7 +1306,7 @@ async def new_outline_generator(
|
||||
user_id=user_id_for_mcp,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1332,7 +1347,7 @@ async def new_outline_generator(
|
||||
mcp_references=mcp_reference_materials
|
||||
)
|
||||
|
||||
# 调用AI
|
||||
# 调用AI流式生成
|
||||
yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30)
|
||||
|
||||
# 添加调试日志
|
||||
@@ -1341,24 +1356,44 @@ async def new_outline_generator(
|
||||
logger.info(f"=== 大纲生成AI调用参数 ===")
|
||||
logger.info(f" provider参数: {provider_param}")
|
||||
logger.info(f" model参数: {model_param}")
|
||||
logger.info(f" 完整data: {data}")
|
||||
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# ✅ 流式生成(带字数统计和进度)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider_param,
|
||||
model=model_param
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数(30-95%,AI生成占65%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 2), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"AI生成大纲中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70)
|
||||
yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
# 全新生成模式:删除旧大纲和关联的所有章节
|
||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 75)
|
||||
yield await SSEResponse.send_progress("清理旧大纲和章节...", 97)
|
||||
logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})")
|
||||
|
||||
from sqlalchemy import delete as sql_delete
|
||||
@@ -1390,7 +1425,7 @@ async def new_outline_generator(
|
||||
logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲")
|
||||
|
||||
# 保存新大纲
|
||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80)
|
||||
yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98)
|
||||
outlines = await _save_outlines(
|
||||
project_id, outline_data, db, start_index=1
|
||||
)
|
||||
@@ -1410,7 +1445,7 @@ async def new_outline_generator(
|
||||
for outline in outlines:
|
||||
await db.refresh(outline)
|
||||
|
||||
yield await SSEResponse.send_progress("整理结果数据...", 95)
|
||||
yield await SSEResponse.send_progress("整理结果数据...", 99)
|
||||
|
||||
logger.info(f"全新生成完成 - {len(outlines)} 章")
|
||||
|
||||
@@ -1785,7 +1820,7 @@ async def continue_outline_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 减少为1轮,避免超时
|
||||
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1846,19 +1881,43 @@ async def continue_outline_generator(
|
||||
logger.info(f" provider参数: {provider_param}")
|
||||
logger.info(f" model参数: {model_param}")
|
||||
|
||||
ai_response = await user_ai_service.generate_text(
|
||||
# 流式生成并累积文本
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider_param,
|
||||
model=model_param
|
||||
)
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度(每批占用约50%的进度空间)
|
||||
if chunk_count % 5 == 0:
|
||||
# 在批次范围内平滑递增(从10到85,总共75%)
|
||||
batch_range = 60 // total_batches # 总进度60%分配给所有批次
|
||||
progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)",
|
||||
progress_in_batch
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...",
|
||||
batch_progress + 10
|
||||
)
|
||||
|
||||
# 提取内容(generate_text返回字典)
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
# 提取内容
|
||||
ai_content = accumulated_text
|
||||
ai_response = {"content": ai_content}
|
||||
|
||||
# 解析响应
|
||||
outline_data = _parse_ai_response(ai_content)
|
||||
|
||||
+25
-11
@@ -73,14 +73,15 @@ async def get_user_ai_service(
|
||||
await db.refresh(settings)
|
||||
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
|
||||
|
||||
# 使用用户设置创建AI服务实例
|
||||
# 使用用户设置创建AI服务实例(包括系统提示词)
|
||||
return create_user_ai_service(
|
||||
api_provider=settings.api_provider,
|
||||
api_key=settings.api_key,
|
||||
api_base_url=settings.api_base_url or "",
|
||||
model_name=settings.llm_model,
|
||||
temperature=settings.temperature,
|
||||
max_tokens=settings.max_tokens
|
||||
max_tokens=settings.max_tokens,
|
||||
system_prompt=settings.system_prompt # 传递系统提示词
|
||||
)
|
||||
|
||||
|
||||
@@ -271,17 +272,30 @@ async def get_available_models(
|
||||
}
|
||||
|
||||
elif provider == "anthropic":
|
||||
# Anthropic 没有公开的模型列表API
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称"
|
||||
)
|
||||
# Anthropic models API
|
||||
url = f"{api_base_url.rstrip('/')}/v1/models"
|
||||
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])]
|
||||
return {"provider": provider, "models": models, "count": len(models)}
|
||||
|
||||
elif provider == "gemini":
|
||||
# Gemini models API
|
||||
url = f"{api_base_url.rstrip('/')}/models?key={api_key}"
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models = []
|
||||
for m in data.get("models", []):
|
||||
if "generateContent" in m.get("supportedGenerationMethods", []):
|
||||
mid = m.get("name", "").replace("models/", "")
|
||||
models.append({"value": mid, "label": m.get("displayName", mid), "description": ""})
|
||||
return {"provider": provider, "models": models, "count": len(models)}
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的提供商: {provider}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}")
|
||||
|
||||
+362
-126
@@ -99,7 +99,7 @@ async def world_building_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -139,51 +139,118 @@ async def world_building_generator(
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
# ===== 流式生成世界观(带重试机制) =====
|
||||
MAX_WORLD_RETRIES = 3 # 最多重试3次
|
||||
world_retry_count = 0
|
||||
world_generation_success = False
|
||||
world_data = {}
|
||||
try:
|
||||
# ✅ 使用 AIService 的统一清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观JSON解析成功")
|
||||
|
||||
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
|
||||
try:
|
||||
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
|
||||
yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 世界构建JSON解析失败: {e}")
|
||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 世界观生成独立进度:5-95%
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(5 + (chunk_count // 3), 95)
|
||||
yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 检查是否返回空响应
|
||||
if not accumulated_text or not accumulated_text.strip():
|
||||
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI返回为空,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
logger.error("❌ 世界观生成多次返回空响应")
|
||||
world_data = {
|
||||
"time_period": "AI多次返回为空,请稍后重试",
|
||||
"location": "AI多次返回为空,请稍后重试",
|
||||
"atmosphere": "AI多次返回为空,请稍后重试",
|
||||
"rules": "AI多次返回为空,请稍后重试"
|
||||
}
|
||||
world_generation_success = True # 标记为成功以继续流程
|
||||
break
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析世界观数据...", 96)
|
||||
|
||||
try:
|
||||
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
|
||||
logger.info(f" 原始内容预览: {accumulated_text[:300]}...")
|
||||
|
||||
# ✅ 使用 AIService 的统一清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||
logger.info(f" 清洗后预览: {cleaned_text[:300]}...")
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_generation_success = True # 解析成功,标记完成
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}")
|
||||
logger.error(f" 原始内容长度: {len(accumulated_text)}")
|
||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ JSON解析失败,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
world_generation_success = True # 标记为成功以继续流程
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成异常,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 最后一次重试仍失败,抛出异常
|
||||
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
|
||||
raise
|
||||
# 保存到数据库
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
yield await SSEResponse.send_progress("保存世界观到数据库...", 99)
|
||||
|
||||
# 确保user_id存在
|
||||
if not user_id:
|
||||
@@ -240,41 +307,81 @@ async def world_building_generator(
|
||||
project.wizard_step = 1
|
||||
await db.commit()
|
||||
|
||||
# ===== 自动生成职业体系 =====
|
||||
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 75)
|
||||
# ===== 自动生成职业体系(带重试机制+流式) =====
|
||||
yield await SSEResponse.send_progress("世界观完成!", 100, "success")
|
||||
yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5)
|
||||
logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系")
|
||||
|
||||
try:
|
||||
# 获取职业生成提示词模板(支持用户自定义)
|
||||
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
|
||||
career_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
genre=genre or '未设定',
|
||||
theme=theme or '未设定',
|
||||
time_period=world_data.get('time_period', '未设定'),
|
||||
location=world_data.get('location', '未设定'),
|
||||
atmosphere=world_data.get('atmosphere', '未设定'),
|
||||
rules=world_data.get('rules', '未设定')
|
||||
)
|
||||
|
||||
yield await SSEResponse.send_progress("正在生成职业体系...", 78)
|
||||
|
||||
# 调用AI生成职业
|
||||
result = await user_ai_service.generate_text(prompt=career_prompt)
|
||||
career_response = result.get('content', '') if isinstance(result, dict) else result
|
||||
|
||||
if not career_response or not career_response.strip():
|
||||
logger.warning("⚠️ AI返回空职业体系,跳过职业生成")
|
||||
yield await SSEResponse.send_progress("职业体系生成跳过(AI返回为空)", 85)
|
||||
else:
|
||||
yield await SSEResponse.send_progress("解析职业体系数据...", 82)
|
||||
MAX_CAREER_RETRIES = 3 # 最多重试3次
|
||||
career_retry_count = 0
|
||||
career_generation_success = False
|
||||
|
||||
while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success:
|
||||
try:
|
||||
retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else ""
|
||||
yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10)
|
||||
|
||||
# 获取职业生成提示词模板(支持用户自定义)
|
||||
template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db)
|
||||
career_prompt = PromptService.format_prompt(
|
||||
template,
|
||||
title=project.title,
|
||||
genre=genre or '未设定',
|
||||
theme=theme or '未设定',
|
||||
time_period=world_data.get('time_period', '未设定'),
|
||||
location=world_data.get('location', '未设定'),
|
||||
atmosphere=world_data.get('atmosphere', '未设定'),
|
||||
rules=world_data.get('rules', '未设定')
|
||||
)
|
||||
|
||||
# ✅ 使用流式生成职业体系
|
||||
career_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=career_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
career_response += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 职业体系生成独立进度:10-95%
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(10 + (chunk_count // 3), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"生成职业体系中... ({len(career_response)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
if not career_response or not career_response.strip():
|
||||
logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI返回为空,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99)
|
||||
break
|
||||
|
||||
yield await SSEResponse.send_progress("解析职业体系数据...", 96)
|
||||
|
||||
# 清洗并解析JSON
|
||||
try:
|
||||
cleaned_response = user_ai_service._clean_json_response(career_response)
|
||||
career_data = json.loads(cleaned_response)
|
||||
logger.info(f"✅ 职业体系JSON解析成功")
|
||||
logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})")
|
||||
|
||||
# 保存主职业
|
||||
main_careers_created = []
|
||||
@@ -338,22 +445,51 @@ async def world_building_generator(
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 标记成功
|
||||
career_generation_success = True
|
||||
logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}个")
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)})",
|
||||
90
|
||||
99
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 职业体系JSON解析失败: {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败,已跳过", 85)
|
||||
logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ JSON解析失败,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 职业体系保存失败: {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败,已跳过", 85)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 职业体系生成异常: {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败,已跳过(不影响项目创建)", 85)
|
||||
logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 保存失败,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}")
|
||||
career_retry_count += 1
|
||||
if career_retry_count < MAX_CAREER_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成异常,准备重试...",
|
||||
10,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99)
|
||||
|
||||
db_committed = True
|
||||
|
||||
@@ -366,7 +502,8 @@ async def world_building_generator(
|
||||
"rules": world_data.get("rules")
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_progress("职业体系完成!", 100, "success")
|
||||
yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
@@ -473,7 +610,7 @@ async def characters_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮
|
||||
max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -611,15 +748,32 @@ async def characters_generator(
|
||||
else:
|
||||
prompt = base_prompt
|
||||
|
||||
# 流式生成
|
||||
# 流式生成(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"生成角色中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析批次结果 - 使用统一的JSON清洗方法
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
@@ -1184,18 +1338,35 @@ async def outline_generator(
|
||||
requirements=outline_requirements
|
||||
)
|
||||
|
||||
# 流式生成大纲
|
||||
# 流式生成大纲(带字数统计)
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=outline_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度和字数(5-95%,AI生成占90%)
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(5 + (chunk_count // 3), 95)
|
||||
yield await SSEResponse.send_progress(
|
||||
f"生成大纲中... ({len(accumulated_text)}字符)",
|
||||
progress
|
||||
)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析大纲结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析大纲...", 40)
|
||||
yield await SSEResponse.send_progress("解析大纲...", 96)
|
||||
|
||||
try:
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
@@ -1208,7 +1379,7 @@ async def outline_generator(
|
||||
return
|
||||
|
||||
# 保存大纲到数据库
|
||||
yield await SSEResponse.send_progress("保存大纲到数据库...", 45)
|
||||
yield await SSEResponse.send_progress("保存大纲到数据库...", 97)
|
||||
created_outlines = []
|
||||
for index, outline_item in enumerate(outline_data[:outline_count], 1):
|
||||
outline = Outline(
|
||||
@@ -1231,7 +1402,7 @@ async def outline_generator(
|
||||
created_chapters = []
|
||||
if project.outline_mode == 'one-to-one':
|
||||
# 一对一模式:自动为每个大纲创建对应的章节
|
||||
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50)
|
||||
yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98)
|
||||
|
||||
for outline in created_outlines:
|
||||
chapter = Chapter(
|
||||
@@ -1250,10 +1421,10 @@ async def outline_generator(
|
||||
await db.refresh(chapter)
|
||||
|
||||
logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节")
|
||||
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85)
|
||||
yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99)
|
||||
else:
|
||||
# 一对多模式:跳过自动创建,用户可手动展开
|
||||
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85)
|
||||
yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99)
|
||||
logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开")
|
||||
|
||||
# 更新项目信息
|
||||
@@ -1396,7 +1567,7 @@ async def world_building_regenerate_generator(
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
@@ -1433,44 +1604,109 @@ async def world_building_regenerate_generator(
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
# ===== 流式生成世界观(带重试机制) =====
|
||||
MAX_WORLD_RETRIES = 3 # 最多重试3次
|
||||
world_retry_count = 0
|
||||
world_generation_success = False
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观重新生成JSON解析成功")
|
||||
|
||||
while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success:
|
||||
try:
|
||||
retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else ""
|
||||
yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 30 + world_retry_count * 5)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"世界构建JSON解析失败: {e}")
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 85)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 检查是否返回空响应
|
||||
if not accumulated_text or not accumulated_text.strip():
|
||||
logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI返回为空,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
logger.error("❌ 世界观重新生成多次返回空响应")
|
||||
world_data = {
|
||||
"time_period": "AI多次返回为空,请稍后重试",
|
||||
"location": "AI多次返回为空,请稍后重试",
|
||||
"atmosphere": "AI多次返回为空,请稍后重试",
|
||||
"rules": "AI多次返回为空,请稍后重试"
|
||||
}
|
||||
world_generation_success = True
|
||||
break
|
||||
|
||||
# 解析结果 - 使用统一的JSON清洗方法
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
try:
|
||||
logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}")
|
||||
cleaned_text = user_ai_service._clean_json_response(accumulated_text)
|
||||
logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}")
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})")
|
||||
world_generation_success = True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}")
|
||||
logger.error(f" 原始内容长度: {len(accumulated_text)}")
|
||||
logger.error(f" 原始内容预览: {accumulated_text[:200]}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ JSON解析失败,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 达到最大重试次数,使用默认值
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
world_generation_success = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}")
|
||||
world_retry_count += 1
|
||||
if world_retry_count < MAX_WORLD_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成异常,准备重试...",
|
||||
30 + world_retry_count * 5,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 最后一次重试仍失败,抛出异常
|
||||
logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}")
|
||||
raise
|
||||
|
||||
# 不保存到数据库,仅返回生成结果供用户预览
|
||||
yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90)
|
||||
|
||||
@@ -15,7 +15,6 @@ from ..schemas.writing_style import (
|
||||
WritingStyleListResponse,
|
||||
SetDefaultStyleRequest
|
||||
)
|
||||
from ..services.prompt_service import WritingStyleManager
|
||||
from ..logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
@@ -31,21 +30,36 @@ def get_current_user_id(request: Request) -> str:
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
async def get_preset_styles():
|
||||
async def get_preset_styles(db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
获取所有预设风格列表
|
||||
获取所有预设风格列表(从数据库读取)
|
||||
|
||||
返回格式:数组形式的预设风格列表
|
||||
[
|
||||
{"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": "classical", "name": "古典优雅", ...}
|
||||
{"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."},
|
||||
{"id": 2, "preset_id": "classical", "name": "古典优雅", ...}
|
||||
]
|
||||
"""
|
||||
presets = WritingStyleManager.get_all_presets()
|
||||
# 将字典转换为数组,添加 id 字段
|
||||
# 从数据库获取全局预设风格(user_id 为 NULL)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(WritingStyle.user_id.is_(None))
|
||||
.order_by(WritingStyle.order_index)
|
||||
)
|
||||
preset_styles = result.scalars().all()
|
||||
|
||||
# 转换为响应格式
|
||||
return [
|
||||
{"id": preset_id, **preset_data}
|
||||
for preset_id, preset_data in presets.items()
|
||||
{
|
||||
"id": style.id,
|
||||
"preset_id": style.preset_id,
|
||||
"name": style.name,
|
||||
"description": style.description,
|
||||
"prompt_content": style.prompt_content,
|
||||
"style_type": style.style_type,
|
||||
"order_index": style.order_index
|
||||
}
|
||||
for style in preset_styles
|
||||
]
|
||||
|
||||
|
||||
@@ -58,25 +72,33 @@ async def create_writing_style(
|
||||
"""
|
||||
创建新的写作风格(用户级别)
|
||||
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 获取当前用户ID
|
||||
user_id = get_current_user_id(request)
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
# 如果基于预设创建,从数据库获取预设内容
|
||||
if style_data.preset_id:
|
||||
preset = WritingStyleManager.get_preset_style(style_data.preset_id)
|
||||
result = await db.execute(
|
||||
select(WritingStyle)
|
||||
.where(
|
||||
WritingStyle.user_id.is_(None),
|
||||
WritingStyle.preset_id == style_data.preset_id
|
||||
)
|
||||
)
|
||||
preset = result.scalar_one_or_none()
|
||||
|
||||
if not preset:
|
||||
raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在")
|
||||
|
||||
# 使用预设内容填充(如果用户未提供)
|
||||
if not style_data.name:
|
||||
style_data.name = preset["name"]
|
||||
style_data.name = preset.name
|
||||
if not style_data.description:
|
||||
style_data.description = preset["description"]
|
||||
style_data.description = preset.description
|
||||
if not style_data.prompt_content:
|
||||
style_data.prompt_content = preset["prompt_content"]
|
||||
style_data.prompt_content = preset.prompt_content
|
||||
|
||||
# 验证必填字段
|
||||
if not style_data.name or not style_data.prompt_content:
|
||||
|
||||
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
|
||||
from mcp import ClientSession, types
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import AnyUrl
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
from app.logger import get_logger
|
||||
|
||||
@@ -141,51 +142,89 @@ class HTTPMCPClient:
|
||||
async def call_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
arguments: Dict[str, Any],
|
||||
max_reconnect_attempts: int = 2
|
||||
) -> Any:
|
||||
"""
|
||||
调用工具
|
||||
调用工具(带自动重连)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
max_reconnect_attempts: 最大重连尝试次数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
logger.info(f"调用工具: {tool_name}")
|
||||
logger.debug(f"参数: {arguments}")
|
||||
|
||||
result = await self._session.call_tool(tool_name, arguments)
|
||||
|
||||
# 处理返回结果
|
||||
# MCP SDK 返回 CallToolResult 对象
|
||||
if result.content:
|
||||
# 提取第一个content的文本
|
||||
for content in result.content:
|
||||
if isinstance(content, types.TextContent):
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
# 如果没有文本内容,返回原始内容
|
||||
return result.content[0] if result.content else None
|
||||
|
||||
# 如果有结构化内容(2025-06-18规范)
|
||||
if hasattr(result, 'structuredContent') and result.structuredContent:
|
||||
return result.structuredContent
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
|
||||
raise MCPError(f"调用工具失败: {str(e)}")
|
||||
for attempt in range(max_reconnect_attempts + 1):
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
logger.info(f"调用工具: {tool_name}")
|
||||
logger.debug(f" 参数类型: {type(arguments)}")
|
||||
logger.debug(f" 参数内容: {arguments}")
|
||||
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
|
||||
|
||||
result = await self._session.call_tool(tool_name, arguments)
|
||||
|
||||
logger.debug(f" 工具返回类型: {type(result)}")
|
||||
logger.debug(f" 返回内容: {result}")
|
||||
|
||||
# 处理返回结果
|
||||
# MCP SDK 返回 CallToolResult 对象
|
||||
if result.content:
|
||||
logger.debug(f" 返回content数量: {len(result.content)}")
|
||||
# 提取第一个content的文本
|
||||
for idx, content in enumerate(result.content):
|
||||
logger.debug(f" content[{idx}]类型: {type(content)}")
|
||||
if isinstance(content, types.TextContent):
|
||||
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
logger.debug(f" ✅ 返回ImageContent")
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
# 如果没有文本内容,返回原始内容
|
||||
logger.debug(f" ⚠️ 返回原始content[0]")
|
||||
return result.content[0] if result.content else None
|
||||
|
||||
# 如果有结构化内容(2025-06-18规范)
|
||||
if hasattr(result, 'structuredContent') and result.structuredContent:
|
||||
logger.debug(f" ✅ 返回structuredContent")
|
||||
return result.structuredContent
|
||||
|
||||
logger.warning(f" ⚠️ 工具返回为None")
|
||||
return None
|
||||
|
||||
except ClosedResourceError as e:
|
||||
# 连接已关闭,尝试重连
|
||||
if attempt < max_reconnect_attempts:
|
||||
logger.warning(
|
||||
f"⚠️ MCP连接已关闭,尝试重新连接 "
|
||||
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
|
||||
)
|
||||
await self._cleanup()
|
||||
await asyncio.sleep(0.5) # 短暂延迟后重连
|
||||
continue
|
||||
else:
|
||||
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
|
||||
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
|
||||
raise MCPError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
logger.error(f" 参数: {arguments}")
|
||||
logger.error(f" 错误类型: {type(e).__name__}")
|
||||
logger.error(f" 错误详情: {repr(e)}")
|
||||
logger.error(f" 错误字符串: '{str(e)}'")
|
||||
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
|
||||
raise MCPError(f"调用工具失败: {error_msg}")
|
||||
|
||||
# 理论上不会到这里
|
||||
raise MCPError(f"工具调用失败: 未知错误")
|
||||
|
||||
async def list_resources(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,7 @@ class Settings(Base):
|
||||
llm_model = Column(String(100), default="gpt-4", comment="模型名称")
|
||||
temperature = Column(Float, default=0.7, comment="温度参数")
|
||||
max_tokens = Column(Integer, default=2000, comment="最大token数")
|
||||
system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用")
|
||||
preferences = Column(Text, comment="其他偏好设置(JSON)")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
@@ -14,6 +14,7 @@ class SettingsBase(BaseModel):
|
||||
llm_model: Optional[str] = Field(default="gpt-4", description="模型名称")
|
||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
||||
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用")
|
||||
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""AI 客户端模块"""
|
||||
from .base_client import BaseAIClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
|
||||
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Anthropic 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient:
|
||||
"""Anthropic API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.config = config or default_config
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
self.client = AsyncAnthropic(**kwargs)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if block.type == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {"name": block.name, "arguments": block.input},
|
||||
})
|
||||
elif block.type == "text":
|
||||
content += block.text
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": response.stop_reason,
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
@@ -0,0 +1,154 @@
|
||||
"""AI 客户端基类"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局 HTTP 客户端池
|
||||
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
||||
_global_semaphore: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
||||
"""获取全局信号量"""
|
||||
global _global_semaphore
|
||||
if _global_semaphore is None:
|
||||
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
return _global_semaphore
|
||||
|
||||
|
||||
class BaseAIClient(ABC):
|
||||
"""AI HTTP 客户端基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or default_config
|
||||
self.http_client = self._get_or_create_client()
|
||||
|
||||
def _get_client_key(self) -> str:
|
||||
"""生成客户端唯一键"""
|
||||
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
||||
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
||||
|
||||
def _get_or_create_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
client_key = self._get_client_key()
|
||||
|
||||
if client_key in _http_client_pool:
|
||||
client = _http_client_pool[client_key]
|
||||
if not client.is_closed:
|
||||
return client
|
||||
del _http_client_pool[client_key]
|
||||
|
||||
http_cfg = self.config.http
|
||||
client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
||||
max_connections=http_cfg.max_connections,
|
||||
keepalive_expiry=http_cfg.keepalive_expiry,
|
||||
),
|
||||
)
|
||||
_http_client_pool[client_key] = client
|
||||
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
||||
return client
|
||||
|
||||
@abstractmethod
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
) -> Any:
|
||||
"""带重试的 HTTP 请求"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = self._build_headers()
|
||||
retry_cfg = self.config.retry
|
||||
rate_cfg = self.config.rate_limit
|
||||
|
||||
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(rate_cfg.request_delay)
|
||||
|
||||
for attempt in range(retry_cfg.max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
delay = min(
|
||||
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
||||
retry_cfg.max_delay,
|
||||
)
|
||||
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if stream:
|
||||
return self.http_client.stream(method, url, headers=headers, json=payload)
|
||||
|
||||
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
||||
raise
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天补全"""
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_all_clients():
|
||||
"""清理所有 HTTP 客户端"""
|
||||
for key, client in list(_http_client_pool.items()):
|
||||
if not client.is_closed:
|
||||
await client.aclose()
|
||||
_http_client_pool.clear()
|
||||
logger.info("✅ HTTP 客户端池已清理")
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Gemini 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Google Gemini API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||
self.config = config or default_config
|
||||
http_cfg = self.config.http
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_tools_to_gemini(self, tools: list) -> list:
|
||||
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
||||
gemini_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
||||
params.pop("$schema", None)
|
||||
params.pop("additionalProperties", None)
|
||||
if params and "type" not in params:
|
||||
params["type"] = "object"
|
||||
decl = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description") or func["name"],
|
||||
}
|
||||
if params:
|
||||
decl["parameters"] = params
|
||||
gemini_tools.append(decl)
|
||||
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates or len(candidates) == 0:
|
||||
# 返回空内容而不是报错,保持流程继续
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text = ""
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
||||
})
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
import json
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = parts[0].get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except:
|
||||
continue
|
||||
@@ -0,0 +1,101 @@
|
||||
"""OpenAI 客户端"""
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from .base_client import BaseAIClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(BaseAIClient):
|
||||
"""OpenAI API 客户端"""
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
if tools:
|
||||
# 清理 $schema 字段
|
||||
cleaned = []
|
||||
for t in tools:
|
||||
tc = t.copy()
|
||||
if "function" in tc and "parameters" in tc["function"]:
|
||||
tc["function"]["parameters"] = {
|
||||
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
|
||||
}
|
||||
cleaned.append(tc)
|
||||
payload["tools"] = cleaned
|
||||
if tool_choice:
|
||||
payload["tool_choice"] = tool_choice
|
||||
return payload
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
|
||||
data = await self._request_with_retry("POST", "/chat/completions", payload)
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices or len(choices) == 0:
|
||||
raise ValueError("API 返回空 choices 或 choices 为空列表")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
return {
|
||||
"content": message.get("content", ""),
|
||||
"tool_calls": message.get("tool_calls"),
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
|
||||
|
||||
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
@@ -0,0 +1,44 @@
|
||||
"""AI 服务配置管理"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPClientConfig:
|
||||
"""HTTP 客户端配置"""
|
||||
connect_timeout: float = 90.0
|
||||
read_timeout: float = 300.0
|
||||
write_timeout: float = 90.0
|
||||
pool_timeout: float = 90.0
|
||||
max_keepalive_connections: int = 50
|
||||
max_connections: int = 100
|
||||
keepalive_expiry: float = 60.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
max_retries: int = 3
|
||||
base_delay: float = 0.2
|
||||
max_delay: float = 10.0
|
||||
exponential_base: int = 2
|
||||
non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
max_concurrent_requests: int = 5
|
||||
request_delay: float = 0.2
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIClientConfig:
|
||||
"""AI 客户端完整配置"""
|
||||
http: HTTPClientConfig = field(default_factory=HTTPClientConfig)
|
||||
retry: RetryConfig = field(default_factory=RetryConfig)
|
||||
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
|
||||
|
||||
|
||||
# 全局默认配置
|
||||
default_config = AIClientConfig()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""AI Provider 模块"""
|
||||
from .base_provider import BaseAIProvider
|
||||
from .openai_provider import OpenAIProvider
|
||||
from .anthropic_provider import AnthropicProvider
|
||||
|
||||
__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"]
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Anthropic Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
|
||||
class AnthropicProvider(BaseAIProvider):
|
||||
"""Anthropic 提供商"""
|
||||
|
||||
def __init__(self, client: AnthropicClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,33 @@
|
||||
"""AI Provider 基类"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseAIProvider(ABC):
|
||||
"""AI 提供商抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成文本"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
pass
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Gemini Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from app.services.ai_clients.gemini_client import GeminiClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
|
||||
class GeminiProvider(BaseAIProvider):
|
||||
def __init__(self, client: GeminiClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,57 @@
|
||||
"""OpenAI Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.services.ai_clients.openai_client import OpenAIClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseAIProvider):
|
||||
"""OpenAI 提供商"""
|
||||
|
||||
def __init__(self, client: OpenAIClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield chunk
|
||||
+177
-1295
File diff suppressed because it is too large
Load Diff
@@ -263,7 +263,7 @@ class AutoCharacterService:
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
@@ -362,7 +362,7 @@ class AutoCharacterService:
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""JSON 处理工具类"""
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Union
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def clean_json_response(text: str) -> str:
|
||||
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
|
||||
try:
|
||||
if not text:
|
||||
logger.warning("⚠️ clean_json_response: 输入为空")
|
||||
return text
|
||||
|
||||
original_length = len(text)
|
||||
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
|
||||
|
||||
# 去除 markdown 代码块
|
||||
text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE)
|
||||
text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE)
|
||||
text = text.strip()
|
||||
|
||||
if len(text) != original_length:
|
||||
logger.debug(f" 移除markdown后长度: {len(text)}")
|
||||
|
||||
# 尝试直接解析(快速路径)
|
||||
try:
|
||||
json.loads(text)
|
||||
logger.debug(f"✅ 直接解析成功,无需清洗")
|
||||
return text
|
||||
except:
|
||||
pass
|
||||
|
||||
# 找到第一个 { 或 [
|
||||
start = -1
|
||||
for i, c in enumerate(text):
|
||||
if c in ('{', '['):
|
||||
start = i
|
||||
break
|
||||
|
||||
if start == -1:
|
||||
logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [")
|
||||
logger.debug(f" 文本预览: {text[:200]}")
|
||||
return text
|
||||
|
||||
if start > 0:
|
||||
logger.debug(f" 跳过前{start}个字符")
|
||||
text = text[start:]
|
||||
|
||||
# 改进的括号匹配算法(更严格的字符串处理)
|
||||
stack = []
|
||||
i = 0
|
||||
end = -1
|
||||
|
||||
while i < len(text):
|
||||
c = text[i]
|
||||
|
||||
# 处理字符串(关键:正确处理转义)
|
||||
if c == '"':
|
||||
# 计算前面有多少个连续的反斜杠
|
||||
num_backslashes = 0
|
||||
j = i - 1
|
||||
while j >= 0 and text[j] == '\\':
|
||||
num_backslashes += 1
|
||||
j -= 1
|
||||
|
||||
# 偶数个反斜杠(包括0)表示引号未被转义
|
||||
if num_backslashes % 2 == 0:
|
||||
# 这是字符串边界,跳过整个字符串
|
||||
i += 1
|
||||
while i < len(text):
|
||||
if text[i] == '"':
|
||||
# 再次检查转义
|
||||
num_backslashes = 0
|
||||
j = i - 1
|
||||
while j >= 0 and text[j] == '\\':
|
||||
num_backslashes += 1
|
||||
j -= 1
|
||||
if num_backslashes % 2 == 0:
|
||||
# 字符串结束
|
||||
break
|
||||
i += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 处理括号(只有在字符串外部才有效)
|
||||
if c == '{' or c == '[':
|
||||
stack.append(c)
|
||||
elif c == '}':
|
||||
if len(stack) > 0 and stack[-1] == '{':
|
||||
stack.pop()
|
||||
if len(stack) == 0:
|
||||
end = i + 1
|
||||
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||
break
|
||||
else:
|
||||
logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1] if stack else 'empty'}")
|
||||
elif c == ']':
|
||||
if len(stack) > 0 and stack[-1] == '[':
|
||||
stack.pop()
|
||||
if len(stack) == 0:
|
||||
end = i + 1
|
||||
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||
break
|
||||
else:
|
||||
logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1] if stack else 'empty'}")
|
||||
|
||||
i += 1
|
||||
|
||||
# 提取结果
|
||||
if end > 0:
|
||||
result = text[:end]
|
||||
logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}")
|
||||
else:
|
||||
result = text
|
||||
logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)})")
|
||||
logger.debug(f" 栈状态: {stack}")
|
||||
|
||||
# 验证清洗后的结果
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.debug(f"✅ 清洗后JSON验证成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ clean_json_response 出错: {e}")
|
||||
logger.error(f" 文本长度: {len(text) if text else 0}")
|
||||
logger.error(f" 文本预览: {text[:200] if text else 'None'}")
|
||||
raise
|
||||
|
||||
|
||||
def parse_json(text: str) -> Union[Dict, List]:
|
||||
"""解析 JSON"""
|
||||
try:
|
||||
cleaned = clean_json_response(text)
|
||||
return json.loads(cleaned)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ parse_json 出错: {e}")
|
||||
logger.error(f" 原始文本长度: {len(text) if text else 0}")
|
||||
logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}")
|
||||
raise
|
||||
@@ -175,20 +175,34 @@ class MCPTestService:
|
||||
db=db_session
|
||||
)
|
||||
|
||||
ai_response = await ai_service.generate_text(
|
||||
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
|
||||
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
|
||||
# 因此这里需要特殊处理
|
||||
accumulated_text = ""
|
||||
tool_calls = None
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
tools=openai_tools,
|
||||
tool_choice="required"
|
||||
)
|
||||
):
|
||||
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
|
||||
if isinstance(chunk, dict):
|
||||
if "tool_calls" in chunk:
|
||||
tool_calls = chunk["tool_calls"]
|
||||
if "content" in chunk:
|
||||
accumulated_text += chunk.get("content", "")
|
||||
else:
|
||||
accumulated_text += chunk
|
||||
|
||||
# 5. 检查AI是否返回工具调用
|
||||
if not ai_response.get("tool_calls"):
|
||||
if not tool_calls:
|
||||
logger.error(f"❌ AI未返回工具调用")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI Function Calling失败",
|
||||
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
|
||||
error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}",
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
"请确认使用的AI模型支持Function Calling",
|
||||
@@ -198,7 +212,7 @@ class MCPTestService:
|
||||
)
|
||||
|
||||
# 6. 解析工具调用
|
||||
tool_call = ai_response["tool_calls"][0]
|
||||
tool_call = tool_calls[0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
@@ -386,17 +386,30 @@ class MCPToolService:
|
||||
|
||||
try:
|
||||
# 解析插件名和工具名
|
||||
logger.debug(f"🔍 解析工具名称: {function_name}")
|
||||
if "_" in function_name:
|
||||
plugin_name, tool_name = function_name.split("_", 1)
|
||||
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
|
||||
else:
|
||||
raise ValueError(f"无效的工具名称格式: {function_name}")
|
||||
|
||||
# 解析参数
|
||||
arguments_str = tool_call["function"]["arguments"]
|
||||
logger.debug(f"🔍 解析参数:")
|
||||
logger.debug(f" 原始类型: {type(arguments_str)}")
|
||||
logger.debug(f" 原始内容: {arguments_str}")
|
||||
|
||||
if isinstance(arguments_str, str):
|
||||
arguments = json.loads(arguments_str)
|
||||
try:
|
||||
arguments = json.loads(arguments_str)
|
||||
logger.debug(f" ✅ JSON解析成功: {arguments}")
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f" ❌ JSON解析失败: {je}")
|
||||
logger.error(f" 原始字符串: '{arguments_str}'")
|
||||
raise ValueError(f"参数JSON解析失败: {je}")
|
||||
else:
|
||||
arguments = arguments_str
|
||||
logger.debug(f" 直接使用dict类型参数")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {plugin_name}.{tool_name}, "
|
||||
|
||||
@@ -71,24 +71,15 @@ class PlotAnalyzer:
|
||||
# 调用AI进行分析
|
||||
# 注意:不指定max_tokens,使用用户在设置中配置的值
|
||||
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
|
||||
response = await self.ai_service.generate_text(
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 🔍 添加调试日志:查看AI返回的原始内容
|
||||
# logger.info(f"🔍 AI返回类型: {type(response)}")
|
||||
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
|
||||
|
||||
# 从返回的字典中提取content字段
|
||||
if isinstance(response, dict):
|
||||
response_text = response.get('content', '')
|
||||
if not response_text:
|
||||
logger.error("❌ AI返回的字典中没有content字段或content为空")
|
||||
return None
|
||||
else:
|
||||
# 兼容旧的字符串返回格式
|
||||
response_text = response
|
||||
# 提取内容
|
||||
response_text = accumulated_text
|
||||
|
||||
# 解析JSON结果
|
||||
analysis_result = self._parse_analysis_response(response_text)
|
||||
|
||||
@@ -133,14 +133,16 @@ class PlotExpansionService:
|
||||
|
||||
# 调用AI生成章节规划
|
||||
logger.info(f"调用AI生成章节规划...")
|
||||
ai_response = await self.ai_service.generate_text(
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 提取内容
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析AI响应
|
||||
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||
@@ -236,14 +238,16 @@ class PlotExpansionService:
|
||||
|
||||
# 调用AI生成当前批次
|
||||
logger.info(f"调用AI生成第{batch_num + 1}批...")
|
||||
ai_response = await self.ai_service.generate_text(
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 提取内容
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析AI响应
|
||||
batch_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||
|
||||
@@ -6,142 +6,6 @@ import json
|
||||
class WritingStyleManager:
|
||||
"""写作风格管理器"""
|
||||
|
||||
# 预设风格配置
|
||||
PRESET_STYLES = {
|
||||
"natural": {
|
||||
"name": "自然沉浸 (Natural & Immersive)",
|
||||
"description": "祛除翻译腔,强调生活质感,像呼吸一样自然的叙事",
|
||||
"prompt_content": """
|
||||
### 核心指令:自然沉浸风格
|
||||
请模拟人类作家在放松状态下的写作,通过以下规则消除“AI味”:
|
||||
|
||||
1. **拒绝翻译腔与书面化**:
|
||||
- 严禁使用“一种...的感觉”、“随着...”、“与此同时”等连接词。
|
||||
- 多用短句和“流水句”,模拟人类视线的移动和思维的跳跃。
|
||||
- 口语化叙述,但不要滥用语气词,而是通过句子的长短节奏来体现语气。
|
||||
|
||||
2. **生活化的颗粒度**:
|
||||
- 描写不要宏大,要聚焦在具体的、微小的生活细节(如:杯子上的水渍、衣服的褶皱)。
|
||||
- 允许逻辑上的适度“松散”,不要让每句话都像说明书一样严丝合缝。
|
||||
|
||||
3. **具体的“展示”**:
|
||||
- 不要写“他很生气”,要写他“把烟头按灭在还没吃完的米饭里”。
|
||||
- 避免使用抽象的形容词(如:巨大的、美丽的、悲伤的),必须用名词和动词来承载画面。
|
||||
"""
|
||||
},
|
||||
"classical": {
|
||||
"name": "古典雅致 (Classical & Elegant)",
|
||||
"description": "白话文与古典韵味的结合,强调留白与炼字",
|
||||
"prompt_content": """
|
||||
### 核心指令:古典雅致风格
|
||||
请模仿民国时期或古典白话小说的笔触,构建端庄且富有余味的叙事:
|
||||
|
||||
1. **炼字与韵律**:
|
||||
- 尽量使用双音节词或四字短语,但严禁堆砌辞藻。
|
||||
- 注重句子的声调韵律,读起来要有金石之声或流水之韵。
|
||||
- 适当使用倒装句或定语后置,增加古雅感。
|
||||
|
||||
2. **克制的修辞**:
|
||||
- 少用现代的比喻(如“像机器一样”),多用取自自然的比喻(如“如风过林”)。
|
||||
- **意在言外**:不要把话说透,留三分余地。写景即是写情,不要将情感直接剖白。
|
||||
|
||||
3. **禁忌**:
|
||||
- 严禁使用现代科技词汇(除非题材需要)、网络用语或过于西化的句式(如长定语从句)。
|
||||
- 避免滥用“之乎者也”,追求的是“神似”而非生硬的半文半白。
|
||||
"""
|
||||
},
|
||||
"modern": {
|
||||
"name": "冷硬现代 (Modern & Hard-boiled)",
|
||||
"description": "海明威式的冰山理论,节奏极快,零度情感",
|
||||
"prompt_content": """
|
||||
### 核心指令:冷硬现代风格
|
||||
请采用“极简主义”和“零度写作”手法,去除所有矫饰:
|
||||
|
||||
1. **冰山理论**:
|
||||
- **只写动作和对话,完全剔除心理描写和形容词堆砌。**
|
||||
- 不要告诉读者角色感觉如何,通过角色的反应和环境的冷峻反馈来体现。
|
||||
|
||||
2. **电影蒙太奇节奏**:
|
||||
- 句子要短、脆、硬。像手术刀一样切开场景。
|
||||
- 段落之间快速切换,不要用过渡句连接,直接跳切。
|
||||
|
||||
3. **高信息密度**:
|
||||
- 删除所有废话。如果一个词删掉不影响理解,就删掉它。
|
||||
- 多用名词和强动词(Strong Verbs),少用副词(Adverbs)。例如:不要写“他重重地关上门”,写“他摔上了门”。
|
||||
"""
|
||||
},
|
||||
"poetic": {
|
||||
"name": "意识流 (Stream of Consciousness)",
|
||||
"description": "注重感官通感与内心独白,打破现实与幻想的边界",
|
||||
"prompt_content": """
|
||||
### 核心指令:意识流/诗意风格
|
||||
请侧重于主观感受的流动,而非客观事实的记录:
|
||||
|
||||
1. **通感与陌生化**:
|
||||
- 打通五感(如:听到了颜色的声音,闻到了悲伤的气味)。
|
||||
- 使用“陌生化”的语言,把熟悉的事物写得陌生,迫使读者重新审视。
|
||||
|
||||
2. **情绪的具象化**:
|
||||
- **绝对禁止**直接出现“开心”、“痛苦”等抽象词汇。
|
||||
- 必须寻找“客观对应物”(Objective Correlative),将情绪投射到具体的景物上(如:生锈的铁轨、发霉的橘子)。
|
||||
|
||||
3. **流动的句式**:
|
||||
- 句子可以很长,包含多重意象的叠加。
|
||||
- 允许思维的非线性跳跃,模拟梦境或深层潜意识的逻辑。
|
||||
"""
|
||||
},
|
||||
"concise": {
|
||||
"name": "白描速写 (Sketch & Concise)",
|
||||
"description": "只有骨架的叙事,强调绝对的精准和功能性",
|
||||
"prompt_content": """
|
||||
### 核心指令:白描速写风格
|
||||
请像速写画家一样,只勾勒线条,不涂抹色彩:
|
||||
|
||||
1. **功能性第一**:
|
||||
- 每一句话必须推动情节,或者揭示关键信息。
|
||||
- 如果一句话只是为了渲染气氛,删掉它。
|
||||
|
||||
2. **主谓宾结构**:
|
||||
- 尽量使用简单的主谓宾结构,减少修饰语。
|
||||
- 避免复杂的从句和嵌套结构。
|
||||
|
||||
3. **直击核心**:
|
||||
- 对话直接进入主题,去除寒暄和废话。
|
||||
- 环境描写仅限于对情节有物理影响的物体(如:挡路的石头、藏在桌下的枪)。
|
||||
"""
|
||||
},
|
||||
"vivid": {
|
||||
"name": "感官特写 (Sensory & Vivid)",
|
||||
"description": "高分辨率的描写,强调材质、光影和微观细节",
|
||||
"prompt_content": """
|
||||
### 核心指令:感官特写风格
|
||||
请将镜头推到特写级别(Macro Lens),捕捉常人忽略的细节:
|
||||
|
||||
1. **反套路细节**:
|
||||
- 不要写大众化的细节(如:蓝天白云),要写具有**独特性**的细节(如:云层边缘那抹像淤青一样的灰紫色)。
|
||||
- 关注物体的**质感(Texture)**:粗糙的、粘稠的、冰凉的、颗粒感的。
|
||||
|
||||
2. **动态捕捉**:
|
||||
- 不要写静止的画面,要写光影的流变、灰尘的飞舞、肌肉的抽动。
|
||||
- 让读者产生生理性的反应(如:痛感、饥饿感、窒息感)。
|
||||
|
||||
3. **禁用词汇**:
|
||||
- 禁止使用“映入眼帘”、“宛如画卷”等陈词滥调。
|
||||
- 必须用具体的动词带动感官描写。
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_preset_style(cls, preset_id: str) -> Optional[Dict[str, str]]:
|
||||
"""获取预设风格配置"""
|
||||
return cls.PRESET_STYLES.get(preset_id)
|
||||
|
||||
@classmethod
|
||||
def get_all_presets(cls) -> Dict[str, Dict[str, str]]:
|
||||
"""获取所有预设风格"""
|
||||
return cls.PRESET_STYLES
|
||||
|
||||
@staticmethod
|
||||
def apply_style_to_prompt(base_prompt: str, style_content: str) -> str:
|
||||
"""
|
||||
@@ -692,9 +556,8 @@ class PromptService:
|
||||
|
||||
6. **承上启下**:
|
||||
- 开头自然衔接上一章结尾(但不重复上一章内容)
|
||||
- 结尾为下一章做好铺垫
|
||||
|
||||
6. **记忆系统使用指南**:
|
||||
7. **记忆系统使用指南**:
|
||||
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
|
||||
- **语义相关记忆**:参考相似情节的处理方式
|
||||
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
|
||||
@@ -1308,16 +1171,15 @@ class PromptService:
|
||||
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
|
||||
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
|
||||
|
||||
# 灵感模式提示词字典
|
||||
INSPIRATION_PROMPTS = {
|
||||
"title": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
# 灵感模式 - 书名生成(系统提示词)
|
||||
INSPIRATION_TITLE_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
|
||||
请根据用户的想法,生成6个吸引人的书名建议,要求:
|
||||
1. 紧扣用户的原始想法和核心故事构思
|
||||
2. 富有创意和吸引力
|
||||
3. 涵盖不同的风格倾向
|
||||
4. 书名中不要带有"《》"符号
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
@@ -1325,11 +1187,13 @@ class PromptService:
|
||||
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
|
||||
}}
|
||||
|
||||
只返回纯JSON,不要有其他文字。""",
|
||||
"user": "用户的想法:{initial_idea}\n请生成6个书名建议"
|
||||
},
|
||||
"description": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
只返回纯JSON,不要有其他文字。"""
|
||||
|
||||
# 灵感模式 - 书名生成(用户提示词)
|
||||
INSPIRATION_TITLE_USER = "用户的想法:{initial_idea}\n请生成6个书名建议"
|
||||
|
||||
# 灵感模式 - 简介生成(系统提示词)
|
||||
INSPIRATION_DESCRIPTION_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
已确定的书名:{title}
|
||||
|
||||
@@ -1343,11 +1207,13 @@ class PromptService:
|
||||
返回JSON格式:
|
||||
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
|
||||
|
||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
||||
},
|
||||
"theme": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
只返回纯JSON,不要有其他文字,不要换行。"""
|
||||
|
||||
# 灵感模式 - 简介生成(用户提示词)
|
||||
INSPIRATION_DESCRIPTION_USER = "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
||||
|
||||
# 灵感模式 - 主题生成(系统提示词)
|
||||
INSPIRATION_THEME_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
小说信息:
|
||||
- 书名:{title}
|
||||
@@ -1363,11 +1229,13 @@ class PromptService:
|
||||
返回JSON格式:
|
||||
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
|
||||
|
||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
||||
},
|
||||
"genre": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
只返回纯JSON,不要有其他文字,不要换行。"""
|
||||
|
||||
# 灵感模式 - 主题生成(用户提示词)
|
||||
INSPIRATION_THEME_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
||||
|
||||
# 灵感模式 - 类型生成(系统提示词)
|
||||
INSPIRATION_GENRE_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
小说信息:
|
||||
- 书名:{title}
|
||||
@@ -1384,10 +1252,10 @@ class PromptService:
|
||||
返回JSON格式:
|
||||
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
|
||||
|
||||
只返回紧凑的纯JSON,不要换行,不要有其他文字。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
||||
}
|
||||
}
|
||||
只返回紧凑的纯JSON,不要换行,不要有其他文字。"""
|
||||
|
||||
# 灵感模式 - 类型生成(用户提示词)
|
||||
INSPIRATION_GENRE_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
||||
|
||||
# 灵感模式智能补全提示词
|
||||
INSPIRATION_QUICK_COMPLETE = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
|
||||
@@ -1887,7 +1755,26 @@ class PromptService:
|
||||
@classmethod
|
||||
def get_inspiration_prompt(cls, step: str) -> Optional[Dict[str, str]]:
|
||||
"""获取灵感模式指定步骤的提示词"""
|
||||
return cls.INSPIRATION_PROMPTS.get(step)
|
||||
# 根据步骤名称返回对应的system和user提示词
|
||||
step_map = {
|
||||
"title": {
|
||||
"system": cls.INSPIRATION_TITLE_SYSTEM,
|
||||
"user": cls.INSPIRATION_TITLE_USER
|
||||
},
|
||||
"description": {
|
||||
"system": cls.INSPIRATION_DESCRIPTION_SYSTEM,
|
||||
"user": cls.INSPIRATION_DESCRIPTION_USER
|
||||
},
|
||||
"theme": {
|
||||
"system": cls.INSPIRATION_THEME_SYSTEM,
|
||||
"user": cls.INSPIRATION_THEME_USER
|
||||
},
|
||||
"genre": {
|
||||
"system": cls.INSPIRATION_GENRE_SYSTEM,
|
||||
"user": cls.INSPIRATION_GENRE_USER
|
||||
}
|
||||
}
|
||||
return step_map.get(step)
|
||||
|
||||
@classmethod
|
||||
def get_inspiration_quick_complete_prompt(cls, existing: str) -> Dict[str, str]:
|
||||
@@ -1997,17 +1884,12 @@ class PromptService:
|
||||
# 2. 降级到系统默认模板
|
||||
logger.info(f"⚪ 使用系统默认提示词: user_id={user_id}, template_key={template_key} (未找到自定义模板)")
|
||||
|
||||
# 特殊处理灵感模式的提示词(存储在INSPIRATION_PROMPTS字典中)
|
||||
# 特殊处理灵感模式的提示词(直接从类属性获取)
|
||||
if template_key.startswith("INSPIRATION_"):
|
||||
# 提取步骤名称(如 INSPIRATION_TITLE -> title)
|
||||
step = template_key.replace("INSPIRATION_", "").lower()
|
||||
inspiration_prompt = cls.INSPIRATION_PROMPTS.get(step)
|
||||
if inspiration_prompt:
|
||||
# 返回JSON格式的提示词
|
||||
return json.dumps(inspiration_prompt, ensure_ascii=False)
|
||||
# 如果是INSPIRATION_QUICK_COMPLETE
|
||||
if template_key == "INSPIRATION_QUICK_COMPLETE":
|
||||
return cls.INSPIRATION_QUICK_COMPLETE
|
||||
# 直接从类属性获取
|
||||
template_content = getattr(cls, template_key, None)
|
||||
if template_content:
|
||||
return template_content
|
||||
|
||||
# 其他模板直接从类属性获取
|
||||
template_content = getattr(cls, template_key, None)
|
||||
@@ -2182,6 +2064,60 @@ class PromptService:
|
||||
"category": "世界构建",
|
||||
"description": "根据世界观自动生成完整的职业体系,包括主职业和副职业",
|
||||
"parameters": ["title", "genre", "theme", "time_period", "location", "atmosphere", "rules"]
|
||||
},
|
||||
"INSPIRATION_TITLE_SYSTEM": {
|
||||
"name": "灵感模式-书名生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户的原始想法生成6个书名建议的系统提示词",
|
||||
"parameters": ["initial_idea"]
|
||||
},
|
||||
"INSPIRATION_TITLE_USER": {
|
||||
"name": "灵感模式-书名生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户的原始想法生成6个书名建议的用户提示词",
|
||||
"parameters": ["initial_idea"]
|
||||
},
|
||||
"INSPIRATION_DESCRIPTION_SYSTEM": {
|
||||
"name": "灵感模式-简介生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户想法和书名生成6个简介选项的系统提示词",
|
||||
"parameters": ["initial_idea", "title"]
|
||||
},
|
||||
"INSPIRATION_DESCRIPTION_USER": {
|
||||
"name": "灵感模式-简介生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户想法和书名生成6个简介选项的用户提示词",
|
||||
"parameters": ["initial_idea", "title"]
|
||||
},
|
||||
"INSPIRATION_THEME_SYSTEM": {
|
||||
"name": "灵感模式-主题生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据书名和简介生成6个深刻的主题选项的系统提示词",
|
||||
"parameters": ["initial_idea", "title", "description"]
|
||||
},
|
||||
"INSPIRATION_THEME_USER": {
|
||||
"name": "灵感模式-主题生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据书名和简介生成6个深刻的主题选项的用户提示词",
|
||||
"parameters": ["initial_idea", "title", "description"]
|
||||
},
|
||||
"INSPIRATION_GENRE_SYSTEM": {
|
||||
"name": "灵感模式-类型生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据小说信息生成6个合适的类型标签的系统提示词",
|
||||
"parameters": ["initial_idea", "title", "description", "theme"]
|
||||
},
|
||||
"INSPIRATION_GENRE_USER": {
|
||||
"name": "灵感模式-类型生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据小说信息生成6个合适的类型标签的用户提示词",
|
||||
"parameters": ["initial_idea", "title", "description", "theme"]
|
||||
},
|
||||
"INSPIRATION_QUICK_COMPLETE": {
|
||||
"name": "灵感模式-智能补全",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户提供的部分信息智能补全完整的小说方案",
|
||||
"parameters": ["existing"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,11 +23,22 @@ class SSEResponse:
|
||||
Returns:
|
||||
格式化后的SSE消息字符串
|
||||
"""
|
||||
message = ""
|
||||
if event:
|
||||
message += f"event: {event}\n"
|
||||
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
return message
|
||||
try:
|
||||
message = ""
|
||||
if event:
|
||||
message += f"event: {event}\n"
|
||||
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
return message
|
||||
except Exception as e:
|
||||
logger.error(f"❌ SSE格式化失败: {type(e).__name__}: {e}")
|
||||
logger.error(f" data类型: {type(data)}")
|
||||
logger.error(f" data内容: {str(data)[:500]}")
|
||||
# 返回错误消息而不是崩溃
|
||||
error_message = ""
|
||||
if event:
|
||||
error_message += f"event: {event}\n"
|
||||
error_message += f'data: {{"type": "error", "error": "SSE格式化失败: {str(e)}", "code": 500}}\n\n'
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def send_progress(
|
||||
|
||||
@@ -190,7 +190,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
@@ -236,7 +237,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(33 + Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
@@ -273,7 +275,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(66 + Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: () => {
|
||||
@@ -336,15 +339,13 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
// 世界观生成占0%-20%,职业生成占20%-30%
|
||||
const baseProgress = Math.floor(prog / 5);
|
||||
setProgress(baseProgress);
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
|
||||
// 检测职业体系生成阶段 - 必须包含"职业体系"才算职业阶段
|
||||
// 检测职业体系生成阶段
|
||||
if (msg.includes('职业体系')) {
|
||||
if (msg.includes('开始') || msg.includes('生成')) {
|
||||
// 职业开始时,世界观应该已完成
|
||||
setGenerationSteps(prev => ({
|
||||
...prev,
|
||||
worldBuilding: 'completed',
|
||||
@@ -403,8 +404,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
// 角色生成占40%-70%
|
||||
setProgress(40 + Math.floor(prog * 0.3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
@@ -437,8 +438,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
// 大纲生成占70%-100%
|
||||
setProgress(70 + Math.floor(prog * 0.3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: () => {
|
||||
@@ -533,8 +534,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
const baseProgress = Math.floor(prog / 5);
|
||||
setProgress(baseProgress);
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
|
||||
// 检测职业体系生成阶段
|
||||
@@ -604,7 +605,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(33 + Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
@@ -647,7 +649,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(66 + Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: () => {
|
||||
@@ -707,7 +710,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(33 + Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: (result) => {
|
||||
@@ -746,7 +750,8 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
|
||||
},
|
||||
{
|
||||
onProgress: (msg, prog) => {
|
||||
setProgress(66 + Math.floor(prog / 3));
|
||||
// 直接使用后端返回的进度值
|
||||
setProgress(prog);
|
||||
setProgressMessage(msg);
|
||||
},
|
||||
onResult: () => {
|
||||
|
||||
@@ -16,6 +16,8 @@ interface Message {
|
||||
options?: string[];
|
||||
isMultiSelect?: boolean;
|
||||
optionsDisabled?: boolean; // 标记选项是否已禁用
|
||||
canRefine?: boolean; // 是否可以优化(用于支持多轮对话)
|
||||
step?: Step; // 当前步骤(用于反馈)
|
||||
}
|
||||
|
||||
interface WizardData {
|
||||
@@ -69,6 +71,11 @@ const Inspiration: React.FC = () => {
|
||||
const [wizardData, setWizardData] = useState<Partial<WizardData>>({});
|
||||
// 保存用户的原始想法,用于保持上下文一致性
|
||||
const [initialIdea, setInitialIdea] = useState<string>('');
|
||||
|
||||
// 反馈相关状态
|
||||
const [feedbackValue, setFeedbackValue] = useState('');
|
||||
const [showFeedbackInput, setShowFeedbackInput] = useState<number | null>(null); // 当前显示反馈输入的消息索引
|
||||
const [refining, setRefining] = useState(false); // 正在优化选项
|
||||
|
||||
// 生成配置
|
||||
const [generationConfig, setGenerationConfig] = useState<GenerationConfig | null>(null);
|
||||
@@ -248,6 +255,86 @@ const Inspiration: React.FC = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 处理用户反馈,重新生成选项
|
||||
const handleRefineOptions = async (messageIndex: number, feedback: string) => {
|
||||
if (!feedback.trim()) {
|
||||
message.warning('请输入您的反馈意见');
|
||||
return;
|
||||
}
|
||||
|
||||
const targetMessage = messages[messageIndex];
|
||||
if (!targetMessage.options || !targetMessage.step) {
|
||||
return;
|
||||
}
|
||||
|
||||
setRefining(true);
|
||||
setShowFeedbackInput(null);
|
||||
setFeedbackValue('');
|
||||
|
||||
// 先禁用旧的选项
|
||||
setMessages(prev => {
|
||||
const newMessages = [...prev];
|
||||
if (newMessages[messageIndex]) {
|
||||
newMessages[messageIndex] = {
|
||||
...newMessages[messageIndex],
|
||||
optionsDisabled: true,
|
||||
canRefine: false, // 同时禁用反馈功能
|
||||
};
|
||||
}
|
||||
return newMessages;
|
||||
});
|
||||
|
||||
try {
|
||||
// 添加用户反馈消息
|
||||
const feedbackMessage: Message = {
|
||||
type: 'user',
|
||||
content: `💭 ${feedback}`,
|
||||
};
|
||||
setMessages(prev => [...prev, feedbackMessage]);
|
||||
|
||||
const step = targetMessage.step as 'title' | 'description' | 'theme' | 'genre';
|
||||
|
||||
// 构建上下文
|
||||
const context: any = {
|
||||
initial_idea: initialIdea,
|
||||
title: wizardData.title,
|
||||
description: wizardData.description,
|
||||
theme: wizardData.theme,
|
||||
};
|
||||
|
||||
// 调用refine接口
|
||||
const response = await inspirationApi.refineOptions({
|
||||
step,
|
||||
context,
|
||||
feedback,
|
||||
previous_options: targetMessage.options,
|
||||
});
|
||||
|
||||
if (response.error) {
|
||||
message.error(response.error);
|
||||
return;
|
||||
}
|
||||
|
||||
// 添加新的AI消息
|
||||
const aiMessage: Message = {
|
||||
type: 'ai',
|
||||
content: response.prompt || `根据您的反馈,我重新生成了一些${step === 'title' ? '书名' : step === 'description' ? '简介' : step === 'theme' ? '主题' : '类型'}选项:`,
|
||||
options: response.options || [],
|
||||
isMultiSelect: step === 'genre',
|
||||
canRefine: true,
|
||||
step: step,
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
|
||||
message.success('已根据您的反馈重新生成选项');
|
||||
} catch (error: any) {
|
||||
console.error('优化选项失败:', error);
|
||||
message.error(error.response?.data?.detail || '优化失败,请重试');
|
||||
} finally {
|
||||
setRefining(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 步骤顺序
|
||||
const stepOrder: Step[] = ['idea', 'title', 'description', 'theme', 'genre', 'perspective', 'outline_mode', 'confirm'];
|
||||
|
||||
@@ -297,7 +384,9 @@ const Inspiration: React.FC = () => {
|
||||
const aiMessage: Message = {
|
||||
type: 'ai',
|
||||
content: response.prompt || '请选择一个书名,或者输入你自己的:',
|
||||
options: response.options
|
||||
options: response.options,
|
||||
canRefine: true,
|
||||
step: 'title'
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
setCurrentStep('title');
|
||||
@@ -497,6 +586,24 @@ const Inspiration: React.FC = () => {
|
||||
updatedData.genre = [input];
|
||||
} else if (currentStep === 'perspective') {
|
||||
updatedData.narrative_perspective = input;
|
||||
setWizardData(updatedData);
|
||||
|
||||
// 直接进入大纲模式选择
|
||||
const aiMessage: Message = {
|
||||
type: 'ai',
|
||||
content: `很好!现在请选择你想要的大纲模式:
|
||||
|
||||
📋 一对一模式:传统模式,一个大纲对应一个章节,适合结构清晰、章节独立的小说。
|
||||
|
||||
📚 一对多模式:细化模式,一个大纲可以展开成多个章节,适合需要详细展开情节的小说。
|
||||
|
||||
请选择:`,
|
||||
options: ['📋 一对一模式', '📚 一对多模式']
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
setCurrentStep('outline_mode');
|
||||
setLoading(false);
|
||||
return;
|
||||
} else if (currentStep === 'outline_mode') {
|
||||
// 大纲模式不支持自定义输入
|
||||
message.warning('请从选项中选择一个大纲模式');
|
||||
@@ -561,7 +668,16 @@ const Inspiration: React.FC = () => {
|
||||
const currentIndex = stepOrder.indexOf(currentStep);
|
||||
const nextStep = stepOrder[currentIndex + 1];
|
||||
|
||||
if (nextStep === 'description') {
|
||||
if (nextStep === 'perspective') {
|
||||
// genre 步骤完成后,进入 perspective
|
||||
const aiMessage: Message = {
|
||||
type: 'ai',
|
||||
content: '很好!接下来,请选择小说的叙事视角:',
|
||||
options: ['第一人称', '第三人称', '全知视角']
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
setCurrentStep('perspective');
|
||||
} else if (nextStep === 'description') {
|
||||
const requestData = {
|
||||
step: 'description' as const,
|
||||
context: {
|
||||
@@ -587,7 +703,9 @@ const Inspiration: React.FC = () => {
|
||||
const aiMessage: Message = {
|
||||
type: 'ai',
|
||||
content: response.prompt || '请选择一个简介,或者输入你自己的:',
|
||||
options: response.options
|
||||
options: response.options,
|
||||
canRefine: true,
|
||||
step: 'description'
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
setCurrentStep('description');
|
||||
@@ -620,7 +738,9 @@ const Inspiration: React.FC = () => {
|
||||
const aiMessage: Message = {
|
||||
type: 'ai',
|
||||
content: response.prompt || '请选择一个主题,或者输入你自己的:',
|
||||
options: response.options
|
||||
options: response.options,
|
||||
canRefine: true,
|
||||
step: 'theme'
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
setCurrentStep('theme');
|
||||
@@ -656,7 +776,9 @@ const Inspiration: React.FC = () => {
|
||||
type: 'ai',
|
||||
content: response.prompt || '请选择类型标签(可多选):',
|
||||
options: response.options,
|
||||
isMultiSelect: true
|
||||
isMultiSelect: true,
|
||||
canRefine: true,
|
||||
step: 'genre'
|
||||
};
|
||||
setMessages(prev => [...prev, aiMessage]);
|
||||
setCurrentStep('genre');
|
||||
@@ -767,7 +889,7 @@ const Inspiration: React.FC = () => {
|
||||
background: msg.optionsDisabled
|
||||
? 'var(--color-bg-layout)'
|
||||
: msg.isMultiSelect && selectedOptions.includes(option)
|
||||
? 'var(--color-bg-spotlight)' // Need to ensure this exists or use safe fallback
|
||||
? 'var(--color-bg-spotlight)'
|
||||
: 'var(--color-bg-container)',
|
||||
opacity: msg.optionsDisabled ? 0.6 : 1,
|
||||
animation: 'floatIn 0.6s ease-out',
|
||||
@@ -802,19 +924,72 @@ const Inspiration: React.FC = () => {
|
||||
确认选择 ({selectedOptions.length})
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{/* 反馈优化区域 - 新增 */}
|
||||
{msg.canRefine && !msg.optionsDisabled && !msg.isMultiSelect && (
|
||||
<div style={{ marginTop: 8, paddingTop: 8, borderTop: '1px dashed var(--color-border)' }}>
|
||||
{showFeedbackInput === index ? (
|
||||
<Space direction="vertical" style={{ width: '100%' }} size="small">
|
||||
<TextArea
|
||||
value={feedbackValue}
|
||||
onChange={(e) => setFeedbackValue(e.target.value)}
|
||||
placeholder="例如:我想要更悲剧的主题、能不能更简短一些、偏向古风..."
|
||||
autoSize={{ minRows: 2, maxRows: 3 }}
|
||||
disabled={refining}
|
||||
onPressEnter={(e) => {
|
||||
if (!e.shiftKey && feedbackValue.trim()) {
|
||||
e.preventDefault();
|
||||
handleRefineOptions(index, feedbackValue);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Space style={{ width: '100%', justifyContent: 'flex-end' }}>
|
||||
<Button
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setShowFeedbackInput(null);
|
||||
setFeedbackValue('');
|
||||
}}
|
||||
disabled={refining}
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
type="primary"
|
||||
size="small"
|
||||
onClick={() => handleRefineOptions(index, feedbackValue)}
|
||||
loading={refining}
|
||||
disabled={!feedbackValue.trim()}
|
||||
>
|
||||
重新生成
|
||||
</Button>
|
||||
</Space>
|
||||
</Space>
|
||||
) : (
|
||||
<Button
|
||||
type="link"
|
||||
size="small"
|
||||
onClick={() => setShowFeedbackInput(index)}
|
||||
style={{ padding: 0, height: 'auto' }}
|
||||
>
|
||||
💡 不太满意?告诉我你的想法
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</Space>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{loading && (
|
||||
{(loading || refining) && (
|
||||
<div style={{
|
||||
textAlign: 'center',
|
||||
padding: 20,
|
||||
animation: 'fadeIn 0.3s ease-in'
|
||||
}}>
|
||||
<Spin tip="AI思考中..." />
|
||||
<Spin tip={refining ? "正在根据您的反馈重新生成..." : "AI思考中..."} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -150,10 +150,9 @@ export default function SettingsPage() {
|
||||
};
|
||||
|
||||
const apiProviders = [
|
||||
{ value: 'openai', label: 'OpenAl Compatible', defaultUrl: 'https://api.openai.com/v1' },
|
||||
// { value: 'azure', label: 'Azure OpenAI', defaultUrl: 'https://YOUR-RESOURCE.openai.azure.com' },
|
||||
// { value: 'anthropic', label: 'Anthropic', defaultUrl: 'https://api.anthropic.com' },
|
||||
// { value: 'custom', label: '自定义', defaultUrl: '' },
|
||||
{ value: 'openai', label: 'OpenAI Compatible', defaultUrl: 'https://api.openai.com/v1' },
|
||||
// { value: 'anthropic', label: 'Anthropic (Claude)', defaultUrl: 'https://api.anthropic.com' },
|
||||
{ value: 'gemini', label: 'Google Gemini', defaultUrl: 'https://generativelanguage.googleapis.com/v1beta' },
|
||||
];
|
||||
|
||||
const handleProviderChange = (value: string) => {
|
||||
@@ -483,8 +482,8 @@ export default function SettingsPage() {
|
||||
switch (provider) {
|
||||
case 'openai':
|
||||
return 'blue';
|
||||
case 'anthropic':
|
||||
return 'purple';
|
||||
// case 'anthropic':
|
||||
// return 'purple';
|
||||
case 'gemini':
|
||||
return 'green';
|
||||
default:
|
||||
@@ -973,6 +972,26 @@ export default function SettingsPage() {
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label={
|
||||
<Space size={4}>
|
||||
<span>系统提示词</span>
|
||||
<Tooltip title="设置全局系统提示词,每次AI调用时都会自动使用。可用于设定AI的角色、语言风格等">
|
||||
<InfoCircleOutlined style={{ color: 'var(--color-text-secondary)', fontSize: isMobile ? '12px' : '14px' }} />
|
||||
</Tooltip>
|
||||
</Space>
|
||||
}
|
||||
name="system_prompt"
|
||||
>
|
||||
<TextArea
|
||||
rows={4}
|
||||
placeholder="例如:你是一个专业的小说创作助手,请用生动、细腻的文字进行创作..."
|
||||
maxLength={10000}
|
||||
showCount
|
||||
style={{ fontSize: isMobile ? '13px' : '14px' }}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
{/* 测试结果展示 */}
|
||||
{showTestResult && testResult && (
|
||||
<Alert
|
||||
@@ -1247,7 +1266,7 @@ export default function SettingsPage() {
|
||||
>
|
||||
<Select>
|
||||
<Select.Option value="openai">OpenAI</Select.Option>
|
||||
<Select.Option value="anthropic">Anthropic (Claude)</Select.Option>
|
||||
{/* <Select.Option value="anthropic">Anthropic (Claude)</Select.Option> */}
|
||||
<Select.Option value="gemini">Google Gemini</Select.Option>
|
||||
</Select>
|
||||
</Form.Item>
|
||||
@@ -1298,6 +1317,18 @@ export default function SettingsPage() {
|
||||
placeholder="2000"
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="system_prompt"
|
||||
label="系统提示词"
|
||||
>
|
||||
<TextArea
|
||||
rows={3}
|
||||
placeholder="例如:你是一个专业的小说创作助手...(可选)"
|
||||
maxLength={10000}
|
||||
showCount
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
</div>
|
||||
|
||||
@@ -557,6 +557,24 @@ export const inspirationApi = {
|
||||
error?: string;
|
||||
}>('/inspiration/generate-options', data),
|
||||
|
||||
// 基于用户反馈重新生成选项(新增)
|
||||
refineOptions: (data: {
|
||||
step: 'title' | 'description' | 'theme' | 'genre';
|
||||
context: {
|
||||
initial_idea?: string;
|
||||
title?: string;
|
||||
description?: string;
|
||||
theme?: string;
|
||||
};
|
||||
feedback: string;
|
||||
previous_options?: string[];
|
||||
}) =>
|
||||
api.post<unknown, {
|
||||
prompt?: string;
|
||||
options: string[];
|
||||
error?: string;
|
||||
}>('/inspiration/refine-options', data),
|
||||
|
||||
// 智能补全缺失信息
|
||||
quickGenerate: (data: {
|
||||
title?: string;
|
||||
|
||||
@@ -21,6 +21,7 @@ export interface Settings {
|
||||
llm_model: string;
|
||||
temperature: number;
|
||||
max_tokens: number;
|
||||
system_prompt?: string;
|
||||
preferences?: string;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
@@ -33,6 +34,7 @@ export interface SettingsUpdate {
|
||||
llm_model?: string;
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
system_prompt?: string;
|
||||
preferences?: string;
|
||||
}
|
||||
|
||||
@@ -44,6 +46,7 @@ export interface APIKeyPresetConfig {
|
||||
llm_model: string;
|
||||
temperature: number;
|
||||
max_tokens: number;
|
||||
system_prompt?: string;
|
||||
}
|
||||
|
||||
export interface APIKeyPreset {
|
||||
|
||||
Reference in New Issue
Block a user