From 89848e2258f96f39d5225c39f78caced7577eacc Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Sun, 28 Dec 2025 19:35:23 +0800 Subject: [PATCH] =?UTF-8?q?update:1.=E4=BC=98=E5=8C=96=20AI=20=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E7=94=9F=E6=88=90=E5=92=8C=E8=BF=9B=E5=BA=A6=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E7=B3=BB=E7=BB=9F=202.=E6=96=B0=E5=A2=9E=E5=86=99?= =?UTF-8?q?=E4=BD=9C=E9=A3=8E=E6=A0=BC=E7=B3=BB=E7=BB=9F=E6=8F=90=E7=A4=BA?= =?UTF-8?q?=E8=AF=8D=E6=94=AF=E6=8C=81=203.=E7=81=B5=E6=84=9F=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=BC=BA,=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E7=81=B5=E6=84=9F=E9=87=8D=E5=86=99=204.=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E9=A1=B5=E9=9D=A2=E5=8A=9F=E8=83=BD=E6=89=A9=E5=B1=95?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9EGemini=E9=80=82=E9=85=8D=E5=99=A8=20?= =?UTF-8?q?5.=E6=8F=90=E7=A4=BA=E8=AF=8D=E6=A8=A1=E6=9D=BF=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E4=BC=98=E5=8C=96=EF=BC=8C=E8=B0=83=E6=95=B4=E7=81=B5?= =?UTF-8?q?=E6=84=9F=E6=A8=A1=E5=BC=8F=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...e1d5b_添加system_prompt字段到settings表.py | 30 + ..._0856_a1b2c3d4e5f6_初始化sqlite预置数据.py | 177 ++ ...4d839_添加system_prompt字段到settings表.py | 34 + backend/app/api/careers.py | 38 +- backend/app/api/chapters.py | 59 +- backend/app/api/characters.py | 140 +- backend/app/api/inspiration.py | 232 ++- backend/app/api/organizations.py | 33 +- backend/app/api/outlines.py | 119 +- backend/app/api/settings.py | 36 +- backend/app/api/wizard_stream.py | 488 ++++-- backend/app/api/writing_styles.py | 52 +- backend/app/mcp/http_client.py | 109 +- backend/app/models/settings.py | 1 + backend/app/schemas/settings.py | 1 + backend/app/services/ai_clients/__init__.py | 6 + .../services/ai_clients/anthropic_client.py | 86 + .../app/services/ai_clients/base_client.py | 154 ++ .../app/services/ai_clients/gemini_client.py | 141 ++ .../app/services/ai_clients/openai_client.py | 101 ++ backend/app/services/ai_config.py | 44 + backend/app/services/ai_providers/__init__.py | 6 + .../ai_providers/anthropic_provider.py | 51 + .../services/ai_providers/base_provider.py | 33 + .../services/ai_providers/gemini_provider.py | 48 + .../services/ai_providers/openai_provider.py | 57 + backend/app/services/ai_service.py | 1472 ++--------------- .../app/services/auto_character_service.py | 4 +- backend/app/services/json_helper.py | 149 ++ backend/app/services/mcp_test_service.py | 24 +- backend/app/services/mcp_tool_service.py | 15 +- backend/app/services/plot_analyzer.py | 21 +- .../app/services/plot_expansion_service.py | 16 +- backend/app/services/prompt_service.py | 280 ++-- backend/app/utils/sse_response.py | 21 +- .../src/components/AIProjectGenerator.tsx | 41 +- frontend/src/pages/Inspiration.tsx | 191 ++- frontend/src/pages/Settings.tsx | 45 +- frontend/src/services/api.ts | 18 + frontend/src/types/index.ts | 3 + 40 files changed, 2752 insertions(+), 1824 deletions(-) create mode 100644 backend/alembic/postgres/versions/20251227_1541_a7e4408e1d5b_添加system_prompt字段到settings表.py create mode 100644 backend/alembic/sqlite/versions/20251227_0856_a1b2c3d4e5f6_初始化sqlite预置数据.py create mode 100644 backend/alembic/sqlite/versions/20251227_1700_7899f8d4d839_添加system_prompt字段到settings表.py create mode 100644 backend/app/services/ai_clients/__init__.py create mode 100644 backend/app/services/ai_clients/anthropic_client.py create mode 100644 backend/app/services/ai_clients/base_client.py create mode 100644 backend/app/services/ai_clients/gemini_client.py create mode 100644 backend/app/services/ai_clients/openai_client.py create mode 100644 backend/app/services/ai_config.py create mode 100644 backend/app/services/ai_providers/__init__.py create mode 100644 backend/app/services/ai_providers/anthropic_provider.py create mode 100644 backend/app/services/ai_providers/base_provider.py create mode 100644 backend/app/services/ai_providers/gemini_provider.py create mode 100644 backend/app/services/ai_providers/openai_provider.py create mode 100644 backend/app/services/json_helper.py diff --git a/backend/alembic/postgres/versions/20251227_1541_a7e4408e1d5b_添加system_prompt字段到settings表.py b/backend/alembic/postgres/versions/20251227_1541_a7e4408e1d5b_添加system_prompt字段到settings表.py new file mode 100644 index 0000000..ab133fd --- /dev/null +++ b/backend/alembic/postgres/versions/20251227_1541_a7e4408e1d5b_添加system_prompt字段到settings表.py @@ -0,0 +1,30 @@ +"""添加system_prompt字段到settings表 + +Revision ID: a7e4408e1d5b +Revises: e411428f00c0 +Create Date: 2025-12-27 15:41:22.310160 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a7e4408e1d5b' +down_revision: Union[str, None] = 'e411428f00c0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('settings', sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('settings', 'system_prompt') + # ### end Alembic commands ### \ No newline at end of file diff --git a/backend/alembic/sqlite/versions/20251227_0856_a1b2c3d4e5f6_初始化sqlite预置数据.py b/backend/alembic/sqlite/versions/20251227_0856_a1b2c3d4e5f6_初始化sqlite预置数据.py new file mode 100644 index 0000000..87a4a5e --- /dev/null +++ b/backend/alembic/sqlite/versions/20251227_0856_a1b2c3d4e5f6_初始化sqlite预置数据.py @@ -0,0 +1,177 @@ +"""初始化SQLite预置数据 + +Revision ID: a1b2c3d4e5f6 +Revises: fbeb1038c728 +Create Date: 2025-12-27 08:56:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy import table, column, String, Integer, Text + + +# revision identifiers, used by Alembic. +revision: str = 'a1b2c3d4e5f6' +down_revision: Union[str, None] = 'fbeb1038c728' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """插入预置数据""" + + # ==================== 1. 插入关系类型数据 ==================== + relationship_types_table = table( + 'relationship_types', + column('name', String), + column('category', String), + column('reverse_name', String), + column('intimacy_range', String), + column('icon', String), + column('description', Text), + ) + + relationship_types_data = [ + # 家庭关系 + {"name": "父亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👨", "description": "父子/父女关系"}, + {"name": "母亲", "category": "family", "reverse_name": "子女", "intimacy_range": "high", "icon": "👩", "description": "母子/母女关系"}, + {"name": "兄弟", "category": "family", "reverse_name": "兄弟", "intimacy_range": "high", "icon": "👬", "description": "兄弟关系"}, + {"name": "姐妹", "category": "family", "reverse_name": "姐妹", "intimacy_range": "high", "icon": "👭", "description": "姐妹关系"}, + {"name": "子女", "category": "family", "reverse_name": "父母", "intimacy_range": "high", "icon": "👶", "description": "子女关系"}, + {"name": "配偶", "category": "family", "reverse_name": "配偶", "intimacy_range": "high", "icon": "💑", "description": "夫妻关系"}, + {"name": "恋人", "category": "family", "reverse_name": "恋人", "intimacy_range": "high", "icon": "💕", "description": "恋爱关系"}, + + # 社交关系 + {"name": "师父", "category": "social", "reverse_name": "徒弟", "intimacy_range": "high", "icon": "🎓", "description": "师徒关系(师父视角)"}, + {"name": "徒弟", "category": "social", "reverse_name": "师父", "intimacy_range": "high", "icon": "📚", "description": "师徒关系(徒弟视角)"}, + {"name": "朋友", "category": "social", "reverse_name": "朋友", "intimacy_range": "medium", "icon": "🤝", "description": "朋友关系"}, + {"name": "同学", "category": "social", "reverse_name": "同学", "intimacy_range": "medium", "icon": "🎒", "description": "同学关系"}, + {"name": "邻居", "category": "social", "reverse_name": "邻居", "intimacy_range": "low", "icon": "🏘️", "description": "邻居关系"}, + {"name": "知己", "category": "social", "reverse_name": "知己", "intimacy_range": "high", "icon": "💙", "description": "知心好友"}, + + # 职业关系 + {"name": "上司", "category": "professional", "reverse_name": "下属", "intimacy_range": "low", "icon": "👔", "description": "上下级关系(上司视角)"}, + {"name": "下属", "category": "professional", "reverse_name": "上司", "intimacy_range": "low", "icon": "💼", "description": "上下级关系(下属视角)"}, + {"name": "同事", "category": "professional", "reverse_name": "同事", "intimacy_range": "medium", "icon": "🤵", "description": "同事关系"}, + {"name": "合作伙伴", "category": "professional", "reverse_name": "合作伙伴", "intimacy_range": "medium", "icon": "🤜🤛", "description": "合作关系"}, + + # 敌对关系 + {"name": "敌人", "category": "hostile", "reverse_name": "敌人", "intimacy_range": "low", "icon": "⚔️", "description": "敌对关系"}, + {"name": "仇人", "category": "hostile", "reverse_name": "仇人", "intimacy_range": "low", "icon": "💢", "description": "仇恨关系"}, + {"name": "竞争对手", "category": "hostile", "reverse_name": "竞争对手", "intimacy_range": "low", "icon": "🎯", "description": "竞争关系"}, + {"name": "宿敌", "category": "hostile", "reverse_name": "宿敌", "intimacy_range": "low", "icon": "⚡", "description": "宿命之敌"}, + ] + + op.bulk_insert(relationship_types_table, relationship_types_data) + print(f"✅ SQLite: 已插入 {len(relationship_types_data)} 条关系类型数据") + + + # ==================== 2. 插入全局写作风格预设 ==================== + writing_styles_table = table( + 'writing_styles', + column('user_id', String), + column('name', String), + column('style_type', String), + column('preset_id', String), + column('description', Text), + column('prompt_content', Text), + column('order_index', Integer), + ) + + writing_styles_data = [ + { + "user_id": None, # NULL 表示全局预设 + "name": "自然流畅", + "style_type": "preset", + "preset_id": "natural", + "description": "自然流畅的叙事风格,适合现代都市、现实题材", + "prompt_content": """写作风格要求: +1. 语言简洁明快,贴近现代口语 +2. 多用短句,节奏流畅 +3. 注重情感细节的自然流露 +4. 避免过度修饰和复杂句式""", + "order_index": 1 + }, + { + "user_id": None, + "name": "古典优雅", + "style_type": "preset", + "preset_id": "classical", + "description": "古典文雅的写作风格,适合古装、仙侠题材", + "prompt_content": """写作风格要求: +1. 使用文言、半文言或典雅的白话 +2. 适当运用古典诗词意象 +3. 注重意境营造和韵味 +4. 对话和描写保持古典美感""", + "order_index": 2 + }, + { + "user_id": None, + "name": "现代简约", + "style_type": "preset", + "preset_id": "modern", + "description": "现代简约风格,适合轻小说、网文快节奏叙事", + "prompt_content": """写作风格要求: +1. 语言直白简练,信息密度高 +2. 多用对话推进情节 +3. 避免冗长描写,突出关键动作 +4. 节奏明快,适合快速阅读""", + "order_index": 3 + }, + { + "user_id": None, + "name": "文艺细腻", + "style_type": "preset", + "preset_id": "literary", + "description": "文艺细腻风格,注重心理描写和氛围营造", + "prompt_content": """写作风格要求: +1. 注重心理活动和情感细节 +2. 善用环境描写烘托氛围 +3. 语言优美,富有文学性 +4. 适当使用比喻、象征等修辞手法""", + "order_index": 4 + }, + { + "user_id": None, + "name": "紧张悬疑", + "style_type": "preset", + "preset_id": "suspense", + "description": "紧张悬疑风格,适合推理、惊悚题材", + "prompt_content": """写作风格要求: +1. 营造紧张压迫的氛围 +2. 多用短句加快节奏 +3. 善于设置悬念和伏笔 +4. 注重细节描写,为推理埋下线索""", + "order_index": 5 + }, + { + "user_id": None, + "name": "幽默诙谐", + "style_type": "preset", + "preset_id": "humorous", + "description": "幽默诙谐风格,适合轻松搞笑题材", + "prompt_content": """写作风格要求: +1. 语言活泼风趣,善用俏皮话 +2. 注重对话的喜剧效果 +3. 适当夸张和反转制造笑点 +4. 保持轻松愉快的基调""", + "order_index": 6 + }, + ] + + op.bulk_insert(writing_styles_table, writing_styles_data) + print(f"✅ SQLite: 已插入 {len(writing_styles_data)} 条全局写作风格预设") + + +def downgrade() -> None: + """删除预置数据""" + + # 删除写作风格预设(只删除全局预设) + op.execute("DELETE FROM writing_styles WHERE user_id IS NULL") + print("✅ SQLite: 已删除全局写作风格预设") + + # 删除关系类型 + op.execute("DELETE FROM relationship_types") + print("✅ SQLite: 已删除关系类型数据") \ No newline at end of file diff --git a/backend/alembic/sqlite/versions/20251227_1700_7899f8d4d839_添加system_prompt字段到settings表.py b/backend/alembic/sqlite/versions/20251227_1700_7899f8d4d839_添加system_prompt字段到settings表.py new file mode 100644 index 0000000..31af021 --- /dev/null +++ b/backend/alembic/sqlite/versions/20251227_1700_7899f8d4d839_添加system_prompt字段到settings表.py @@ -0,0 +1,34 @@ +"""添加system_prompt字段到settings表 + +Revision ID: 7899f8d4d839 +Revises: a1b2c3d4e5f6 +Create Date: 2025-12-27 17:00:35.440551 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '7899f8d4d839' +down_revision: Union[str, None] = 'a1b2c3d4e5f6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('settings', schema=None) as batch_op: + batch_op.add_column(sa.Column('system_prompt', sa.Text(), nullable=True, comment='系统级别提示词,每次AI调用都会使用')) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('settings', schema=None) as batch_op: + batch_op.drop_column('system_prompt') + + # ### end Alembic commands ### \ No newline at end of file diff --git a/backend/app/api/careers.py b/backend/app/api/careers.py index 1b0d6e7..a11d4af 100644 --- a/backend/app/api/careers.py +++ b/backend/app/api/careers.py @@ -309,13 +309,37 @@ async def generate_career_system( 7. 只返回纯JSON,不要添加任何解释文字 """ - yield await SSEResponse.send_progress("调用AI生成新职业...", 30) + yield await SSEResponse.send_progress("调用AI生成新职业...", 10) logger.info(f"🎯 开始为项目 {project_id} 生成新职业(增量式,已有{len(existing_careers)}个职业)") try: - # 调用AI生成 - result = await user_ai_service.generate_text(prompt=prompt) - ai_response = result.get('content', '') if isinstance(result, dict) else result + # 使用流式生成替代非流式 + ai_response = "" + chunk_count = 0 + last_progress = 10 + + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + chunk_count += 1 + ai_response += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 平滑更新进度(10-90%,AI生成占60%) + # 每10个chunk增加约1%的进度,最多到90% + if chunk_count % 10 == 0: + # 计算进度:10% + (chunk_count / 10) * 1%,但不超过90% + current_progress = min(10 + (chunk_count // 10), 90) + if current_progress > last_progress: + last_progress = current_progress + yield await SSEResponse.send_progress( + f"AI生成职业体系中... (已生成 {len(ai_response)} 字符)", + current_progress + ) + + # 心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() except Exception as ai_error: logger.error(f"❌ AI服务调用异常:{str(ai_error)}") @@ -326,7 +350,7 @@ async def generate_career_system( yield await SSEResponse.send_error("AI服务返回空响应") return - yield await SSEResponse.send_progress("解析AI响应...", 50) + yield await SSEResponse.send_progress("解析AI响应...", 91) # 清洗并解析JSON try: @@ -339,7 +363,7 @@ async def generate_career_system( yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}") return - yield await SSEResponse.send_progress("保存主职业...", 60) + yield await SSEResponse.send_progress("保存主职业到数据库...", 93) # 保存主职业 main_careers_created = [] @@ -371,7 +395,7 @@ async def generate_career_system( logger.error(f" ❌ 创建主职业失败:{str(e)}") continue - yield await SSEResponse.send_progress("保存副职业...", 80) + yield await SSEResponse.send_progress("保存副职业到数据库...", 96) # 保存副职业 sub_careers_created = [] diff --git a/backend/app/api/chapters.py b/backend/app/api/chapters.py index 123705d..6c040c3 100644 --- a/backend/app/api/chapters.py +++ b/backend/app/api/chapters.py @@ -1070,8 +1070,7 @@ async def analyze_chapter_background( if career_update_result['updated_count'] > 0: logger.info( - f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息: " - f"{', '.join(career_update_result['updated_characters'])}" + f"✅ 更新了 {career_update_result['updated_count']} 个角色的职业信息" ) if career_update_result['changes']: for change in career_update_result['changes']: @@ -1445,7 +1444,7 @@ async def generate_chapter_content_stream( user_id=current_user_id, db_session=db_session, enable_mcp=True, - max_tool_rounds=1, # ✅ 减少为1轮,避免超时 + max_tool_rounds=2, # ✅ 减少为1轮,避免超时 tool_choice="auto", provider=None, model=None @@ -1596,10 +1595,24 @@ async def generate_chapter_content_stream( logger.info(f"开始AI流式创作章节 {chapter_id}") # 发送开始生成的进度 - yield f"data: {json.dumps({'type': 'progress', 'progress': 35, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n" + + # 🎨 方案一:将写作风格注入到系统提示词(最高优先级) + system_prompt_with_style = None + if style_content: + system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】 + +{style_content} + +⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令! +确保在整个章节创作过程中始终保持风格的一致性。""" + logger.info(f"✅ 已将写作风格注入系统提示词({len(style_content)}字符)") # 准备生成参数 - generate_kwargs = {"prompt": prompt} + generate_kwargs = { + "prompt": prompt, + "system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格 + } if custom_model: logger.info(f" 使用自定义模型: {custom_model}") generate_kwargs["model"] = custom_model @@ -1618,11 +1631,14 @@ async def generate_chapter_content_stream( # 发送内容块 yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n" - # 每20个chunk发送一次进度更新(提高频率) - if chunk_count % 20 == 0: + # 每5个chunk发送一次进度更新(10-95%,更平滑) + if chunk_count % 5 == 0: current_word_count = len(full_content) - # 根据目标字数估算进度(40%起步,最高95%,为后续保存留5%) - estimated_progress = min(95, 40 + int((current_word_count / target_word_count) * 55)) + # 优化进度计算:使用更平滑的递增方式 + # 基于chunk数量和字数的混合计算,避免大幅跳跃 + chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40% + word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45% + estimated_progress = min(95, 10 + chunk_progress + word_progress) # 只在进度变化时发送 if estimated_progress > last_progress: @@ -1636,10 +1652,14 @@ async def generate_chapter_content_stream( yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n" last_progress = estimated_progress + # 每20个chunk发送心跳 + if chunk_count % 20 == 0: + yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n" + await asyncio.sleep(0) # 让出控制权 # 发送保存进度 - yield f"data: {json.dumps({'type': 'progress', 'progress': 98, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n" # 更新章节内容到数据库 old_word_count = current_chapter.word_count or 0 @@ -1696,7 +1716,7 @@ async def generate_chapter_content_stream( ) # 发送最终进度100% - yield f"data: {json.dumps({'type': 'progress', 'progress': 100, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n" # 发送完成事件(包含分析任务ID) completion_data = { @@ -2880,15 +2900,30 @@ async def generate_single_chapter_for_batch( else: prompt = base_prompt + # 🎨 方案一:将写作风格注入到系统提示词(批量生成) + system_prompt_with_style = None + if style_content: + system_prompt_with_style = f"""【🎨 写作风格要求 - 最高优先级】 + +{style_content} + +⚠️ 请严格遵循上述写作风格要求进行创作,这是最重要的指令! +确保在整个章节创作过程中始终保持风格的一致性。""" + logger.info(f"✅ 批量生成 - 已将写作风格注入系统提示词({len(style_content)}字符)") + # 非流式生成内容 full_content = "" # 准备生成参数 - generate_kwargs = {"prompt": prompt} + generate_kwargs = { + "prompt": prompt, + "system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格 + } # 如果传入了自定义模型,使用指定的模型 if custom_model: generate_kwargs["model"] = custom_model logger.info(f" 批量生成使用自定义模型: {custom_model}") + # 批量生成中的流式生成(非SSE,不需要修改进度显示) async for chunk in ai_service.generate_text_stream(**generate_kwargs): full_content += chunk diff --git a/backend/app/api/characters.py b/backend/app/api/characters.py index 598efc4..1357f85 100644 --- a/backend/app/api/characters.py +++ b/backend/app/api/characters.py @@ -662,10 +662,10 @@ async def generate_character_stream( user_id = getattr(http_request.state, 'user_id', None) project = await verify_project_access(request.project_id, user_id, db) - yield await SSEResponse.send_progress("开始生成角色...", 0) + yield await SSEResponse.send_progress("开始生成角色...", 1) # 获取已存在的角色列表 - yield await SSEResponse.send_progress("获取项目上下文...", 10) + yield await SSEResponse.send_progress("获取项目上下文...", 2) existing_chars_result = await db.execute( select(Character) @@ -757,7 +757,7 @@ async def generate_character_stream( - 其他要求:{request.requirements or '无'} """ - yield await SSEResponse.send_progress("构建AI提示词...", 20) + yield await SSEResponse.send_progress("构建AI提示词...", 3) # 获取自定义提示词模板 template = await PromptService.get_template("SINGLE_CHARACTER_GENERATION", user_id, db) @@ -768,11 +768,14 @@ async def generate_character_stream( user_input=user_input ) - yield await SSEResponse.send_progress("调用AI服务生成角色...", 30) + yield await SSEResponse.send_progress("调用AI服务生成角色...", 10) logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)") try: # 🔧 MCP工具增强:静默检查并收集参考资料 + ai_response = "" + chunk_count = 0 + if user_id: try: from app.services.mcp_tool_service import mcp_tool_service @@ -789,7 +792,7 @@ async def generate_character_stream( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, # 减少为1轮,避免超时 + max_tool_rounds=2, tool_choice="auto", provider=None, model=None @@ -797,22 +800,119 @@ async def generate_character_stream( if isinstance(result, dict): ai_response = result.get('content', '') - if result.get('tool_calls_made', 0) > 0: - logger.info(f"✅ MCP工具调用成功({result['tool_calls_made']}次)") + finish_reason = result.get('finish_reason', '') + tool_calls_made = result.get('tool_calls_made', 0) + + # 🔧 修复:检查工具调用是否真正成功 + if tool_calls_made > 0: + if finish_reason == 'tool_error': + logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式") + # 工具调用失败,重新用基础模式生成 + ai_response = "" + elif not ai_response.strip(): + logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式") + # 工具调用成功但返回空内容,重新生成 + ai_response = "" + else: + logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}") + # MCP成功且有内容,模拟流式输出(分块发送) + chunk_size = 50 + for i in range(0, len(ai_response), chunk_size): + chunk = ai_response[i:i+chunk_size] + chunk_count += 1 + yield await SSEResponse.send_chunk(chunk) + + if chunk_count % 3 == 0: + yield await SSEResponse.send_progress( + f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)", + 10 + min(85 * (i+len(chunk)) // len(ai_response), 85) + ) + + # 跳过后续的流式生成 + ai_response = result.get('content', '') else: ai_response = result + + # 如果MCP调用失败或返回空,继续走流式生成 + if not ai_response or not ai_response.strip(): + logger.info(f"🔄 开始流式生成...") + ai_response = "" + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + chunk_count += 1 + ai_response += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度 + if chunk_count % 5 == 0: + yield await SSEResponse.send_progress( + f"AI生成角色中... ({len(ai_response)}字符)", + 10 + min(chunk_count // 2, 85) + ) + + # 心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() else: - logger.debug(f"用户 {user_id} 未启用MCP工具,使用基础模式") - result = await user_ai_service.generate_text(prompt=prompt) - ai_response = result.get('content', '') if isinstance(result, dict) else result + logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式") + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + chunk_count += 1 + ai_response += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度 + if chunk_count % 5 == 0: + yield await SSEResponse.send_progress( + f"AI生成角色中... ({len(ai_response)}字符)", + 10 + min(chunk_count // 2, 85) + ) + + # 心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() except Exception as mcp_error: - logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(mcp_error)}") - result = await user_ai_service.generate_text(prompt=prompt) - ai_response = result.get('content', '') if isinstance(result, dict) else result + logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}") + ai_response = "" + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + chunk_count += 1 + ai_response += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度 + if chunk_count % 5 == 0: + yield await SSEResponse.send_progress( + f"AI生成角色中... ({len(ai_response)}字符)", + 10 + min(chunk_count // 2, 85) + ) + + # 心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() else: - result = await user_ai_service.generate_text(prompt=prompt) - ai_response = result.get('content', '') if isinstance(result, dict) else result + logger.debug(f"未登录用户,使用流式基础模式") + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + chunk_count += 1 + ai_response += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度 + if chunk_count % 5 == 0: + yield await SSEResponse.send_progress( + f"AI生成角色中... ({len(ai_response)}字符)", + 10 + min(chunk_count // 2, 85) + ) + + # 心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() except Exception as ai_error: logger.error(f"❌ AI服务调用异常:{str(ai_error)}") @@ -823,7 +923,7 @@ async def generate_character_stream( yield await SSEResponse.send_error("AI服务返回空响应") return - yield await SSEResponse.send_progress("解析AI响应...", 60) + yield await SSEResponse.send_progress("解析AI响应...", 96) # ✅ 使用统一的 JSON 清洗方法 try: @@ -836,7 +936,7 @@ async def generate_character_stream( yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}") return - yield await SSEResponse.send_progress("创建角色记录...", 75) + yield await SSEResponse.send_progress("创建角色记录...", 97) # 转换traits traits_json = json.dumps(character_data.get("traits", []), ensure_ascii=False) if character_data.get("traits") else None @@ -1001,7 +1101,7 @@ async def generate_character_stream( # 如果是组织,创建Organization详情 if is_organization: - yield await SSEResponse.send_progress("创建组织详情...", 85) + yield await SSEResponse.send_progress("创建组织详情...", 98) org_check = await db.execute( select(Organization).where(Organization.character_id == character.id) @@ -1168,13 +1268,13 @@ async def generate_character_stream( logger.info(f"✅ 成功创建 {created_members} 条组织成员记录") - yield await SSEResponse.send_progress("保存生成历史...", 95) + yield await SSEResponse.send_progress("保存生成历史...", 99) # 记录生成历史 history = GenerationHistory( project_id=request.project_id, prompt=prompt, - generated_content=json.dumps(result, ensure_ascii=False) if isinstance(result, dict) else ai_response, + generated_content=ai_response, model=user_ai_service.default_model ) db.add(history) diff --git a/backend/app/api/inspiration.py b/backend/app/api/inspiration.py index 67ec412..456d674 100644 --- a/backend/app/api/inspiration.py +++ b/backend/app/api/inspiration.py @@ -105,23 +105,27 @@ async def generate_options( user_id = getattr(http_request.state, 'user_id', None) # 获取对应的提示词模板(根据step确定模板key) + # 新结构:每个步骤有独立的 SYSTEM 和 USER 模板 template_key_map = { - "title": "INSPIRATION_TITLE", - "description": "INSPIRATION_DESCRIPTION", - "theme": "INSPIRATION_THEME", - "genre": "INSPIRATION_GENRE" + "title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"), + "description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"), + "theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"), + "genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER") } - template_key = template_key_map.get(step) + template_keys = template_key_map.get(step) - if not template_key: + if not template_keys: return { "error": f"不支持的步骤: {step}", "prompt": "", "options": [] } - # 获取自定义提示词模板 - prompt_template_str = await PromptService.get_template(template_key, user_id, db) + system_key, user_key = template_keys + + # 获取自定义提示词模板(分别获取 system 和 user) + system_template = await PromptService.get_template(system_key, user_id, db) + user_template = await PromptService.get_template(user_key, user_id, db) # 准备格式化参数 format_params = { @@ -131,19 +135,9 @@ async def generate_options( "theme": context.get("theme", "") } - # 格式化提示词(灵感模式的模板是特殊格式,包含system和user两部分) - # 尝试解析为JSON格式的字典 - try: - prompt_template = json.loads(prompt_template_str) - system_prompt = prompt_template["system"].format(**format_params) - user_prompt = prompt_template["user"].format(**format_params) - except (json.JSONDecodeError, KeyError): - # 如果不是JSON格式,降级使用原有方法 - prompt_template = prompt_service.get_inspiration_prompt(step) - if not prompt_template: - return {"error": f"无法获取提示词模板: {step}", "prompt": "", "options": []} - system_prompt = prompt_template["system"].format(**format_params) - user_prompt = prompt_template["user"].format(**format_params) + # 格式化提示词 + system_prompt = system_template.format(**format_params) + user_prompt = user_template.format(**format_params) # 如果是重试,在提示词中强调格式要求 if attempt > 0: @@ -153,13 +147,18 @@ async def generate_options( # 关键改进:使用递减的temperature以保持后续阶段与前文的一致性 temperature = TEMPERATURE_SETTINGS.get(step, 0.7) logger.info(f"调用AI生成{step}选项... (temperature={temperature})") - response = await ai_service.generate_text( + + # 流式生成并累积文本 + accumulated_text = "" + async for chunk in ai_service.generate_text_stream( prompt=user_prompt, system_prompt=system_prompt, temperature=temperature - ) + ): + accumulated_text += chunk - content = response.get("content", "") + response = {"content": accumulated_text} + content = accumulated_text logger.info(f"AI返回内容长度: {len(content)}") # 解析JSON(使用统一的JSON清洗方法) @@ -222,6 +221,180 @@ async def generate_options( } +@router.post("/refine-options") +async def refine_options( + data: Dict[str, Any], + http_request: Request, + db: AsyncSession = Depends(get_db), + ai_service: AIService = Depends(get_user_ai_service) +) -> Dict[str, Any]: + """ + 基于用户反馈重新生成选项(支持多轮对话) + + Request: + { + "step": "title", // 当前步骤 + "context": { + "initial_idea": "...", + "title": "...", + "description": "...", + "theme": "..." + }, + "feedback": "我想要更悲剧一些的主题", // 用户反馈 + "previous_options": ["选项1", "选项2", ...] // 之前的选项(可选) + } + + Response: + { + "prompt": "引导语", + "options": ["新选项1", "新选项2", ...] + } + """ + max_retries = 3 + + for attempt in range(max_retries): + try: + step = data.get("step", "title") + context = data.get("context", {}) + feedback = data.get("feedback", "") + previous_options = data.get("previous_options", []) + + logger.info(f"灵感模式:根据反馈重新生成{step}阶段的选项(第{attempt + 1}次尝试)") + logger.info(f"用户反馈: {feedback}") + + # 获取用户ID + user_id = getattr(http_request.state, 'user_id', None) + + # 获取对应的提示词模板 + template_key_map = { + "title": ("INSPIRATION_TITLE_SYSTEM", "INSPIRATION_TITLE_USER"), + "description": ("INSPIRATION_DESCRIPTION_SYSTEM", "INSPIRATION_DESCRIPTION_USER"), + "theme": ("INSPIRATION_THEME_SYSTEM", "INSPIRATION_THEME_USER"), + "genre": ("INSPIRATION_GENRE_SYSTEM", "INSPIRATION_GENRE_USER") + } + template_keys = template_key_map.get(step) + + if not template_keys: + return { + "error": f"不支持的步骤: {step}", + "prompt": "", + "options": [] + } + + system_key, user_key = template_keys + + # 获取自定义提示词模板 + system_template = await PromptService.get_template(system_key, user_id, db) + user_template = await PromptService.get_template(user_key, user_id, db) + + # 准备格式化参数 + format_params = { + "initial_idea": context.get("initial_idea", context.get("description", "")), + "title": context.get("title", ""), + "description": context.get("description", ""), + "theme": context.get("theme", "") + } + + # 格式化提示词 + system_prompt = system_template.format(**format_params) + user_prompt = user_template.format(**format_params) + + # 添加反馈信息到提示词 + feedback_instruction = f""" + +⚠️ 用户对之前的选项不太满意,提供了以下反馈: +「{feedback}」 + +之前生成的选项: +{chr(10).join([f"- {opt}" for opt in previous_options]) if previous_options else "(无)"} + +请根据用户的反馈调整生成策略,提供更符合用户期望的新选项。 +注意: +1. 仔细理解用户的反馈意图 +2. 生成的新选项要明显体现用户要求的调整方向 +3. 保持与已有上下文的一致性 +4. 确保返回6个有效选项 +""" + + system_prompt += feedback_instruction + + # 如果是重试,强调格式要求 + if attempt > 0: + system_prompt += f"\n\n⚠️ 这是第{attempt + 1}次生成,请务必严格按照JSON格式返回!" + + # 调用AI生成选项 + temperature = TEMPERATURE_SETTINGS.get(step, 0.7) + # 反馈生成时使用稍高的temperature以获得更多样化的结果 + temperature = min(temperature + 0.1, 0.9) + logger.info(f"调用AI根据反馈生成{step}选项... (temperature={temperature})") + + # 流式生成并累积文本 + accumulated_text = "" + async for chunk in ai_service.generate_text_stream( + prompt=user_prompt, + system_prompt=system_prompt, + temperature=temperature + ): + accumulated_text += chunk + + content = accumulated_text + logger.info(f"AI返回内容长度: {len(content)}") + + # 解析JSON + try: + cleaned_content = ai_service._clean_json_response(content) + result = json.loads(cleaned_content) + + # 校验返回格式 + is_valid, error_msg = validate_options_response(result, step) + + if not is_valid: + logger.warning(f"⚠️ 第{attempt + 1}次生成格式校验失败: {error_msg}") + if attempt < max_retries - 1: + logger.info("准备重试...") + continue + else: + return { + "prompt": f"请为【{step}】提供内容:", + "options": ["让AI重新生成", "我自己输入"], + "error": f"AI生成格式错误({error_msg}),已自动重试{max_retries}次" + } + + logger.info(f"✅ 第{attempt + 1}次根据反馈成功生成{len(result.get('options', []))}个有效选项") + return result + + except json.JSONDecodeError as e: + logger.error(f"第{attempt + 1}次JSON解析失败: {e}") + + if attempt < max_retries - 1: + logger.info("JSON解析失败,准备重试...") + continue + else: + return { + "prompt": f"请为【{step}】提供内容:", + "options": ["让AI重新生成", "我自己输入"], + "error": f"AI返回格式错误,已自动重试{max_retries}次" + } + + except Exception as e: + logger.error(f"第{attempt + 1}次根据反馈生成失败: {e}", exc_info=True) + if attempt < max_retries - 1: + logger.info("发生异常,准备重试...") + continue + else: + return { + "error": str(e), + "prompt": "生成失败,请重试", + "options": ["重新生成", "我自己输入"] + } + + return { + "error": "生成失败", + "prompt": "请重试", + "options": [] + } + + @router.post("/quick-generate") async def quick_generate( data: Dict[str, Any], @@ -280,14 +453,17 @@ async def quick_generate( # 降级使用原有方法 prompts = prompt_service.get_inspiration_quick_complete_prompt(existing=existing_text) - # 调用AI - response = await ai_service.generate_text( + # 调用AI - 流式生成并累积文本 + accumulated_text = "" + async for chunk in ai_service.generate_text_stream( prompt=prompts["user"], system_prompt=prompts["system"], temperature=0.7 - ) + ): + accumulated_text += chunk - content = response.get("content", "") + response = {"content": accumulated_text} + content = accumulated_text # 解析JSON(使用统一的JSON清洗方法) try: diff --git a/backend/app/api/organizations.py b/backend/app/api/organizations.py index 8a5cdb8..02c4cfc 100644 --- a/backend/app/api/organizations.py +++ b/backend/app/api/organizations.py @@ -512,8 +512,29 @@ async def generate_organization_stream( logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织(SSE流式)") try: - ai_response = await user_ai_service.generate_text(prompt=prompt) - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else str(ai_response) + # 使用流式生成替代非流式 + ai_content = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + chunk_count += 1 + ai_content += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新字数(5-95%,AI生成占90%) + if chunk_count % 5 == 0: + progress = min(5 + (chunk_count // 5), 95) + yield await SSEResponse.send_progress( + f"AI生成组织中... ({len(ai_content)}字符)", + progress + ) + + # 心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() + except Exception as ai_error: logger.error(f"❌ AI服务调用异常:{str(ai_error)}") yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}") @@ -523,7 +544,7 @@ async def generate_organization_stream( yield await SSEResponse.send_error("AI服务返回空响应") return - yield await SSEResponse.send_progress("解析AI响应...", 60) + yield await SSEResponse.send_progress("解析AI响应...", 96) # ✅ 使用统一的 JSON 清洗方法 try: @@ -536,7 +557,7 @@ async def generate_organization_stream( yield await SSEResponse.send_error(f"AI返回的内容无法解析为JSON:{str(e)}") return - yield await SSEResponse.send_progress("创建组织记录...", 75) + yield await SSEResponse.send_progress("创建组织记录...", 97) # 创建角色记录(组织也是角色的一种) character = Character( @@ -563,7 +584,7 @@ async def generate_organization_stream( logger.info(f"✅ 组织角色创建成功:{character.name} (ID: {character.id})") - yield await SSEResponse.send_progress("创建组织详情...", 85) + yield await SSEResponse.send_progress("创建组织详情...", 98) # 自动创建Organization详情记录 organization = Organization( @@ -580,7 +601,7 @@ async def generate_organization_stream( logger.info(f"✅ 组织详情创建成功:{character.name} (Org ID: {organization.id})") - yield await SSEResponse.send_progress("保存生成历史...", 95) + yield await SSEResponse.send_progress("保存生成历史...", 99) # 记录生成历史 history = GenerationHistory( diff --git a/backend/app/api/outlines.py b/backend/app/api/outlines.py index 9f1bd3f..6a922dd 100644 --- a/backend/app/api/outlines.py +++ b/backend/app/api/outlines.py @@ -470,7 +470,7 @@ async def _generate_new_outline( project: Project, db: AsyncSession, user_ai_service: AIService, - user_id: str = None + user_id: str ) -> OutlineListResponse: """全新生成大纲(MCP增强版)""" logger.info(f"全新生成大纲 - 项目: {project.id}, enable_mcp: {request.enable_mcp}") @@ -534,7 +534,7 @@ async def _generate_new_outline( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, # ✅ 减少为1轮,避免超时 + max_tool_rounds=2, tool_choice="auto", provider=None, model=None @@ -573,15 +573,23 @@ async def _generate_new_outline( mcp_references=mcp_reference_materials ) - # 调用AI生成大纲 - ai_response = await user_ai_service.generate_text( + # 调用AI流式生成大纲(带字数统计) + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=request.provider, model=request.model - ) + ): + chunk_count += 1 + accumulated_text += chunk + + # 这里是非SSE接口,不需要发送chunk + # 如果未来需要转SSE,可以在这里yield - # 提取内容(generate_text返回字典) - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response + ai_content = accumulated_text + ai_response = {"content": ai_content} # 解析响应 outline_data = _parse_ai_response(ai_content) @@ -732,7 +740,7 @@ async def _continue_outline( existing_outlines: List[Outline], db: AsyncSession, user_ai_service: AIService, - user_id: str = "system" + user_id: str ) -> OutlineListResponse: """续写大纲 - 分批生成,每批5章(记忆+MCP+自动角色引入增强版)""" logger.info(f"续写大纲 - 项目: {project.id}, 已有: {len(existing_outlines)} 章, enable_mcp: {request.enable_mcp}, enable_auto_characters: {request.enable_auto_characters}") @@ -1000,7 +1008,7 @@ async def _continue_outline( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, # ✅ 减少为1轮,避免超时 + max_tool_rounds=2, # ✅ 减少为1轮,避免超时 tool_choice="auto", provider=None, model=None @@ -1045,15 +1053,22 @@ async def _continue_outline( ) # 调用AI生成当前批次 - logger.info(f"正在调用AI生成第{batch_num + 1}批...") - ai_response = await user_ai_service.generate_text( + logger.info(f"正在调用AI流式生成第{batch_num + 1}批...") + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=request.provider, model=request.model - ) + ): + chunk_count += 1 + accumulated_text += chunk + + # 这里是非SSE接口,不需要发送chunk - # 提取内容(generate_text返回字典) - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response + ai_content = accumulated_text + ai_response = {"content": ai_content} # 解析响应 outline_data = _parse_ai_response(ai_content) @@ -1291,7 +1306,7 @@ async def new_outline_generator( user_id=user_id_for_mcp, db_session=db, enable_mcp=True, - max_tool_rounds=1, # ✅ 减少为1轮,避免超时 + max_tool_rounds=2, # ✅ 减少为1轮,避免超时 tool_choice="auto", provider=None, model=None @@ -1332,7 +1347,7 @@ async def new_outline_generator( mcp_references=mcp_reference_materials ) - # 调用AI + # 调用AI流式生成 yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30) # 添加调试日志 @@ -1341,24 +1356,44 @@ async def new_outline_generator( logger.info(f"=== 大纲生成AI调用参数 ===") logger.info(f" provider参数: {provider_param}") logger.info(f" model参数: {model_param}") - logger.info(f" 完整data: {data}") - ai_response = await user_ai_service.generate_text( + # ✅ 流式生成(带字数统计和进度) + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider_param, model=model_param - ) + ): + chunk_count += 1 + accumulated_text += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度和字数(30-95%,AI生成占65%) + if chunk_count % 5 == 0: + progress = min(30 + (chunk_count // 2), 95) + yield await SSEResponse.send_progress( + f"AI生成大纲中... ({len(accumulated_text)}字符)", + progress + ) + + # 每20个块发送心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() - yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 70) + yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96) - # 提取内容(generate_text返回字典) - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response + ai_content = accumulated_text + ai_response = {"content": ai_content} # 解析响应 outline_data = _parse_ai_response(ai_content) # 全新生成模式:删除旧大纲和关联的所有章节 - yield await SSEResponse.send_progress("清理旧大纲和章节...", 75) + yield await SSEResponse.send_progress("清理旧大纲和章节...", 97) logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})") from sqlalchemy import delete as sql_delete @@ -1390,7 +1425,7 @@ async def new_outline_generator( logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲") # 保存新大纲 - yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 80) + yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98) outlines = await _save_outlines( project_id, outline_data, db, start_index=1 ) @@ -1410,7 +1445,7 @@ async def new_outline_generator( for outline in outlines: await db.refresh(outline) - yield await SSEResponse.send_progress("整理结果数据...", 95) + yield await SSEResponse.send_progress("整理结果数据...", 99) logger.info(f"全新生成完成 - {len(outlines)} 章") @@ -1785,7 +1820,7 @@ async def continue_outline_generator( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, # ✅ 减少为1轮,避免超时 + max_tool_rounds=2, # ✅ 减少为1轮,避免超时 tool_choice="auto", provider=None, model=None @@ -1846,19 +1881,43 @@ async def continue_outline_generator( logger.info(f" provider参数: {provider_param}") logger.info(f" model参数: {model_param}") - ai_response = await user_ai_service.generate_text( + # 流式生成并累积文本 + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider_param, model=model_param - ) + ): + chunk_count += 1 + accumulated_text += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度(每批占用约50%的进度空间) + if chunk_count % 5 == 0: + # 在批次范围内平滑递增(从10到85,总共75%) + batch_range = 60 // total_batches # 总进度60%分配给所有批次 + progress_in_batch = batch_progress + 5 + min((chunk_count // 2), batch_range - 5) + yield await SSEResponse.send_progress( + f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)", + progress_in_batch + ) + + # 每20个块发送心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() yield await SSEResponse.send_progress( f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...", batch_progress + 10 ) - # 提取内容(generate_text返回字典) - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response + # 提取内容 + ai_content = accumulated_text + ai_response = {"content": ai_content} # 解析响应 outline_data = _parse_ai_response(ai_content) diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py index a992b9e..9b33b88 100644 --- a/backend/app/api/settings.py +++ b/backend/app/api/settings.py @@ -73,14 +73,15 @@ async def get_user_ai_service( await db.refresh(settings) logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库") - # 使用用户设置创建AI服务实例 + # 使用用户设置创建AI服务实例(包括系统提示词) return create_user_ai_service( api_provider=settings.api_provider, api_key=settings.api_key, api_base_url=settings.api_base_url or "", model_name=settings.llm_model, temperature=settings.temperature, - max_tokens=settings.max_tokens + max_tokens=settings.max_tokens, + system_prompt=settings.system_prompt # 传递系统提示词 ) @@ -271,17 +272,30 @@ async def get_available_models( } elif provider == "anthropic": - # Anthropic 没有公开的模型列表API - raise HTTPException( - status_code=400, - detail="Anthropic 不支持自动获取模型列表,请手动输入模型名称" - ) + # Anthropic models API + url = f"{api_base_url.rstrip('/')}/v1/models" + headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"} + response = await client.get(url, headers=headers) + response.raise_for_status() + data = response.json() + models = [{"value": m["id"], "label": m["id"], "description": m.get("display_name", "")} for m in data.get("data", [])] + return {"provider": provider, "models": models, "count": len(models)} + + elif provider == "gemini": + # Gemini models API + url = f"{api_base_url.rstrip('/')}/models?key={api_key}" + response = await client.get(url) + response.raise_for_status() + data = response.json() + models = [] + for m in data.get("models", []): + if "generateContent" in m.get("supportedGenerationMethods", []): + mid = m.get("name", "").replace("models/", "") + models.append({"value": mid, "label": m.get("displayName", mid), "description": ""}) + return {"provider": provider, "models": models, "count": len(models)} else: - raise HTTPException( - status_code=400, - detail=f"不支持的提供商: {provider}" - ) + raise HTTPException(status_code=400, detail=f"不支持的提供商: {provider}") except httpx.HTTPStatusError as e: logger.error(f"获取模型列表失败 (HTTP {e.response.status_code}): {e.response.text}") diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 999fdea..ea1277d 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -99,7 +99,7 @@ async def world_building_generator( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, + max_tool_rounds=2, tool_choice="auto", provider=None, model=None @@ -139,51 +139,118 @@ async def world_building_generator( final_prompt = base_prompt yield await SSEResponse.send_progress("正在调用AI生成...", 30) - # 流式生成世界观 - accumulated_text = "" - chunk_count = 0 - - async for chunk in user_ai_service.generate_text_stream( - prompt=final_prompt, - provider=provider, - model=model - ): - chunk_count += 1 - accumulated_text += chunk - - # 发送内容块 - yield await SSEResponse.send_chunk(chunk) - - # 定期更新进度 - if chunk_count % 5 == 0: - progress = min(30 + (chunk_count // 5), 70) - yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress) - - # 每20个块发送心跳 - if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() - - # 解析结果 - 使用统一的JSON清洗方法 - yield await SSEResponse.send_progress("解析AI返回结果...", 80) - + # ===== 流式生成世界观(带重试机制) ===== + MAX_WORLD_RETRIES = 3 # 最多重试3次 + world_retry_count = 0 + world_generation_success = False world_data = {} - try: - # ✅ 使用 AIService 的统一清洗方法 - cleaned_text = user_ai_service._clean_json_response(accumulated_text) - world_data = json.loads(cleaned_text) - logger.info(f"✅ 世界观JSON解析成功") + + while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success: + try: + retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else "" + yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 30 + world_retry_count * 5) + + # 流式生成世界观 + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( + prompt=final_prompt, + provider=provider, + model=model + ): + chunk_count += 1 + accumulated_text += chunk - except json.JSONDecodeError as e: - logger.error(f"❌ 世界构建JSON解析失败: {e}") - logger.error(f" 原始内容预览: {accumulated_text[:200]}") - world_data = { - "time_period": "AI返回格式错误,请重试", - "location": "AI返回格式错误,请重试", - "atmosphere": "AI返回格式错误,请重试", - "rules": "AI返回格式错误,请重试" - } + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 世界观生成独立进度:5-95% + if chunk_count % 5 == 0: + progress = min(5 + (chunk_count // 3), 95) + yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress) + + # 每20个块发送心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() + + # 检查是否返回空响应 + if not accumulated_text or not accumulated_text.strip(): + logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") + world_retry_count += 1 + if world_retry_count < MAX_WORLD_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ AI返回为空,准备重试...", + 30 + world_retry_count * 5, + "warning" + ) + continue + else: + # 达到最大重试次数,使用默认值 + logger.error("❌ 世界观生成多次返回空响应") + world_data = { + "time_period": "AI多次返回为空,请稍后重试", + "location": "AI多次返回为空,请稍后重试", + "atmosphere": "AI多次返回为空,请稍后重试", + "rules": "AI多次返回为空,请稍后重试" + } + world_generation_success = True # 标记为成功以继续流程 + break + + # 解析结果 - 使用统一的JSON清洗方法 + yield await SSEResponse.send_progress("解析世界观数据...", 96) + + try: + logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}") + logger.info(f" 原始内容预览: {accumulated_text[:300]}...") + + # ✅ 使用 AIService 的统一清洗方法 + cleaned_text = user_ai_service._clean_json_response(accumulated_text) + logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}") + logger.info(f" 清洗后预览: {cleaned_text[:300]}...") + + world_data = json.loads(cleaned_text) + logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") + world_generation_success = True # 解析成功,标记完成 + + except json.JSONDecodeError as e: + logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}") + logger.error(f" 原始内容长度: {len(accumulated_text)}") + logger.error(f" 原始内容预览: {accumulated_text[:200]}") + world_retry_count += 1 + if world_retry_count < MAX_WORLD_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ JSON解析失败,准备重试...", + 30 + world_retry_count * 5, + "warning" + ) + continue + else: + # 达到最大重试次数,使用默认值 + world_data = { + "time_period": "AI返回格式错误,请重试", + "location": "AI返回格式错误,请重试", + "atmosphere": "AI返回格式错误,请重试", + "rules": "AI返回格式错误,请重试" + } + world_generation_success = True # 标记为成功以继续流程 + + except Exception as e: + logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}") + world_retry_count += 1 + if world_retry_count < MAX_WORLD_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ 生成异常,准备重试...", + 30 + world_retry_count * 5, + "warning" + ) + continue + else: + # 最后一次重试仍失败,抛出异常 + logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}") + raise # 保存到数据库 - yield await SSEResponse.send_progress("保存到数据库...", 90) + yield await SSEResponse.send_progress("保存世界观到数据库...", 99) # 确保user_id存在 if not user_id: @@ -240,41 +307,81 @@ async def world_building_generator( project.wizard_step = 1 await db.commit() - # ===== 自动生成职业体系 ===== - yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 75) + # ===== 自动生成职业体系(带重试机制+流式) ===== + yield await SSEResponse.send_progress("世界观完成!", 100, "success") + yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5) logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系") - try: - # 获取职业生成提示词模板(支持用户自定义) - template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db) - career_prompt = PromptService.format_prompt( - template, - title=project.title, - genre=genre or '未设定', - theme=theme or '未设定', - time_period=world_data.get('time_period', '未设定'), - location=world_data.get('location', '未设定'), - atmosphere=world_data.get('atmosphere', '未设定'), - rules=world_data.get('rules', '未设定') - ) - - yield await SSEResponse.send_progress("正在生成职业体系...", 78) - - # 调用AI生成职业 - result = await user_ai_service.generate_text(prompt=career_prompt) - career_response = result.get('content', '') if isinstance(result, dict) else result - - if not career_response or not career_response.strip(): - logger.warning("⚠️ AI返回空职业体系,跳过职业生成") - yield await SSEResponse.send_progress("职业体系生成跳过(AI返回为空)", 85) - else: - yield await SSEResponse.send_progress("解析职业体系数据...", 82) + MAX_CAREER_RETRIES = 3 # 最多重试3次 + career_retry_count = 0 + career_generation_success = False + + while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success: + try: + retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else "" + yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10) + + # 获取职业生成提示词模板(支持用户自定义) + template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db) + career_prompt = PromptService.format_prompt( + template, + title=project.title, + genre=genre or '未设定', + theme=theme or '未设定', + time_period=world_data.get('time_period', '未设定'), + location=world_data.get('location', '未设定'), + atmosphere=world_data.get('atmosphere', '未设定'), + rules=world_data.get('rules', '未设定') + ) + + # ✅ 使用流式生成职业体系 + career_response = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( + prompt=career_prompt, + provider=provider, + model=model + ): + chunk_count += 1 + career_response += chunk + + # 发送内容块 + yield await SSEResponse.send_chunk(chunk) + + # 职业体系生成独立进度:10-95% + if chunk_count % 5 == 0: + progress = min(10 + (chunk_count // 3), 95) + yield await SSEResponse.send_progress( + f"生成职业体系中... ({len(career_response)}字符)", + progress + ) + + # 每20个块发送心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() + + if not career_response or not career_response.strip(): + logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})") + career_retry_count += 1 + if career_retry_count < MAX_CAREER_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ AI返回为空,准备重试...", + 10, + "warning" + ) + continue + else: + yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99) + break + + yield await SSEResponse.send_progress("解析职业体系数据...", 96) # 清洗并解析JSON try: cleaned_response = user_ai_service._clean_json_response(career_response) career_data = json.loads(cleaned_response) - logger.info(f"✅ 职业体系JSON解析成功") + logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})") # 保存主职业 main_careers_created = [] @@ -338,22 +445,51 @@ async def world_building_generator( await db.commit() + # 标记成功 + career_generation_success = True logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}个") yield await SSEResponse.send_progress( f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)})", - 90 + 99 ) except json.JSONDecodeError as e: - logger.error(f"❌ 职业体系JSON解析失败: {e}") - yield await SSEResponse.send_progress("⚠️ 职业体系解析失败,已跳过", 85) + logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}") + career_retry_count += 1 + if career_retry_count < MAX_CAREER_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ JSON解析失败,准备重试...", + 10, + "warning" + ) + continue + else: + yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99) except Exception as e: - logger.error(f"❌ 职业体系保存失败: {e}") - yield await SSEResponse.send_progress("⚠️ 职业体系保存失败,已跳过", 85) - - except Exception as e: - logger.error(f"❌ 职业体系生成异常: {e}") - yield await SSEResponse.send_progress("⚠️ 职业体系生成失败,已跳过(不影响项目创建)", 85) + logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}") + career_retry_count += 1 + if career_retry_count < MAX_CAREER_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ 保存失败,准备重试...", + 10, + "warning" + ) + continue + else: + yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99) + + except Exception as e: + logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}") + career_retry_count += 1 + if career_retry_count < MAX_CAREER_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ 生成异常,准备重试...", + 10, + "warning" + ) + continue + else: + yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99) db_committed = True @@ -366,7 +502,8 @@ async def world_building_generator( "rules": world_data.get("rules") }) - yield await SSEResponse.send_progress("完成!", 100, "success") + yield await SSEResponse.send_progress("职业体系完成!", 100, "success") + yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success") yield await SSEResponse.send_done() except GeneratorExit: @@ -473,7 +610,7 @@ async def characters_generator( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, # ✅ 优化: 从2轮减少到1轮 + max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮 tool_choice="auto", provider=None, model=None @@ -611,15 +748,32 @@ async def characters_generator( else: prompt = base_prompt - # 流式生成 + # 流式生成(带字数统计) accumulated_text = "" + chunk_count = 0 + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider, model=model ): + chunk_count += 1 accumulated_text += chunk + + # 发送内容块 yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度和字数 + if chunk_count % 5 == 0: + progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15) + yield await SSEResponse.send_progress( + f"生成角色中... ({len(accumulated_text)}字符)", + progress + ) + + # 每20个块发送心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() # 解析批次结果 - 使用统一的JSON清洗方法 cleaned_text = user_ai_service._clean_json_response(accumulated_text) @@ -1184,18 +1338,35 @@ async def outline_generator( requirements=outline_requirements ) - # 流式生成大纲 + # 流式生成大纲(带字数统计) accumulated_text = "" + chunk_count = 0 + async for chunk in user_ai_service.generate_text_stream( prompt=outline_prompt, provider=provider, model=model ): + chunk_count += 1 accumulated_text += chunk + + # 发送内容块 yield await SSEResponse.send_chunk(chunk) + + # 定期更新进度和字数(5-95%,AI生成占90%) + if chunk_count % 5 == 0: + progress = min(5 + (chunk_count // 3), 95) + yield await SSEResponse.send_progress( + f"生成大纲中... ({len(accumulated_text)}字符)", + progress + ) + + # 每20个块发送心跳 + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() # 解析大纲结果 - 使用统一的JSON清洗方法 - yield await SSEResponse.send_progress("解析大纲...", 40) + yield await SSEResponse.send_progress("解析大纲...", 96) try: cleaned_text = user_ai_service._clean_json_response(accumulated_text) @@ -1208,7 +1379,7 @@ async def outline_generator( return # 保存大纲到数据库 - yield await SSEResponse.send_progress("保存大纲到数据库...", 45) + yield await SSEResponse.send_progress("保存大纲到数据库...", 97) created_outlines = [] for index, outline_item in enumerate(outline_data[:outline_count], 1): outline = Outline( @@ -1231,7 +1402,7 @@ async def outline_generator( created_chapters = [] if project.outline_mode == 'one-to-one': # 一对一模式:自动为每个大纲创建对应的章节 - yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 50) + yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98) for outline in created_outlines: chapter = Chapter( @@ -1250,10 +1421,10 @@ async def outline_generator( await db.refresh(chapter) logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节") - yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 85) + yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99) else: # 一对多模式:跳过自动创建,用户可手动展开 - yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 85) + yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99) logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开") # 更新项目信息 @@ -1396,7 +1567,7 @@ async def world_building_regenerate_generator( user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1, + max_tool_rounds=2, tool_choice="auto", provider=None, model=None @@ -1433,44 +1604,109 @@ async def world_building_regenerate_generator( final_prompt = base_prompt yield await SSEResponse.send_progress("正在调用AI生成...", 30) - # 流式生成世界观 - accumulated_text = "" - chunk_count = 0 - - async for chunk in user_ai_service.generate_text_stream( - prompt=final_prompt, - provider=provider, - model=model - ): - chunk_count += 1 - accumulated_text += chunk - - yield await SSEResponse.send_chunk(chunk) - - if chunk_count % 5 == 0: - progress = min(30 + (chunk_count // 5), 70) - yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress) - - if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() - - # 解析结果 - 使用统一的JSON清洗方法 - yield await SSEResponse.send_progress("解析AI返回结果...", 80) - + # ===== 流式生成世界观(带重试机制) ===== + MAX_WORLD_RETRIES = 3 # 最多重试3次 + world_retry_count = 0 + world_generation_success = False world_data = {} - try: - cleaned_text = user_ai_service._clean_json_response(accumulated_text) - world_data = json.loads(cleaned_text) - logger.info(f"✅ 世界观重新生成JSON解析成功") + + while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success: + try: + retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else "" + yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 30 + world_retry_count * 5) + + # 流式生成世界观 + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( + prompt=final_prompt, + provider=provider, + model=model + ): + chunk_count += 1 + accumulated_text += chunk - except json.JSONDecodeError as e: - logger.error(f"世界构建JSON解析失败: {e}") - world_data = { - "time_period": "AI返回格式错误,请重试", - "location": "AI返回格式错误,请重试", - "atmosphere": "AI返回格式错误,请重试", - "rules": "AI返回格式错误,请重试" - } + yield await SSEResponse.send_chunk(chunk) + + if chunk_count % 5 == 0: + progress = min(30 + (chunk_count // 5), 85) + yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress) + + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() + + # 检查是否返回空响应 + if not accumulated_text or not accumulated_text.strip(): + logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") + world_retry_count += 1 + if world_retry_count < MAX_WORLD_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ AI返回为空,准备重试...", + 30 + world_retry_count * 5, + "warning" + ) + continue + else: + # 达到最大重试次数,使用默认值 + logger.error("❌ 世界观重新生成多次返回空响应") + world_data = { + "time_period": "AI多次返回为空,请稍后重试", + "location": "AI多次返回为空,请稍后重试", + "atmosphere": "AI多次返回为空,请稍后重试", + "rules": "AI多次返回为空,请稍后重试" + } + world_generation_success = True + break + + # 解析结果 - 使用统一的JSON清洗方法 + yield await SSEResponse.send_progress("解析AI返回结果...", 80) + + try: + logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}") + cleaned_text = user_ai_service._clean_json_response(accumulated_text) + logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}") + + world_data = json.loads(cleaned_text) + logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") + world_generation_success = True + + except json.JSONDecodeError as e: + logger.error(f"❌ 世界构建JSON解析失败(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {e}") + logger.error(f" 原始内容长度: {len(accumulated_text)}") + logger.error(f" 原始内容预览: {accumulated_text[:200]}") + world_retry_count += 1 + if world_retry_count < MAX_WORLD_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ JSON解析失败,准备重试...", + 30 + world_retry_count * 5, + "warning" + ) + continue + else: + # 达到最大重试次数,使用默认值 + world_data = { + "time_period": "AI返回格式错误,请重试", + "location": "AI返回格式错误,请重试", + "atmosphere": "AI返回格式错误,请重试", + "rules": "AI返回格式错误,请重试" + } + world_generation_success = True + + except Exception as e: + logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}") + world_retry_count += 1 + if world_retry_count < MAX_WORLD_RETRIES: + yield await SSEResponse.send_progress( + f"⚠️ 生成异常,准备重试...", + 30 + world_retry_count * 5, + "warning" + ) + continue + else: + # 最后一次重试仍失败,抛出异常 + logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}") + raise # 不保存到数据库,仅返回生成结果供用户预览 yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90) diff --git a/backend/app/api/writing_styles.py b/backend/app/api/writing_styles.py index 6295d6e..d275ac7 100644 --- a/backend/app/api/writing_styles.py +++ b/backend/app/api/writing_styles.py @@ -15,7 +15,6 @@ from ..schemas.writing_style import ( WritingStyleListResponse, SetDefaultStyleRequest ) -from ..services.prompt_service import WritingStyleManager from ..logger import get_logger router = APIRouter(prefix="/writing-styles", tags=["writing-styles"]) @@ -31,21 +30,36 @@ def get_current_user_id(request: Request) -> str: @router.get("/presets/list", response_model=List[dict]) -async def get_preset_styles(): +async def get_preset_styles(db: AsyncSession = Depends(get_db)): """ - 获取所有预设风格列表 + 获取所有预设风格列表(从数据库读取) 返回格式:数组形式的预设风格列表 [ - {"id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."}, - {"id": "classical", "name": "古典优雅", ...} + {"id": 1, "preset_id": "natural", "name": "自然流畅", "description": "...", "prompt_content": "..."}, + {"id": 2, "preset_id": "classical", "name": "古典优雅", ...} ] """ - presets = WritingStyleManager.get_all_presets() - # 将字典转换为数组,添加 id 字段 + # 从数据库获取全局预设风格(user_id 为 NULL) + result = await db.execute( + select(WritingStyle) + .where(WritingStyle.user_id.is_(None)) + .order_by(WritingStyle.order_index) + ) + preset_styles = result.scalars().all() + + # 转换为响应格式 return [ - {"id": preset_id, **preset_data} - for preset_id, preset_data in presets.items() + { + "id": style.id, + "preset_id": style.preset_id, + "name": style.name, + "description": style.description, + "prompt_content": style.prompt_content, + "style_type": style.style_type, + "order_index": style.order_index + } + for style in preset_styles ] @@ -58,25 +72,33 @@ async def create_writing_style( """ 创建新的写作风格(用户级别) - - **基于预设创建**:提供 preset_id,系统会自动填充预设内容 + - **基于预设创建**:提供 preset_id,系统会从数据库查询预设内容自动填充 - **完全自定义**:不提供 preset_id,需要手动填写所有字段 """ # 获取当前用户ID user_id = get_current_user_id(request) - # 如果基于预设创建,获取预设内容 + # 如果基于预设创建,从数据库获取预设内容 if style_data.preset_id: - preset = WritingStyleManager.get_preset_style(style_data.preset_id) + result = await db.execute( + select(WritingStyle) + .where( + WritingStyle.user_id.is_(None), + WritingStyle.preset_id == style_data.preset_id + ) + ) + preset = result.scalar_one_or_none() + if not preset: raise HTTPException(status_code=400, detail=f"预设风格 '{style_data.preset_id}' 不存在") # 使用预设内容填充(如果用户未提供) if not style_data.name: - style_data.name = preset["name"] + style_data.name = preset.name if not style_data.description: - style_data.description = preset["description"] + style_data.description = preset.description if not style_data.prompt_content: - style_data.prompt_content = preset["prompt_content"] + style_data.prompt_content = preset.prompt_content # 验证必填字段 if not style_data.name or not style_data.prompt_content: diff --git a/backend/app/mcp/http_client.py b/backend/app/mcp/http_client.py index f8d75de..139ff90 100644 --- a/backend/app/mcp/http_client.py +++ b/backend/app/mcp/http_client.py @@ -6,6 +6,7 @@ from contextlib import asynccontextmanager from mcp import ClientSession, types from mcp.client.streamable_http import streamablehttp_client from pydantic import AnyUrl +from anyio import ClosedResourceError from app.logger import get_logger @@ -141,51 +142,89 @@ class HTTPMCPClient: async def call_tool( self, tool_name: str, - arguments: Dict[str, Any] + arguments: Dict[str, Any], + max_reconnect_attempts: int = 2 ) -> Any: """ - 调用工具 + 调用工具(带自动重连) Args: tool_name: 工具名称 arguments: 工具参数 + max_reconnect_attempts: 最大重连尝试次数 Returns: 工具执行结果 """ - try: - await self._ensure_connected() - - logger.info(f"调用工具: {tool_name}") - logger.debug(f"参数: {arguments}") - - result = await self._session.call_tool(tool_name, arguments) - - # 处理返回结果 - # MCP SDK 返回 CallToolResult 对象 - if result.content: - # 提取第一个content的文本 - for content in result.content: - if isinstance(content, types.TextContent): - return content.text - elif isinstance(content, types.ImageContent): - return { - "type": "image", - "data": content.data, - "mimeType": content.mimeType - } - # 如果没有文本内容,返回原始内容 - return result.content[0] if result.content else None - - # 如果有结构化内容(2025-06-18规范) - if hasattr(result, 'structuredContent') and result.structuredContent: - return result.structuredContent - - return None - - except Exception as e: - logger.error(f"调用工具失败: {tool_name}, 错误: {e}") - raise MCPError(f"调用工具失败: {str(e)}") + for attempt in range(max_reconnect_attempts + 1): + try: + await self._ensure_connected() + + logger.info(f"调用工具: {tool_name}") + logger.debug(f" 参数类型: {type(arguments)}") + logger.debug(f" 参数内容: {arguments}") + logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}") + + result = await self._session.call_tool(tool_name, arguments) + + logger.debug(f" 工具返回类型: {type(result)}") + logger.debug(f" 返回内容: {result}") + + # 处理返回结果 + # MCP SDK 返回 CallToolResult 对象 + if result.content: + logger.debug(f" 返回content数量: {len(result.content)}") + # 提取第一个content的文本 + for idx, content in enumerate(result.content): + logger.debug(f" content[{idx}]类型: {type(content)}") + if isinstance(content, types.TextContent): + logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}") + return content.text + elif isinstance(content, types.ImageContent): + logger.debug(f" ✅ 返回ImageContent") + return { + "type": "image", + "data": content.data, + "mimeType": content.mimeType + } + # 如果没有文本内容,返回原始内容 + logger.debug(f" ⚠️ 返回原始content[0]") + return result.content[0] if result.content else None + + # 如果有结构化内容(2025-06-18规范) + if hasattr(result, 'structuredContent') and result.structuredContent: + logger.debug(f" ✅ 返回structuredContent") + return result.structuredContent + + logger.warning(f" ⚠️ 工具返回为None") + return None + + except ClosedResourceError as e: + # 连接已关闭,尝试重连 + if attempt < max_reconnect_attempts: + logger.warning( + f"⚠️ MCP连接已关闭,尝试重新连接 " + f"(第{attempt + 1}/{max_reconnect_attempts}次重连)" + ) + await self._cleanup() + await asyncio.sleep(0.5) # 短暂延迟后重连 + continue + else: + logger.error(f"❌ MCP连接重连失败,已达最大重试次数") + error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)" + raise MCPError(error_msg) + + except Exception as e: + logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True) + logger.error(f" 参数: {arguments}") + logger.error(f" 错误类型: {type(e).__name__}") + logger.error(f" 错误详情: {repr(e)}") + logger.error(f" 错误字符串: '{str(e)}'") + error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})" + raise MCPError(f"调用工具失败: {error_msg}") + + # 理论上不会到这里 + raise MCPError(f"工具调用失败: 未知错误") async def list_resources(self) -> List[Dict[str, Any]]: """ diff --git a/backend/app/models/settings.py b/backend/app/models/settings.py index eaf7b4c..e1d9ca6 100644 --- a/backend/app/models/settings.py +++ b/backend/app/models/settings.py @@ -17,6 +17,7 @@ class Settings(Base): llm_model = Column(String(100), default="gpt-4", comment="模型名称") temperature = Column(Float, default=0.7, comment="温度参数") max_tokens = Column(Integer, default=2000, comment="最大token数") + system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用") preferences = Column(Text, comment="其他偏好设置(JSON)") created_at = Column(DateTime, server_default=func.now(), comment="创建时间") updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间") diff --git a/backend/app/schemas/settings.py b/backend/app/schemas/settings.py index 21005ce..b53d7f5 100644 --- a/backend/app/schemas/settings.py +++ b/backend/app/schemas/settings.py @@ -14,6 +14,7 @@ class SettingsBase(BaseModel): llm_model: Optional[str] = Field(default="gpt-4", description="模型名称") temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数") max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数") + system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用") preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)") diff --git a/backend/app/services/ai_clients/__init__.py b/backend/app/services/ai_clients/__init__.py new file mode 100644 index 0000000..944e346 --- /dev/null +++ b/backend/app/services/ai_clients/__init__.py @@ -0,0 +1,6 @@ +"""AI 客户端模块""" +from .base_client import BaseAIClient +from .openai_client import OpenAIClient +from .anthropic_client import AnthropicClient + +__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"] \ No newline at end of file diff --git a/backend/app/services/ai_clients/anthropic_client.py b/backend/app/services/ai_clients/anthropic_client.py new file mode 100644 index 0000000..11c3125 --- /dev/null +++ b/backend/app/services/ai_clients/anthropic_client.py @@ -0,0 +1,86 @@ +"""Anthropic 客户端""" +from typing import Any, AsyncGenerator, Dict, Optional + +from anthropic import AsyncAnthropic + +from app.logger import get_logger +from app.services.ai_config import AIClientConfig, default_config + +logger = get_logger(__name__) + + +class AnthropicClient: + """Anthropic API 客户端""" + + def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None): + self.config = config or default_config + kwargs = {"api_key": api_key} + if base_url: + kwargs["base_url"] = base_url + self.client = AsyncAnthropic(**kwargs) + + async def chat_completion( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + kwargs = { + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "messages": messages, + } + if system_prompt: + kwargs["system"] = system_prompt + if tools: + kwargs["tools"] = tools + if tool_choice == "required": + kwargs["tool_choice"] = {"type": "any"} + elif tool_choice == "auto": + kwargs["tool_choice"] = {"type": "auto"} + + response = await self.client.messages.create(**kwargs) + + tool_calls = [] + content = "" + for block in response.content: + if block.type == "tool_use": + tool_calls.append({ + "id": block.id, + "type": "function", + "function": {"name": block.name, "arguments": block.input}, + }) + elif block.type == "text": + content += block.text + + return { + "content": content, + "tool_calls": tool_calls if tool_calls else None, + "finish_reason": response.stop_reason, + } + + async def chat_completion_stream( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + kwargs = { + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "messages": messages, + } + if system_prompt: + kwargs["system"] = system_prompt + + async with self.client.messages.stream(**kwargs) as stream: + async for text in stream.text_stream: + yield text \ No newline at end of file diff --git a/backend/app/services/ai_clients/base_client.py b/backend/app/services/ai_clients/base_client.py new file mode 100644 index 0000000..be182b4 --- /dev/null +++ b/backend/app/services/ai_clients/base_client.py @@ -0,0 +1,154 @@ +"""AI 客户端基类""" +import asyncio +import hashlib +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Dict, Optional + +import httpx + +from app.logger import get_logger +from app.services.ai_config import AIClientConfig, default_config + +logger = get_logger(__name__) + +# 全局 HTTP 客户端池 +_http_client_pool: Dict[str, httpx.AsyncClient] = {} +_global_semaphore: Optional[asyncio.Semaphore] = None + + +def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore: + """获取全局信号量""" + global _global_semaphore + if _global_semaphore is None: + _global_semaphore = asyncio.Semaphore(max_concurrent) + return _global_semaphore + + +class BaseAIClient(ABC): + """AI HTTP 客户端基类""" + + def __init__( + self, + api_key: str, + base_url: str, + config: Optional[AIClientConfig] = None, + ): + self.api_key = api_key + self.base_url = base_url.rstrip("/") + self.config = config or default_config + self.http_client = self._get_or_create_client() + + def _get_client_key(self) -> str: + """生成客户端唯一键""" + key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8] + return f"{self.__class__.__name__}_{self.base_url}_{key_hash}" + + def _get_or_create_client(self) -> httpx.AsyncClient: + """获取或创建 HTTP 客户端""" + client_key = self._get_client_key() + + if client_key in _http_client_pool: + client = _http_client_pool[client_key] + if not client.is_closed: + return client + del _http_client_pool[client_key] + + http_cfg = self.config.http + client = httpx.AsyncClient( + timeout=httpx.Timeout( + connect=http_cfg.connect_timeout, + read=http_cfg.read_timeout, + write=http_cfg.write_timeout, + pool=http_cfg.pool_timeout, + ), + limits=httpx.Limits( + max_keepalive_connections=http_cfg.max_keepalive_connections, + max_connections=http_cfg.max_connections, + keepalive_expiry=http_cfg.keepalive_expiry, + ), + ) + _http_client_pool[client_key] = client + logger.info(f"✅ 创建 HTTP 客户端: {client_key}") + return client + + @abstractmethod + def _build_headers(self) -> Dict[str, str]: + """构建请求头""" + pass + + async def _request_with_retry( + self, + method: str, + endpoint: str, + payload: Dict[str, Any], + stream: bool = False, + ) -> Any: + """带重试的 HTTP 请求""" + url = f"{self.base_url}{endpoint}" + headers = self._build_headers() + retry_cfg = self.config.retry + rate_cfg = self.config.rate_limit + + semaphore = _get_semaphore(rate_cfg.max_concurrent_requests) + + async with semaphore: + await asyncio.sleep(rate_cfg.request_delay) + + for attempt in range(retry_cfg.max_retries): + try: + if attempt > 0: + delay = min( + retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt), + retry_cfg.max_delay, + ) + logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s") + await asyncio.sleep(delay) + + if stream: + return self.http_client.stream(method, url, headers=headers, json=payload) + + response = await self.http_client.request(method, url, headers=headers, json=payload) + response.raise_for_status() + return response.json() + + except httpx.HTTPStatusError as e: + if e.response.status_code in retry_cfg.non_retryable_status_codes: + raise + if attempt == retry_cfg.max_retries - 1: + raise + except (httpx.ConnectError, httpx.TimeoutException): + if attempt == retry_cfg.max_retries - 1: + raise + + @abstractmethod + async def chat_completion( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + """聊天补全""" + pass + + @abstractmethod + async def chat_completion_stream( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + ) -> AsyncGenerator[str, None]: + """流式聊天补全""" + pass + + +async def cleanup_all_clients(): + """清理所有 HTTP 客户端""" + for key, client in list(_http_client_pool.items()): + if not client.is_closed: + await client.aclose() + _http_client_pool.clear() + logger.info("✅ HTTP 客户端池已清理") \ No newline at end of file diff --git a/backend/app/services/ai_clients/gemini_client.py b/backend/app/services/ai_clients/gemini_client.py new file mode 100644 index 0000000..33d6cb4 --- /dev/null +++ b/backend/app/services/ai_clients/gemini_client.py @@ -0,0 +1,141 @@ +"""Gemini 客户端""" +from typing import Any, AsyncGenerator, Dict, List, Optional +import httpx +from app.services.ai_config import AIClientConfig, default_config + + +class GeminiClient: + """Google Gemini API 客户端""" + + def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None): + self.api_key = api_key + self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/") + self.config = config or default_config + http_cfg = self.config.http + self.client = httpx.AsyncClient( + timeout=httpx.Timeout( + connect=http_cfg.connect_timeout, + read=http_cfg.read_timeout, + write=http_cfg.write_timeout, + pool=http_cfg.pool_timeout + ) + ) + + def _convert_tools_to_gemini(self, tools: list) -> list: + """将 OpenAI 格式工具转换为 Gemini 格式""" + gemini_tools = [] + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + params = func.get("parameters", {}).copy() if func.get("parameters") else {} + params.pop("$schema", None) + params.pop("additionalProperties", None) + if params and "type" not in params: + params["type"] = "object" + decl = { + "name": func["name"], + "description": func.get("description") or func["name"], + } + if params: + decl["parameters"] = params + gemini_tools.append(decl) + return [{"functionDeclarations": gemini_tools}] if gemini_tools else [] + + async def chat_completion( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}" + + contents = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + contents.append({"role": role, "parts": [{"text": msg["content"]}]}) + + payload = { + "contents": contents, + "generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens} + } + if system_prompt: + payload["systemInstruction"] = {"parts": [{"text": system_prompt}]} + if tools: + payload["tools"] = self._convert_tools_to_gemini(tools) + + response = await self.client.post(url, json=payload) + response.raise_for_status() + data = response.json() + + candidates = data.get("candidates", []) + if not candidates or len(candidates) == 0: + # 返回空内容而不是报错,保持流程继续 + return { + "content": "", + "tool_calls": None, + "finish_reason": "stop" + } + + parts = candidates[0].get("content", {}).get("parts", []) + text = "" + tool_calls = [] + + for part in parts: + if "text" in part: + text += part["text"] + elif "functionCall" in part: + fc = part["functionCall"] + tool_calls.append({ + "id": f"call_{fc['name']}", + "type": "function", + "function": {"name": fc["name"], "arguments": fc.get("args", {})} + }) + + return { + "content": text, + "tool_calls": tool_calls if tool_calls else None, + "finish_reason": "tool_calls" if tool_calls else "stop" + } + + async def chat_completion_stream( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse" + + contents = [] + for msg in messages: + role = "user" if msg["role"] == "user" else "model" + contents.append({"role": role, "parts": [{"text": msg["content"]}]}) + + payload = { + "contents": contents, + "generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens} + } + if system_prompt: + payload["systemInstruction"] = {"parts": [{"text": system_prompt}]} + + async with self.client.stream("POST", url, json=payload) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + import json + try: + data = json.loads(line[6:]) + candidates = data.get("candidates", []) + if candidates and len(candidates) > 0: + parts = candidates[0].get("content", {}).get("parts", []) + if parts and len(parts) > 0: + text = parts[0].get("text", "") + if text: + yield text + except: + continue \ No newline at end of file diff --git a/backend/app/services/ai_clients/openai_client.py b/backend/app/services/ai_clients/openai_client.py new file mode 100644 index 0000000..c32657d --- /dev/null +++ b/backend/app/services/ai_clients/openai_client.py @@ -0,0 +1,101 @@ +"""OpenAI 客户端""" +import json +from typing import Any, AsyncGenerator, Dict, Optional + +from app.logger import get_logger +from .base_client import BaseAIClient + +logger = get_logger(__name__) + + +class OpenAIClient(BaseAIClient): + """OpenAI API 客户端""" + + def _build_headers(self) -> Dict[str, str]: + return { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + def _build_payload( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + stream: bool = False, + ) -> Dict[str, Any]: + payload = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + } + if stream: + payload["stream"] = True + if tools: + # 清理 $schema 字段 + cleaned = [] + for t in tools: + tc = t.copy() + if "function" in tc and "parameters" in tc["function"]: + tc["function"]["parameters"] = { + k: v for k, v in tc["function"]["parameters"].items() if k != "$schema" + } + cleaned.append(tc) + payload["tools"] = cleaned + if tool_choice: + payload["tool_choice"] = tool_choice + return payload + + async def chat_completion( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice) + data = await self._request_with_retry("POST", "/chat/completions", payload) + + choices = data.get("choices", []) + if not choices or len(choices) == 0: + raise ValueError("API 返回空 choices 或 choices 为空列表") + + choice = choices[0] + message = choice.get("message", {}) + return { + "content": message.get("content", ""), + "tool_calls": message.get("tool_calls"), + "finish_reason": choice.get("finish_reason"), + } + + async def chat_completion_stream( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + ) -> AsyncGenerator[str, None]: + payload = self._build_payload(messages, model, temperature, max_tokens, stream=True) + + async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + break + try: + data = json.loads(data_str) + choices = data.get("choices", []) + if choices and len(choices) > 0: + content = choices[0].get("delta", {}).get("content", "") + if content: + yield content + except json.JSONDecodeError: + continue \ No newline at end of file diff --git a/backend/app/services/ai_config.py b/backend/app/services/ai_config.py new file mode 100644 index 0000000..cce7dc0 --- /dev/null +++ b/backend/app/services/ai_config.py @@ -0,0 +1,44 @@ +"""AI 服务配置管理""" +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class HTTPClientConfig: + """HTTP 客户端配置""" + connect_timeout: float = 90.0 + read_timeout: float = 300.0 + write_timeout: float = 90.0 + pool_timeout: float = 90.0 + max_keepalive_connections: int = 50 + max_connections: int = 100 + keepalive_expiry: float = 60.0 + + +@dataclass +class RetryConfig: + """重试配置""" + max_retries: int = 3 + base_delay: float = 0.2 + max_delay: float = 10.0 + exponential_base: int = 2 + non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404)) + + +@dataclass +class RateLimitConfig: + """限流配置""" + max_concurrent_requests: int = 5 + request_delay: float = 0.2 + + +@dataclass +class AIClientConfig: + """AI 客户端完整配置""" + http: HTTPClientConfig = field(default_factory=HTTPClientConfig) + retry: RetryConfig = field(default_factory=RetryConfig) + rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig) + + +# 全局默认配置 +default_config = AIClientConfig() \ No newline at end of file diff --git a/backend/app/services/ai_providers/__init__.py b/backend/app/services/ai_providers/__init__.py new file mode 100644 index 0000000..85db8bb --- /dev/null +++ b/backend/app/services/ai_providers/__init__.py @@ -0,0 +1,6 @@ +"""AI Provider 模块""" +from .base_provider import BaseAIProvider +from .openai_provider import OpenAIProvider +from .anthropic_provider import AnthropicProvider + +__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"] \ No newline at end of file diff --git a/backend/app/services/ai_providers/anthropic_provider.py b/backend/app/services/ai_providers/anthropic_provider.py new file mode 100644 index 0000000..bff9773 --- /dev/null +++ b/backend/app/services/ai_providers/anthropic_provider.py @@ -0,0 +1,51 @@ +"""Anthropic Provider""" +from typing import Any, AsyncGenerator, Dict, List, Optional + +from app.services.ai_clients.anthropic_client import AnthropicClient +from .base_provider import BaseAIProvider + + +class AnthropicProvider(BaseAIProvider): + """Anthropic 提供商""" + + def __init__(self, client: AnthropicClient): + self.client = client + + async def generate( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + messages = [{"role": "user", "content": prompt}] + return await self.client.chat_completion( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tools=tools, + tool_choice=tool_choice, + ) + + async def generate_stream( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + messages = [{"role": "user", "content": prompt}] + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + ): + yield chunk \ No newline at end of file diff --git a/backend/app/services/ai_providers/base_provider.py b/backend/app/services/ai_providers/base_provider.py new file mode 100644 index 0000000..e9c1934 --- /dev/null +++ b/backend/app/services/ai_providers/base_provider.py @@ -0,0 +1,33 @@ +"""AI Provider 基类""" +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Dict, List, Optional + + +class BaseAIProvider(ABC): + """AI 提供商抽象基类""" + + @abstractmethod + async def generate( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + """生成文本""" + pass + + @abstractmethod + async def generate_stream( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """流式生成""" + pass \ No newline at end of file diff --git a/backend/app/services/ai_providers/gemini_provider.py b/backend/app/services/ai_providers/gemini_provider.py new file mode 100644 index 0000000..5b16cd9 --- /dev/null +++ b/backend/app/services/ai_providers/gemini_provider.py @@ -0,0 +1,48 @@ +"""Gemini Provider""" +from typing import Any, AsyncGenerator, Dict, List, Optional +from app.services.ai_clients.gemini_client import GeminiClient +from .base_provider import BaseAIProvider + + +class GeminiProvider(BaseAIProvider): + def __init__(self, client: GeminiClient): + self.client = client + + async def generate( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + messages = [{"role": "user", "content": prompt}] + return await self.client.chat_completion( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tools=tools, + tool_choice=tool_choice, + ) + + async def generate_stream( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + messages = [{"role": "user", "content": prompt}] + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + ): + yield chunk \ No newline at end of file diff --git a/backend/app/services/ai_providers/openai_provider.py b/backend/app/services/ai_providers/openai_provider.py new file mode 100644 index 0000000..c9db53b --- /dev/null +++ b/backend/app/services/ai_providers/openai_provider.py @@ -0,0 +1,57 @@ +"""OpenAI Provider""" +from typing import Any, AsyncGenerator, Dict, List, Optional + +from app.services.ai_clients.openai_client import OpenAIClient +from .base_provider import BaseAIProvider + + +class OpenAIProvider(BaseAIProvider): + """OpenAI 提供商""" + + def __init__(self, client: OpenAIClient): + self.client = client + + async def generate( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + ) -> Dict[str, Any]: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + return await self.client.chat_completion( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + tool_choice=tool_choice, + ) + + async def generate_stream( + self, + prompt: str, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + ): + yield chunk \ No newline at end of file diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py index 2506bf7..063c073 100644 --- a/backend/app/services/ai_service.py +++ b/backend/app/services/ai_service.py @@ -1,124 +1,29 @@ -"""AI服务封装 - 统一的OpenAI和Claude接口""" -from typing import Optional, AsyncGenerator, List, Dict, Any -from openai import AsyncOpenAI -from anthropic import AsyncAnthropic +"""AI服务封装 - 统一的AI接口""" +from typing import Optional, AsyncGenerator, List, Dict, Any, Union + from app.config import settings as app_settings from app.logger import get_logger -from app.mcp.adapters import PromptInjectionAdapter +from app.services.ai_config import AIClientConfig, default_config +from app.services.ai_clients.openai_client import OpenAIClient +from app.services.ai_clients.anthropic_client import AnthropicClient +from app.services.ai_clients.gemini_client import GeminiClient +from app.services.ai_clients.base_client import cleanup_all_clients +from app.services.ai_providers.openai_provider import OpenAIProvider +from app.services.ai_providers.anthropic_provider import AnthropicProvider +from app.services.ai_providers.gemini_provider import GeminiProvider +from app.services.ai_providers.base_provider import BaseAIProvider +from app.services.json_helper import clean_json_response, parse_json from app.mcp.adapters.universal import universal_mcp_adapter -import httpx -import json -import hashlib -import re -import asyncio + +# 导出清理函数 +cleanup_http_clients = cleanup_all_clients logger = get_logger(__name__) -# 全局请求限流器(使用信号量控制并发数) -_global_semaphore = asyncio.Semaphore(5) # 最多5个并发请求 -_request_delay = 0.2 # 请求间隔200ms - -# 全局HTTP客户端池(按配置复用) -_http_client_pool: Dict[str, httpx.AsyncClient] = {} -_client_pool_lock = False # 简单的锁标志 - - -def _get_client_key(provider: str, base_url: Optional[str], api_key: str) -> str: - """生成HTTP客户端的唯一键 - - Args: - provider: 提供商名称 - base_url: API基础URL - api_key: API密钥(用于区分不同用户) - - Returns: - 客户端唯一键 - """ - # 使用API密钥的哈希值(安全性)+ 提供商 + base_url 作为键 - key_hash = hashlib.md5(api_key.encode()).hexdigest()[:8] - url_part = base_url or "default" - return f"{provider}_{url_part}_{key_hash}" - - -def _get_or_create_http_client( - provider: str, - base_url: Optional[str], - api_key: str -) -> httpx.AsyncClient: - """获取或创建HTTP客户端(复用连接) - - Args: - provider: 提供商名称 - base_url: API基础URL - api_key: API密钥 - - Returns: - httpx.AsyncClient实例 - """ - global _http_client_pool - - client_key = _get_client_key(provider, base_url, api_key) - - # 检查是否已存在 - if client_key in _http_client_pool: - client = _http_client_pool[client_key] - # 检查客户端是否仍然有效 - if not client.is_closed: - logger.debug(f"♻️ 复用HTTP客户端: {client_key}") - return client - else: - # 客户端已关闭,从池中移除 - logger.warning(f"⚠️ HTTP客户端已关闭,重新创建: {client_key}") - del _http_client_pool[client_key] - - # 创建新客户端 - limits = httpx.Limits( - max_keepalive_connections=50, # 最大保持连接数 - max_connections=100, # 最大总连接数 - keepalive_expiry=30.0 # 保持连接30秒 - ) - - client = httpx.AsyncClient( - timeout=httpx.Timeout( - connect=90.0, # 连接超时 - read=300.0, # 读取超时 - write=90.0, # 写入超时 - pool=90.0 # 连接池超时 - ), - limits=limits, - headers={ - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" - } - ) - - # 添加到池中 - _http_client_pool[client_key] = client - logger.info(f"✅ 创建新HTTP客户端并加入池: {client_key} (池大小: {len(_http_client_pool)})") - - return client - - -async def cleanup_http_clients(): - """清理所有HTTP客户端(应用关闭时调用)""" - global _http_client_pool - - logger.info(f"🧹 开始清理HTTP客户端池 (共 {len(_http_client_pool)} 个客户端)") - - for key, client in list(_http_client_pool.items()): - try: - if not client.is_closed: - await client.aclose() - logger.debug(f"✅ 关闭HTTP客户端: {key}") - except Exception as e: - logger.error(f"❌ 关闭HTTP客户端失败 {key}: {e}") - - _http_client_pool.clear() - logger.info("✅ HTTP客户端池清理完成") - class AIService: - """AI服务统一接口 - 支持从用户设置或全局配置初始化""" - + """AI服务统一接口""" + def __init__( self, api_provider: Optional[str] = None, @@ -127,99 +32,53 @@ class AIService: default_model: Optional[str] = None, default_temperature: Optional[float] = None, default_max_tokens: Optional[int] = None, - enable_mcp_adapter: bool = True + default_system_prompt: Optional[str] = None, + enable_mcp_adapter: bool = True, + config: Optional[AIClientConfig] = None, ): - """ - 初始化AI客户端(优化并发性能) - - Args: - api_provider: API提供商 (openai/anthropic),为None时使用全局配置 - api_key: API密钥,为None时使用全局配置 - api_base_url: API基础URL,为None时使用全局配置 - default_model: 默认模型,为None时使用全局配置 - default_temperature: 默认温度,为None时使用全局配置 - default_max_tokens: 默认最大tokens,为None时使用全局配置 - """ - # 保存用户设置或使用全局配置 self.api_provider = api_provider or app_settings.default_ai_provider self.default_model = default_model or app_settings.default_model self.default_temperature = default_temperature or app_settings.default_temperature self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens + self.default_system_prompt = default_system_prompt + self.config = config or default_config - # 使用全局MCP适配器单例 - self.enable_mcp_adapter = enable_mcp_adapter - if enable_mcp_adapter: - self.mcp_adapter = universal_mcp_adapter - logger.info("✅ MCP通用适配器已启用(使用全局单例)") - else: - self.mcp_adapter = None - logger.info("⚠️ MCP适配器已禁用") + self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None - # 初始化OpenAI客户端(使用HTTP客户端池) + self._openai_provider: Optional[OpenAIProvider] = None + self._anthropic_provider: Optional[AnthropicProvider] = None + self._gemini_provider: Optional[GeminiProvider] = None + + # 初始化 OpenAI openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key if openai_key: - try: - base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url - - # 从池中获取或创建HTTP客户端(复用连接) - http_client = _get_or_create_http_client("openai", base_url, openai_key) - - client_kwargs = { - "api_key": openai_key, - "http_client": http_client - } - - if base_url: - client_kwargs["base_url"] = base_url - - self.openai_client = AsyncOpenAI(**client_kwargs) - self.openai_http_client = http_client - self.openai_api_key = openai_key - self.openai_base_url = base_url - logger.info("✅ OpenAI客户端初始化成功(复用HTTP连接)") - except Exception as e: - logger.error(f"OpenAI客户端初始化失败: {e}") - self.openai_client = None - self.openai_http_client = None - self.openai_api_key = None - self.openai_base_url = None - else: - self.openai_client = None - self.openai_http_client = None - self.openai_api_key = None - self.openai_base_url = None - # 只有当用户明确选择OpenAI作为提供商时才警告 - if self.api_provider == "openai": - logger.warning("⚠️ OpenAI API key未配置,但被设置为当前AI提供商") + base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url + client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config) + self._openai_provider = OpenAIProvider(client) - # 初始化Anthropic客户端(使用HTTP客户端池) + # 初始化 Anthropic anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key if anthropic_key: - try: - base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url - - # 从池中获取或创建HTTP客户端(复用连接) - http_client = _get_or_create_http_client("anthropic", base_url, anthropic_key) - - client_kwargs = { - "api_key": anthropic_key, - "http_client": http_client - } - - if base_url: - client_kwargs["base_url"] = base_url - - self.anthropic_client = AsyncAnthropic(**client_kwargs) - logger.info("✅ Anthropic客户端初始化成功(复用HTTP连接)") - except Exception as e: - logger.error(f"Anthropic客户端初始化失败: {e}") - self.anthropic_client = None - else: - self.anthropic_client = None - # 只有当用户明确选择Anthropic作为提供商时才警告 - if self.api_provider == "anthropic": - logger.warning("⚠️ Anthropic API key未配置,但被设置为当前AI提供商") - + base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url + client = AnthropicClient(anthropic_key, base_url, self.config) + self._anthropic_provider = AnthropicProvider(client) + + # 初始化 Gemini + if api_provider == "gemini" and api_key: + client = GeminiClient(api_key, api_base_url, self.config) + self._gemini_provider = GeminiProvider(client) + + def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider: + """获取对应的 Provider""" + p = provider or self.api_provider + if p == "openai" and self._openai_provider: + return self._openai_provider + if p == "anthropic" and self._anthropic_provider: + return self._anthropic_provider + if p == "gemini" and self._gemini_provider: + return self._gemini_provider + raise ValueError(f"Provider {p} 未初始化") + async def generate_text( self, prompt: str, @@ -228,44 +87,21 @@ class AIService: temperature: Optional[float] = None, max_tokens: Optional[int] = None, system_prompt: Optional[str] = None, - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[str] = None + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, ) -> Dict[str, Any]: - """ - 生成文本(支持工具调用) - - Args: - prompt: 用户提示词 - provider: AI提供商 (openai/anthropic) - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大token数 - system_prompt: 系统提示词 - tools: 可用工具列表(MCP工具格式) - tool_choice: 工具选择策略 (auto/required/none) - - Returns: - Dict包含: - - content: 文本内容(如果没有工具调用) - - tool_calls: 工具调用列表(如果AI决定调用工具) - - finish_reason: 完成原因 - """ - provider = provider or self.api_provider - model = model or self.default_model - temperature = temperature or self.default_temperature - max_tokens = max_tokens or self.default_max_tokens - - if provider == "openai": - return await self._generate_openai_with_tools( - prompt, model, temperature, max_tokens, system_prompt, tools, tool_choice - ) - elif provider == "anthropic": - return await self._generate_anthropic_with_tools( - prompt, model, temperature, max_tokens, system_prompt, tools, tool_choice - ) - else: - raise ValueError(f"不支持的AI提供商: {provider}") - + """生成文本""" + prov = self._get_provider(provider) + return await prov.generate( + prompt=prompt, + model=model or self.default_model, + temperature=temperature or self.default_temperature, + max_tokens=max_tokens or self.default_max_tokens, + system_prompt=system_prompt or self.default_system_prompt, + tools=tools, + tool_choice=tool_choice, + ) + async def generate_text_stream( self, prompt: str, @@ -273,593 +109,69 @@ class AIService: model: Optional[str] = None, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, ) -> AsyncGenerator[str, None]: - """ - 流式生成文本 - - Args: - prompt: 用户提示词 - provider: AI提供商 - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大token数 - system_prompt: 系统提示词 - - Yields: - 生成的文本片段 - """ - provider = provider or self.api_provider - model = model or self.default_model - temperature = temperature or self.default_temperature - max_tokens = max_tokens or self.default_max_tokens - - if provider == "openai": - async for chunk in self._generate_openai_stream( - prompt, model, temperature, max_tokens, system_prompt - ): - yield chunk - elif provider == "anthropic": - async for chunk in self._generate_anthropic_stream( - prompt, model, temperature, max_tokens, system_prompt - ): - yield chunk - else: - raise ValueError(f"不支持的AI提供商: {provider}") - - async def _generate_openai( - self, - prompt: str, - model: str, - temperature: float, - max_tokens: int, - system_prompt: Optional[str] - ) -> str: - """使用OpenAI生成文本(带限流和重试)""" - if not self.openai_http_client: - raise ValueError("OpenAI客户端未初始化,请检查API key配置") - - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - # 使用全局信号量限流 - async with _global_semaphore: - # 请求间隔 - await asyncio.sleep(_request_delay) - - # 重试机制 - max_retries = 3 - for attempt in range(max_retries): - try: - if attempt > 0: - wait_time = min(2 ** attempt, 10) # 指数退避 - logger.warning(f"⚠️ OpenAI API调用失败,{wait_time}秒后重试(第{attempt + 1}/{max_retries}次)") - await asyncio.sleep(wait_time) - - logger.info(f"🔵 开始调用OpenAI API(尝试 {attempt + 1}/{max_retries})") - logger.info(f" - 模型: {model}") - logger.info(f" - 温度: {temperature}") - logger.info(f" - 最大tokens: {max_tokens}") - logger.info(f" - Prompt长度: {len(prompt)} 字符") - logger.info(f" - 消息数量: {len(messages)}") - - url = f"{self.openai_base_url}/chat/completions" - headers = { - "Authorization": f"Bearer {self.openai_api_key}", - "Content-Type": "application/json" - } - payload = { - "model": model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens - } - - logger.debug(f" - 请求URL: {url}") - logger.debug(f" - 请求头: Authorization=Bearer ***") - - response = await self.openai_http_client.post(url, headers=headers, json=payload) - response.raise_for_status() - - data = response.json() - - logger.info(f"✅ OpenAI API调用成功") - logger.info(f" - 响应ID: {data.get('id', 'N/A')}") - logger.info(f" - 选项数量: {len(data.get('choices', []))}") - logger.debug(f" - 完整API响应: {data}") - - if not data.get('choices'): - logger.error("❌ OpenAI返回的choices为空") - raise ValueError("API返回的响应格式错误:choices字段为空") - - choice = data['choices'][0] - message = choice.get('message', {}) - finish_reason = choice.get('finish_reason') - - # DeepSeek R1特殊处理:只使用content(最终答案),忽略reasoning_content(思考过程) - # reasoning_content是AI的思考过程,不是我们需要的JSON结果 - content = message.get('content', '') - - # 检查是否因达到长度限制而截断 - if finish_reason == 'length': - logger.warning(f"⚠️ 响应因达到max_tokens限制而被截断") - logger.warning(f" - 当前max_tokens: {max_tokens}") - logger.warning(f" - 建议: 增加max_tokens参数(推荐2000+)") - - if content: - logger.info(f" - 返回内容长度: {len(content)} 字符") - logger.info(f" - 完成原因: {finish_reason}") - logger.info(f" - 返回内容预览(前200字符): {content[:200]}") - return content - else: - logger.error("❌ AI返回了空内容") - logger.error(f" - 完整响应: {data}") - logger.error(f" - 完成原因: {finish_reason}") - - # 提供更详细的错误信息 - if finish_reason == 'length': - raise ValueError(f"AI响应被截断且无有效内容。请增加max_tokens参数(当前: {max_tokens},建议: 2000+)") - else: - raise ValueError(f"AI返回了空内容(finish_reason: {finish_reason}),请检查API配置或稍后重试") - - except httpx.ConnectError as e: - logger.error(f"❌ OpenAI API连接失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}") - if attempt == max_retries - 1: - raise Exception(f"连接失败,已重试{max_retries}次。请检查网络连接或API地址: {str(e)}") - continue - - except httpx.HTTPStatusError as e: - logger.error(f"❌ OpenAI API调用失败 (HTTP {e.response.status_code}, 尝试 {attempt + 1}/{max_retries})") - logger.error(f" - 错误信息: {e.response.text}") - - # 某些错误不需要重试(如401、403) - if e.response.status_code in [401, 403, 404]: - raise Exception(f"API返回错误 ({e.response.status_code}): {e.response.text}") - - if attempt == max_retries - 1: - raise Exception(f"API返回错误 ({e.response.status_code}): {e.response.text}") - continue - - except httpx.TimeoutException as e: - logger.error(f"❌ OpenAI API超时 (尝试 {attempt + 1}/{max_retries})") - if attempt == max_retries - 1: - raise Exception(f"API请求超时,已重试{max_retries}次: {str(e)}") - continue - - except Exception as e: - logger.error(f"❌ OpenAI API调用失败 (尝试 {attempt + 1}/{max_retries})") - logger.error(f" - 错误类型: {type(e).__name__}") - logger.error(f" - 错误信息: {str(e)}") - - if attempt == max_retries - 1: - raise - continue - + """流式生成""" + prov = self._get_provider(provider) + async for chunk in prov.generate_stream( + prompt=prompt, + model=model or self.default_model, + temperature=temperature or self.default_temperature, + max_tokens=max_tokens or self.default_max_tokens, + system_prompt=system_prompt or self.default_system_prompt, + ): + yield chunk - async def _generate_openai_with_tools( + async def call_with_json_retry( self, prompt: str, - model: str, - temperature: float, - max_tokens: int, - system_prompt: Optional[str], - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[str] = None - ) -> Dict[str, Any]: - """使用OpenAI生成文本(支持工具调用,集成MCP适配器)""" - if not self.openai_http_client: - raise ValueError("OpenAI客户端未初始化,请检查API key配置") + system_prompt: Optional[str] = None, + max_retries: int = 3, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + provider: Optional[str] = None, + model: Optional[str] = None, + expected_type: Optional[str] = None, + ) -> Union[Dict, List]: + """带重试的 JSON 调用""" + last_response = "" - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - # 如果启用了MCP适配器且有工具,使用适配器处理 - if self.enable_mcp_adapter and self.mcp_adapter and tools: - logger.info(f"🎯 使用MCP适配器处理工具调用") + for attempt in range(1, max_retries + 1): + current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt) - # 生成API标识符 - api_identifier = f"openai_{self.openai_base_url or 'default'}" + result = await self.generate_text( + prompt=current_prompt, + provider=provider, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + ) - # 定义API调用函数 - async def call_api(message: str, tools_param: Optional[List] = None, tool_choice_param: Optional[str] = None): - """实际调用OpenAI API的函数""" - call_messages = messages.copy() - call_messages[-1]["content"] = message - - url = f"{self.openai_base_url}/chat/completions" - headers = { - "Authorization": f"Bearer {self.openai_api_key}", - "Content-Type": "application/json" - } - payload = { - "model": model, - "messages": call_messages, - "temperature": temperature, - "max_tokens": max_tokens - } - - # 只在tools_param不为None时添加工具参数 - if tools_param is not None: - # 清理工具定义,移除$schema字段(某些API不支持) - cleaned_tools = [] - for tool in tools_param: - cleaned_tool = tool.copy() - if "function" in cleaned_tool and "parameters" in cleaned_tool["function"]: - params = cleaned_tool["function"]["parameters"].copy() - # 移除$schema字段 - params.pop("$schema", None) - cleaned_tool["function"]["parameters"] = params - cleaned_tools.append(cleaned_tool) - - payload["tools"] = cleaned_tools - if tool_choice_param: - payload["tool_choice"] = tool_choice_param - - response = await self.openai_http_client.post(url, headers=headers, json=payload) - response.raise_for_status() - return response.json() - - # 定义测试函数(检测API是否支持Function Calling) - async def test_fc(): - """测试Function Calling支持""" - test_tools = [{ - "type": "function", - "function": { - "name": "test_function", - "description": "测试函数", - "parameters": {"type": "object", "properties": {}} - } - }] - try: - result = await call_api("测试", tools_param=test_tools, tool_choice_param="none") - return result - except Exception as e: - logger.debug(f"Function Calling测试失败: {e}") - raise + last_response = result.get("content", "") try: - # 使用适配器处理(自动检测、降级、缓存) - result = await self.mcp_adapter.call_with_fallback( - api_identifier=api_identifier, - tools=tools, - user_message=prompt, - call_function=call_api, - test_function=test_fc - ) - - # 转换结果格式 - if result.has_tool_calls: - return { - "tool_calls": result.tool_calls, - "content": result.raw_response, - "finish_reason": "tool_calls" - } - else: - return { - "content": result.raw_response, - "finish_reason": "stop" - } - + data = parse_json(last_response) + if expected_type == "object" and not isinstance(data, dict): + raise ValueError("期望对象") + if expected_type == "array" and not isinstance(data, list): + raise ValueError("期望数组") + return data except Exception as e: - logger.error(f"❌ MCP适配器调用失败: {str(e)}") - # 降级到原始实现 - logger.warning("⚠️ 降级到原始OpenAI调用") + if attempt == max_retries: + raise ValueError(f"JSON 解析失败: {e}") - # 原始实现(无适配器或降级) - try: - logger.info(f"🔵 开始调用OpenAI API(原始模式)") - logger.info(f" - 模型: {model}") - logger.info(f" - 工具数量: {len(tools) if tools else 0}") - - url = f"{self.openai_base_url}/chat/completions" - headers = { - "Authorization": f"Bearer {self.openai_api_key}", - "Content-Type": "application/json" - } - payload = { - "model": model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens - } - - # 添加工具参数 - if tools: - payload["tools"] = tools - if tool_choice: - if tool_choice == "required": - payload["tool_choice"] = "required" - elif tool_choice == "auto": - payload["tool_choice"] = "auto" - elif tool_choice == "none": - payload["tool_choice"] = "none" - - response = await self.openai_http_client.post(url, headers=headers, json=payload) - response.raise_for_status() - - data = response.json() - - logger.info(f"✅ OpenAI API调用成功") - logger.debug(f" - 完整API响应: {data}") - - if not data.get('choices'): - logger.error(f"❌ API返回的choices为空") - logger.error(f" - 完整响应: {data}") - logger.error(f" - 响应键: {list(data.keys())}") - raise ValueError(f"API返回的响应格式错误:choices字段为空。完整响应: {data}") - - choice = data['choices'][0] - message = choice.get('message', {}) - finish_reason = choice.get('finish_reason') - - # 检查是否有工具调用 - tool_calls = message.get('tool_calls') - if tool_calls: - logger.info(f"🔧 AI请求调用 {len(tool_calls)} 个工具") - return { - "tool_calls": tool_calls, - "content": message.get('content', ''), - "finish_reason": finish_reason - } - - # 没有工具调用,返回普通内容 - content = message.get('content', '') - if content: - return { - "content": content, - "finish_reason": finish_reason - } - else: - raise ValueError(f"AI返回了空内容(finish_reason: {finish_reason})") - - except httpx.HTTPStatusError as e: - logger.error(f"❌ OpenAI API调用失败 (HTTP {e.response.status_code})") - logger.error(f" - 错误信息: {e.response.text}") - raise Exception(f"API返回错误 ({e.response.status_code}): {e.response.text}") - except Exception as e: - logger.error(f"❌ OpenAI API调用失败: {str(e)}") - raise + raise ValueError("JSON 调用失败") - async def _generate_anthropic_with_tools( - self, - prompt: str, - model: str, - temperature: float, - max_tokens: int, - system_prompt: Optional[str], - tools: Optional[List[Dict[str, Any]]] = None, - tool_choice: Optional[str] = None - ) -> Dict[str, Any]: - """使用Anthropic生成文本(支持工具调用)""" - if not self.anthropic_client: - raise ValueError("Anthropic客户端未初始化,请检查API key配置") - - try: - logger.info(f"🔵 开始调用Anthropic API(支持工具调用)") - logger.info(f" - 模型: {model}") - logger.info(f" - 工具数量: {len(tools) if tools else 0}") - - kwargs = { - "model": model, - "max_tokens": max_tokens, - "temperature": temperature, - "messages": [{"role": "user", "content": prompt}] - } - - if system_prompt: - kwargs["system"] = system_prompt - - # 添加工具参数 - if tools: - kwargs["tools"] = tools - if tool_choice == "required": - kwargs["tool_choice"] = {"type": "any"} - elif tool_choice == "auto": - kwargs["tool_choice"] = {"type": "auto"} - - response = await self.anthropic_client.messages.create(**kwargs) - - # 检查是否有工具调用 - tool_calls = [] - content_text = "" - - for block in response.content: - if block.type == "tool_use": - tool_calls.append({ - "id": block.id, - "type": "function", - "function": { - "name": block.name, - "arguments": block.input - } - }) - elif block.type == "text": - content_text += block.text - - if tool_calls: - logger.info(f"🔧 AI请求调用 {len(tool_calls)} 个工具") - return { - "tool_calls": tool_calls, - "content": content_text, - "finish_reason": response.stop_reason - } - - return { - "content": content_text, - "finish_reason": response.stop_reason - } - - except Exception as e: - logger.error(f"❌ Anthropic API调用失败: {str(e)}") - raise + @staticmethod + def _add_json_hint(prompt: str, failed: str, attempt: int) -> str: + return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..." + + @staticmethod + def _clean_json_response(text: str) -> str: + """清洗 JSON 响应""" + return clean_json_response(text) - async def _generate_openai_stream( - self, - prompt: str, - model: str, - temperature: float, - max_tokens: int, - system_prompt: Optional[str] - ) -> AsyncGenerator[str, None]: - """使用OpenAI流式生成文本""" - if not self.openai_http_client: - raise ValueError("OpenAI客户端未初始化,请检查API key配置") - - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - try: - logger.info(f"🔵 开始调用OpenAI流式API(直接HTTP请求)") - logger.info(f" - 模型: {model}") - logger.info(f" - Prompt长度: {len(prompt)} 字符") - logger.info(f" - 最大tokens: {max_tokens}") - - url = f"{self.openai_base_url}/chat/completions" - headers = { - "Authorization": f"Bearer {self.openai_api_key}", - "Content-Type": "application/json" - } - payload = { - "model": model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "stream": True - } - - async with self.openai_http_client.stream('POST', url, headers=headers, json=payload) as response: - response.raise_for_status() - logger.info(f"✅ OpenAI流式API连接成功,开始接收数据...") - - chunk_count = 0 - has_content = False - finish_reason = None - - async for line in response.aiter_lines(): - if line.startswith('data: '): - data_str = line[6:] - if data_str.strip() == '[DONE]': - break - - try: - import json - data = json.loads(data_str) - if 'choices' in data and len(data['choices']) > 0: - choice = data['choices'][0] - delta = choice.get('delta', {}) - finish_reason = choice.get('finish_reason') or finish_reason - - # DeepSeek R1特殊处理:只收集content(最终答案),忽略reasoning_content(思考过程) - # reasoning_content是AI的思考过程,不是我们需要的JSON结果 - content = delta.get('content', '') - - if content: - chunk_count += 1 - has_content = True - yield content - except json.JSONDecodeError: - continue - - # 检查是否因长度限制截断 - if finish_reason == 'length': - logger.warning(f"⚠️ 流式响应因达到max_tokens限制而被截断") - logger.warning(f" - 当前max_tokens: {max_tokens}") - logger.warning(f" - 建议: 增加max_tokens参数(推荐2000+)") - - if not has_content: - logger.warning(f"⚠️ 流式响应未返回任何内容") - logger.warning(f" - 完成原因: {finish_reason}") - - logger.info(f"✅ OpenAI流式生成完成,共接收 {chunk_count} 个chunk,完成原因: {finish_reason}") - - except httpx.TimeoutException as e: - logger.error(f"❌ OpenAI流式API超时") - logger.error(f" - 错误: {str(e)}") - logger.error(f" - 提示: 请检查网络连接或考虑缩短prompt长度") - raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e - except httpx.HTTPStatusError as e: - logger.error(f"❌ OpenAI流式API调用失败 (HTTP {e.response.status_code})") - logger.error(f" - 错误信息: {await e.response.aread()}") - raise - except Exception as e: - logger.error(f"❌ OpenAI流式API调用失败: {str(e)}") - logger.error(f" - 错误类型: {type(e).__name__}") - raise - - async def _generate_anthropic( - self, - prompt: str, - model: str, - temperature: float, - max_tokens: int, - system_prompt: Optional[str] - ) -> str: - """使用Anthropic生成文本""" - if not self.anthropic_client: - raise ValueError("Anthropic客户端未初始化,请检查API key配置") - - try: - response = await self.anthropic_client.messages.create( - model=model, - max_tokens=max_tokens, - temperature=temperature, - system=system_prompt or "", - messages=[{"role": "user", "content": prompt}] - ) - return response.content[0].text - except Exception as e: - logger.error(f"Anthropic API调用失败: {str(e)}") - raise - - async def _generate_anthropic_stream( - self, - prompt: str, - model: str, - temperature: float, - max_tokens: int, - system_prompt: Optional[str] - ) -> AsyncGenerator[str, None]: - """使用Anthropic流式生成文本""" - if not self.anthropic_client: - raise ValueError("Anthropic客户端未初始化,请检查API key配置") - - try: - logger.info(f"🔵 开始调用Anthropic流式API") - logger.info(f" - 模型: {model}") - logger.info(f" - Prompt长度: {len(prompt)} 字符") - logger.info(f" - 最大tokens: {max_tokens}") - - async with self.anthropic_client.messages.stream( - model=model, - max_tokens=max_tokens, - temperature=temperature, - system=system_prompt or "", - messages=[{"role": "user", "content": prompt}] - ) as stream: - logger.info(f"✅ Anthropic流式API连接成功,开始接收数据...") - - chunk_count = 0 - async for text in stream.text_stream: - chunk_count += 1 - yield text - - logger.info(f"✅ Anthropic流式生成完成,共接收 {chunk_count} 个chunk") - - except httpx.TimeoutException as e: - logger.error(f"❌ Anthropic流式API超时") - logger.error(f" - 错误: {str(e)}") - raise TimeoutError(f"AI服务超时(180秒),请稍后重试或减少上下文长度") from e - except Exception as e: - logger.error(f"❌ Anthropic流式API调用失败: {str(e)}") - logger.error(f" - 错误类型: {type(e).__name__}") - raise - async def generate_text_with_mcp( self, prompt: str, @@ -870,515 +182,96 @@ class AIService: tool_choice: str = "auto", **kwargs ) -> Dict[str, Any]: - """ - 支持MCP工具的AI文本生成(非流式) - - Args: - prompt: 用户提示词 - user_id: 用户ID,用于获取MCP工具 - db_session: 数据库会话 - enable_mcp: 是否启用MCP增强 - max_tool_rounds: 最大工具调用轮次 - tool_choice: 工具选择策略(auto/required/none) - **kwargs: 其他AI参数(provider, model, temperature等) - - Returns: - { - "content": "AI生成的最终文本", - "tool_calls_made": 2, # 实际调用的工具次数 - "tools_used": ["exa_search", "filesystem_read"], - "finish_reason": "stop", - "mcp_enhanced": True - } - """ + """支持MCP工具的AI文本生成""" from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError - # 初始化返回结果 - result = { - "content": "", - "tool_calls_made": 0, - "tools_used": [], - "finish_reason": "", - "mcp_enhanced": False - } - - # 1. 获取MCP工具(如果启用) + result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False} tools = None + if enable_mcp: try: - tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db_session - ) + tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session) if tools: - logger.info(f"MCP增强: 加载了 {len(tools)} 个工具") result["mcp_enhanced"] = True - except MCPToolServiceError as e: - logger.error(f"获取MCP工具失败,降级为普通生成: {e}") + except MCPToolServiceError: tools = None - # 2. 工具调用循环 - conversation_history = [ - {"role": "user", "content": prompt} - ] + original_prompt = prompt # 保存原始提示词 for round_num in range(max_tool_rounds): - logger.info(f"MCP工具调用轮次: {round_num + 1}/{max_tool_rounds}") + logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}轮") + logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}") - # 调用AI - ai_response = await self.generate_text( - prompt=conversation_history[-1]["content"], - tools=tools if round_num == 0 else None, # 只在第一轮传递工具 - tool_choice=tool_choice if round_num == 0 else None, - **kwargs - ) + ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs) + logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}") - # 检查是否有工具调用 tool_calls = ai_response.get("tool_calls", []) if not tool_calls: - # AI返回最终内容 - result["content"] = ai_response.get("content", "") + content = ai_response.get("content", "") + result["content"] = content result["finish_reason"] = ai_response.get("finish_reason", "stop") + logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}") + + # 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案 + if not content.strip() and result["tool_calls_made"] > 0: + logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)") + prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。" + tools = None + tool_choice = "none" # 强制不使用工具 + continue + break - # 3. 执行工具调用 - logger.info(f"AI请求调用 {len(tool_calls)} 个工具") + logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用") + for idx, tc in enumerate(tool_calls): + logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}") try: - tool_results = await mcp_tool_service.execute_tool_calls( - user_id=user_id, - tool_calls=tool_calls, - db_session=db_session - ) + logger.debug(f" 开始执行工具调用...") + tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session) + logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}") - # 记录使用的工具 - for tool_call in tool_calls: - tool_name = tool_call["function"]["name"] - if tool_name not in result["tools_used"]: - result["tools_used"].append(tool_name) + # 🔍 检查工具结果 + for idx, tr in enumerate(tool_results): + success = tr.get("success", False) + content_preview = tr.get("content", "")[:200] if tr.get("content") else "None" + logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}") + for tc in tool_calls: + name = tc["function"]["name"] + if name not in result["tools_used"]: + result["tools_used"].append(name) result["tool_calls_made"] += len(tool_calls) - # 4. 构建工具上下文 - tool_context = await mcp_tool_service.build_tool_context( - tool_results, - format="markdown" - ) + tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown") + logger.debug(f" 工具上下文长度: {len(tool_context)}") + logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}") - # 5. 更新对话历史 - conversation_history.append({ - "role": "assistant", - "content": ai_response.get("content", ""), - "tool_calls": tool_calls - }) + # 🔧 改进:在最后一轮时,明确要求AI给出完整答案 + if round_num == max_tool_rounds - 1: + logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案") + prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。" + tool_choice = "none" + else: + prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。" + logger.debug(f" 新prompt长度: {len(prompt)}") - for tool_result in tool_results: - conversation_history.append({ - "role": "tool", - "tool_call_id": tool_result["tool_call_id"], - "content": tool_result["content"] - }) + tools = None # 工具调用后禁用工具列表,避免重复调用 + logger.debug(f" ✅ 工具调用成功,准备下一轮") - # 6. 构建下一轮提示 - next_prompt = ( - f"{prompt}\n\n" - f"{tool_context}\n\n" - f"请基于以上工具查询结果,继续完成任务。" - ) - conversation_history.append({ - "role": "user", - "content": next_prompt - }) - - except Exception as e: - logger.error(f"执行MCP工具失败: {e}", exc_info=True) - # 降级:返回当前AI响应 + except Exception as tool_error: + logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True) + logger.error(f" 错误类型: {type(tool_error).__name__}") + logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}") result["content"] = ai_response.get("content", "") result["finish_reason"] = "tool_error" break - else: - # 达到最大轮次 - logger.info(f"达到MCP最大调用轮次 {max_tool_rounds}") - result["content"] = conversation_history[-1].get("content", "") - result["finish_reason"] = "max_rounds" - return result - - async def generate_text_stream_with_mcp( - self, - prompt: str, - user_id: str, - db_session, - enable_mcp: bool = True, - mcp_planning_prompt: Optional[str] = None, - **kwargs - ) -> AsyncGenerator[str, None]: - """ - 支持MCP工具的AI流式文本生成(两阶段模式) - - Args: - prompt: 用户提示词 - user_id: 用户ID - db_session: 数据库会话 - enable_mcp: 是否启用MCP增强 - mcp_planning_prompt: MCP规划阶段的提示词(可选) - **kwargs: 其他AI参数 - - Yields: - 流式文本chunk - """ - from app.services.mcp_tool_service import mcp_tool_service - - # 阶段1: 工具调用阶段(非流式) - enhanced_prompt = prompt - - if enable_mcp: - try: - # 获取MCP工具 - tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db_session - ) - - if tools: - logger.info(f"MCP增强(流式): 加载了 {len(tools)} 个工具") - - # 使用规划提示让AI决定需要查询什么 - if not mcp_planning_prompt: - mcp_planning_prompt = ( - f"任务: {prompt}\n\n" - f"请分析这个任务,决定是否需要查询外部信息。" - f"如果需要,请调用相应的工具获取信息。" - ) - - # 非流式调用获取工具结果 - planning_result = await self.generate_text_with_mcp( - prompt=mcp_planning_prompt, - user_id=user_id, - db_session=db_session, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - **kwargs - ) - - # 如果有工具调用,将结果融入提示 - if planning_result["tool_calls_made"] > 0: - enhanced_prompt = ( - f"{prompt}\n\n" - f"【参考资料】\n" - f"{planning_result.get('content', '')}" - ) - logger.info( - f"MCP工具规划完成,调用了 " - f"{planning_result['tool_calls_made']} 次工具" - ) - - except Exception as e: - logger.error(f"MCP工具规划失败,使用原始提示: {e}") - - # 阶段2: 内容生成阶段(流式) - async for chunk in self.generate_text_stream( - prompt=enhanced_prompt, - **kwargs - ): - yield chunk - - # ========== JSON 统一调用和自动重试 ========== - - @staticmethod - def _clean_json_response(text: str) -> str: - """ - 清洗 AI 返回的 JSON 响应 - - 去除常见的格式问题: - - markdown 代码块标记 (```json ```) - - 前后空白字符 - - 注释文字 - - Args: - text: AI 返回的原始文本 - - Returns: - 清洗后的 JSON 字符串 - """ - if not text: - return text - - # 去除 markdown 代码块标记 - text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE) - text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE) - text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE) - - # 去除前后空白 - text = text.strip() - - # 尝试提取第一个完整的 JSON 对象或数组 - # 查找第一个 { 或 [ - start_idx = -1 - for i, char in enumerate(text): - if char in ('{', '['): - start_idx = i - break - - if start_idx == -1: - return text - - # 从第一个括号开始提取 - text = text[start_idx:] - - # 查找匹配的结束括号 - bracket_stack = [] - end_idx = -1 - in_string = False - escape_next = False - - for i, char in enumerate(text): - if escape_next: - escape_next = False - continue - - if char == '\\': - escape_next = True - continue - - if char == '"': - in_string = not in_string - continue - - if in_string: - continue - - if char in ('{', '['): - bracket_stack.append(char) - elif char == '}': - if bracket_stack and bracket_stack[-1] == '{': - bracket_stack.pop() - if not bracket_stack: - end_idx = i + 1 - break - elif char == ']': - if bracket_stack and bracket_stack[-1] == '[': - bracket_stack.pop() - if not bracket_stack: - end_idx = i + 1 - break - - if end_idx > 0: - return text[:end_idx] - - return text - - @staticmethod - def _add_json_format_hint(original_prompt: str, failed_response: str, attempt: int) -> str: - """ - 重试时添加格式纠正提示 - - Args: - original_prompt: 原始提示词 - failed_response: 上次失败的响应(截断显示) - attempt: 当前尝试次数 - - Returns: - 增强后的提示词 - """ - error_preview = failed_response[:300] if failed_response else "无响应" - - return f"""{original_prompt} - -⚠️ 【第 {attempt} 次重试】上一次返回格式错误,请严格遵守以下规则: - -🔴 格式要求(必须严格遵守): -1. 只返回纯 JSON 对象或数组,不要有任何其他文字 -2. 不要使用 ```json``` 或 ``` 包裹 JSON -3. 不要添加任何解释、说明或注释 -4. 确保 JSON 格式完全正确: - - 所有括号必须匹配 {{}} [] - - 所有字符串必须用双引号 "" - - 键值对用冒号分隔 : - - 多个元素用逗号分隔 , - - 不要有多余的逗号 - -❌ 上一次的错误返回示例: -{error_preview}... - -✅ 请现在重新生成正确的 JSON 格式内容。""" - - async def call_with_json_retry( - self, - prompt: str, - system_prompt: Optional[str] = None, - max_retries: int = 3, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - provider: Optional[str] = None, - model: Optional[str] = None, - expected_type: Optional[str] = None # "object" 或 "array" - ) -> Dict[str, Any] | List[Dict[str, Any]]: - """ - 统一的 JSON 调用方法,自动重试和格式修复 - - 这是一个专门用于需要返回 JSON 格式的 AI 调用封装,会自动: - 1. 清洗 AI 返回的内容(去除 markdown 标记等) - 2. 解析 JSON 并验证格式 - 3. 失败时自动重试,并在提示词中添加纠正指引 - - Args: - prompt: 用户提示词 - system_prompt: 系统提示词(可选) - max_retries: 最大重试次数,默认 3 次 - temperature: 温度参数(可选,使用默认值) - max_tokens: 最大 token 数(可选,使用默认值) - provider: AI 提供商(可选,使用默认值) - model: 模型名称(可选,使用默认值) - expected_type: 期望的 JSON 类型 "object" 或 "array"(可选,用于额外验证) - - Returns: - 解析后的 JSON 对象(dict)或数组(list) - - Raises: - ValueError: 重试次数用尽仍未获得有效 JSON - - Examples: - >>> # 获取 JSON 对象 - >>> result = await ai_service.call_with_json_retry( - ... prompt="生成一个角色", - ... expected_type="object" - ... ) - >>> print(result["name"]) - - >>> # 获取 JSON 数组 - >>> results = await ai_service.call_with_json_retry( - ... prompt="生成3个角色", - ... expected_type="array" - ... ) - >>> print(len(results)) - """ - last_error = None - last_response = "" - - for attempt in range(1, max_retries + 1): - try: - logger.info(f"🔄 JSON 调用尝试 {attempt}/{max_retries}") - - # 第一次使用原始提示词,之后使用增强提示词 - current_prompt = prompt if attempt == 1 else self._add_json_format_hint( - prompt, last_response, attempt - ) - - # 调用 AI 生成内容 - if provider == "openai" and self.openai_client: - response = await self._generate_openai( - prompt=current_prompt, - model=model or self.default_model, - temperature=temperature or self.default_temperature, - max_tokens=max_tokens or self.default_max_tokens, - system_prompt=system_prompt - ) - elif provider == "anthropic" and self.anthropic_client: - response = await self._generate_anthropic( - prompt=current_prompt, - model=model or self.default_model, - temperature=temperature or self.default_temperature, - max_tokens=max_tokens or self.default_max_tokens, - system_prompt=system_prompt - ) - else: - # 使用默认提供商 - if self.api_provider == "openai": - response = await self._generate_openai( - prompt=current_prompt, - model=model or self.default_model, - temperature=temperature or self.default_temperature, - max_tokens=max_tokens or self.default_max_tokens, - system_prompt=system_prompt - ) - else: - response = await self._generate_anthropic( - prompt=current_prompt, - model=model or self.default_model, - temperature=temperature or self.default_temperature, - max_tokens=max_tokens or self.default_max_tokens, - system_prompt=system_prompt - ) - - last_response = response - - # 清洗响应内容 - cleaned = self._clean_json_response(response) - logger.debug(f"清洗后的内容: {cleaned[:200]}...") - - # 解析 JSON - try: - data = json.loads(cleaned) - except json.JSONDecodeError as e: - logger.warning(f"⚠️ JSON 解析失败: {e}") - logger.debug(f"原始响应: {response[:500]}") - logger.debug(f"清洗后: {cleaned[:500]}") - raise - - # 可选:验证 JSON 类型 - if expected_type: - if expected_type == "object" and not isinstance(data, dict): - raise ValueError(f"期望 JSON 对象,但得到 {type(data).__name__}") - elif expected_type == "array" and not isinstance(data, list): - raise ValueError(f"期望 JSON 数组,但得到 {type(data).__name__}") - - logger.info(f"✅ JSON 解析成功 (尝试 {attempt}/{max_retries})") - if isinstance(data, dict): - logger.info(f" 返回对象,包含 {len(data)} 个键") - elif isinstance(data, list): - logger.info(f" 返回数组,包含 {len(data)} 个元素") - - return data - - except json.JSONDecodeError as e: - last_error = e - logger.warning(f"⚠️ 第 {attempt} 次尝试失败: JSON 解析错误") - logger.warning(f" 错误位置: {e.msg} at line {e.lineno} column {e.colno}") - - if attempt < max_retries: - logger.info(f" 准备第 {attempt + 1} 次重试...") - continue - else: - logger.error(f"❌ JSON 解析失败,已达到最大重试次数 {max_retries}") - logger.error(f" 最后的响应内容:\n{last_response[:1000]}") - raise ValueError( - f"AI 返回内容无法解析为 JSON,已重试 {max_retries} 次。\n" - f"最后错误: {e}\n" - f"响应预览: {last_response[:200]}..." - ) - - except ValueError as e: - last_error = e - logger.warning(f"⚠️ 第 {attempt} 次尝试失败: {e}") - - if attempt < max_retries: - logger.info(f" 准备第 {attempt + 1} 次重试...") - continue - else: - logger.error(f"❌ 验证失败,已达到最大重试次数 {max_retries}") - raise ValueError( - f"AI 返回的 JSON 格式不符合要求,已重试 {max_retries} 次。\n" - f"错误: {e}" - ) - - except Exception as e: - logger.error(f"❌ 第 {attempt} 次调用出现未预期错误: {type(e).__name__}: {e}") - if attempt < max_retries: - logger.info(f" 准备第 {attempt + 1} 次重试...") - last_error = e - continue - else: - raise - - # 理论上不会到达这里,但以防万一 - raise ValueError(f"JSON 调用失败,已重试 {max_retries} 次。最后错误: {last_error}") -# 创建全局AI服务实例 +# 全局实例 ai_service = AIService() @@ -1388,27 +281,16 @@ def create_user_ai_service( api_base_url: str, model_name: str, temperature: float, - max_tokens: int + max_tokens: int, + system_prompt: Optional[str] = None, ) -> AIService: - """ - 根据用户设置创建AI服务实例 - - Args: - api_provider: API提供商 - api_key: API密钥 - api_base_url: API基础URL - model_name: 模型名称 - temperature: 温度参数 - max_tokens: 最大tokens - - Returns: - AIService实例 - """ + """创建用户 AI 服务""" return AIService( api_provider=api_provider, api_key=api_key, api_base_url=api_base_url, default_model=model_name, default_temperature=temperature, - default_max_tokens=max_tokens + default_max_tokens=max_tokens, + default_system_prompt=system_prompt, ) \ No newline at end of file diff --git a/backend/app/services/auto_character_service.py b/backend/app/services/auto_character_service.py index 6511aa8..84bd545 100644 --- a/backend/app/services/auto_character_service.py +++ b/backend/app/services/auto_character_service.py @@ -263,7 +263,7 @@ class AutoCharacterService: user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1 + max_tool_rounds=2 ) content = result.get("content", "") # 使用统一的JSON清洗方法 @@ -362,7 +362,7 @@ class AutoCharacterService: user_id=user_id, db_session=db, enable_mcp=True, - max_tool_rounds=1 + max_tool_rounds=2 ) content = result.get("content", "") # 使用统一的JSON清洗方法 diff --git a/backend/app/services/json_helper.py b/backend/app/services/json_helper.py new file mode 100644 index 0000000..2662235 --- /dev/null +++ b/backend/app/services/json_helper.py @@ -0,0 +1,149 @@ +"""JSON 处理工具类""" +import json +import re +from typing import Any, Dict, List, Union +from app.logger import get_logger + +logger = get_logger(__name__) + + +def clean_json_response(text: str) -> str: + """清洗 AI 返回的 JSON(改进版 - 流式安全)""" + try: + if not text: + logger.warning("⚠️ clean_json_response: 输入为空") + return text + + original_length = len(text) + logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}") + + # 去除 markdown 代码块 + text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE) + text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE) + text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE) + text = text.strip() + + if len(text) != original_length: + logger.debug(f" 移除markdown后长度: {len(text)}") + + # 尝试直接解析(快速路径) + try: + json.loads(text) + logger.debug(f"✅ 直接解析成功,无需清洗") + return text + except: + pass + + # 找到第一个 { 或 [ + start = -1 + for i, c in enumerate(text): + if c in ('{', '['): + start = i + break + + if start == -1: + logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [") + logger.debug(f" 文本预览: {text[:200]}") + return text + + if start > 0: + logger.debug(f" 跳过前{start}个字符") + text = text[start:] + + # 改进的括号匹配算法(更严格的字符串处理) + stack = [] + i = 0 + end = -1 + + while i < len(text): + c = text[i] + + # 处理字符串(关键:正确处理转义) + if c == '"': + # 计算前面有多少个连续的反斜杠 + num_backslashes = 0 + j = i - 1 + while j >= 0 and text[j] == '\\': + num_backslashes += 1 + j -= 1 + + # 偶数个反斜杠(包括0)表示引号未被转义 + if num_backslashes % 2 == 0: + # 这是字符串边界,跳过整个字符串 + i += 1 + while i < len(text): + if text[i] == '"': + # 再次检查转义 + num_backslashes = 0 + j = i - 1 + while j >= 0 and text[j] == '\\': + num_backslashes += 1 + j -= 1 + if num_backslashes % 2 == 0: + # 字符串结束 + break + i += 1 + i += 1 + continue + + # 处理括号(只有在字符串外部才有效) + if c == '{' or c == '[': + stack.append(c) + elif c == '}': + if len(stack) > 0 and stack[-1] == '{': + stack.pop() + if len(stack) == 0: + end = i + 1 + logger.debug(f"✅ 找到JSON结束位置: {end}") + break + else: + logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1] if stack else 'empty'}") + elif c == ']': + if len(stack) > 0 and stack[-1] == '[': + stack.pop() + if len(stack) == 0: + end = i + 1 + logger.debug(f"✅ 找到JSON结束位置: {end}") + break + else: + logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1] if stack else 'empty'}") + + i += 1 + + # 提取结果 + if end > 0: + result = text[:end] + logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}") + else: + result = text + logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)})") + logger.debug(f" 栈状态: {stack}") + + # 验证清洗后的结果 + try: + json.loads(result) + logger.debug(f"✅ 清洗后JSON验证成功") + except json.JSONDecodeError as e: + logger.error(f"❌ 清洗后JSON仍然无效: {e}") + logger.debug(f" 结果预览: {result[:500]}") + logger.debug(f" 结果结尾: ...{result[-200:]}") + + return result + + except Exception as e: + logger.error(f"❌ clean_json_response 出错: {e}") + logger.error(f" 文本长度: {len(text) if text else 0}") + logger.error(f" 文本预览: {text[:200] if text else 'None'}") + raise + + +def parse_json(text: str) -> Union[Dict, List]: + """解析 JSON""" + try: + cleaned = clean_json_response(text) + return json.loads(cleaned) + except Exception as e: + logger.error(f"❌ parse_json 出错: {e}") + logger.error(f" 原始文本长度: {len(text) if text else 0}") + logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}") + raise \ No newline at end of file diff --git a/backend/app/services/mcp_test_service.py b/backend/app/services/mcp_test_service.py index bb88c99..f061bba 100644 --- a/backend/app/services/mcp_test_service.py +++ b/backend/app/services/mcp_test_service.py @@ -175,20 +175,34 @@ class MCPTestService: db=db_session ) - ai_response = await ai_service.generate_text( + # 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下 + # AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks + # 因此这里需要特殊处理 + accumulated_text = "" + tool_calls = None + + async for chunk in ai_service.generate_text_stream( prompt=prompts["user"], system_prompt=prompts["system"], tools=openai_tools, tool_choice="required" - ) + ): + # 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls + if isinstance(chunk, dict): + if "tool_calls" in chunk: + tool_calls = chunk["tool_calls"] + if "content" in chunk: + accumulated_text += chunk.get("content", "") + else: + accumulated_text += chunk # 5. 检查AI是否返回工具调用 - if not ai_response.get("tool_calls"): + if not tool_calls: logger.error(f"❌ AI未返回工具调用") return MCPTestResult( success=False, message="❌ AI Function Calling失败", - error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}", + error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}", tools_count=len(tools), suggestions=[ "请确认使用的AI模型支持Function Calling", @@ -198,7 +212,7 @@ class MCPTestService: ) # 6. 解析工具调用 - tool_call = ai_response["tool_calls"][0] + tool_call = tool_calls[0] function = tool_call["function"] tool_name = function["name"] test_arguments = function["arguments"] diff --git a/backend/app/services/mcp_tool_service.py b/backend/app/services/mcp_tool_service.py index aa7cde6..6d8c315 100644 --- a/backend/app/services/mcp_tool_service.py +++ b/backend/app/services/mcp_tool_service.py @@ -386,17 +386,30 @@ class MCPToolService: try: # 解析插件名和工具名 + logger.debug(f"🔍 解析工具名称: {function_name}") if "_" in function_name: plugin_name, tool_name = function_name.split("_", 1) + logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}") else: raise ValueError(f"无效的工具名称格式: {function_name}") # 解析参数 arguments_str = tool_call["function"]["arguments"] + logger.debug(f"🔍 解析参数:") + logger.debug(f" 原始类型: {type(arguments_str)}") + logger.debug(f" 原始内容: {arguments_str}") + if isinstance(arguments_str, str): - arguments = json.loads(arguments_str) + try: + arguments = json.loads(arguments_str) + logger.debug(f" ✅ JSON解析成功: {arguments}") + except json.JSONDecodeError as je: + logger.error(f" ❌ JSON解析失败: {je}") + logger.error(f" 原始字符串: '{arguments_str}'") + raise ValueError(f"参数JSON解析失败: {je}") else: arguments = arguments_str + logger.debug(f" 直接使用dict类型参数") logger.info( f"执行工具: {plugin_name}.{tool_name}, " diff --git a/backend/app/services/plot_analyzer.py b/backend/app/services/plot_analyzer.py index f5866c2..89441b8 100644 --- a/backend/app/services/plot_analyzer.py +++ b/backend/app/services/plot_analyzer.py @@ -71,24 +71,15 @@ class PlotAnalyzer: # 调用AI进行分析 # 注意:不指定max_tokens,使用用户在设置中配置的值 logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...") - response = await self.ai_service.generate_text( + accumulated_text = "" + async for chunk in self.ai_service.generate_text_stream( prompt=prompt, temperature=0.3 # 降低温度以获得更稳定的JSON输出 - ) + ): + accumulated_text += chunk - # 🔍 添加调试日志:查看AI返回的原始内容 - # logger.info(f"🔍 AI返回类型: {type(response)}") - # logger.info(f"🔍 AI返回内容(前500字符): {str(response)}") - - # 从返回的字典中提取content字段 - if isinstance(response, dict): - response_text = response.get('content', '') - if not response_text: - logger.error("❌ AI返回的字典中没有content字段或content为空") - return None - else: - # 兼容旧的字符串返回格式 - response_text = response + # 提取内容 + response_text = accumulated_text # 解析JSON结果 analysis_result = self._parse_analysis_response(response_text) diff --git a/backend/app/services/plot_expansion_service.py b/backend/app/services/plot_expansion_service.py index 07eb5e1..3d9320b 100644 --- a/backend/app/services/plot_expansion_service.py +++ b/backend/app/services/plot_expansion_service.py @@ -133,14 +133,16 @@ class PlotExpansionService: # 调用AI生成章节规划 logger.info(f"调用AI生成章节规划...") - ai_response = await self.ai_service.generate_text( + accumulated_text = "" + async for chunk in self.ai_service.generate_text_stream( prompt=prompt, provider=provider, model=model - ) + ): + accumulated_text += chunk # 提取内容 - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response + ai_content = accumulated_text # 解析AI响应 chapter_plans = self._parse_expansion_response(ai_content, outline.id) @@ -236,14 +238,16 @@ class PlotExpansionService: # 调用AI生成当前批次 logger.info(f"调用AI生成第{batch_num + 1}批...") - ai_response = await self.ai_service.generate_text( + accumulated_text = "" + async for chunk in self.ai_service.generate_text_stream( prompt=prompt, provider=provider, model=model - ) + ): + accumulated_text += chunk # 提取内容 - ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response + ai_content = accumulated_text # 解析AI响应 batch_plans = self._parse_expansion_response(ai_content, outline.id) diff --git a/backend/app/services/prompt_service.py b/backend/app/services/prompt_service.py index 9768a1c..40580dc 100644 --- a/backend/app/services/prompt_service.py +++ b/backend/app/services/prompt_service.py @@ -6,142 +6,6 @@ import json class WritingStyleManager: """写作风格管理器""" - # 预设风格配置 - PRESET_STYLES = { - "natural": { - "name": "自然沉浸 (Natural & Immersive)", - "description": "祛除翻译腔,强调生活质感,像呼吸一样自然的叙事", - "prompt_content": """ -### 核心指令:自然沉浸风格 -请模拟人类作家在放松状态下的写作,通过以下规则消除“AI味”: - -1. **拒绝翻译腔与书面化**: - - 严禁使用“一种...的感觉”、“随着...”、“与此同时”等连接词。 - - 多用短句和“流水句”,模拟人类视线的移动和思维的跳跃。 - - 口语化叙述,但不要滥用语气词,而是通过句子的长短节奏来体现语气。 - -2. **生活化的颗粒度**: - - 描写不要宏大,要聚焦在具体的、微小的生活细节(如:杯子上的水渍、衣服的褶皱)。 - - 允许逻辑上的适度“松散”,不要让每句话都像说明书一样严丝合缝。 - -3. **具体的“展示”**: - - 不要写“他很生气”,要写他“把烟头按灭在还没吃完的米饭里”。 - - 避免使用抽象的形容词(如:巨大的、美丽的、悲伤的),必须用名词和动词来承载画面。 -""" - }, - "classical": { - "name": "古典雅致 (Classical & Elegant)", - "description": "白话文与古典韵味的结合,强调留白与炼字", - "prompt_content": """ -### 核心指令:古典雅致风格 -请模仿民国时期或古典白话小说的笔触,构建端庄且富有余味的叙事: - -1. **炼字与韵律**: - - 尽量使用双音节词或四字短语,但严禁堆砌辞藻。 - - 注重句子的声调韵律,读起来要有金石之声或流水之韵。 - - 适当使用倒装句或定语后置,增加古雅感。 - -2. **克制的修辞**: - - 少用现代的比喻(如“像机器一样”),多用取自自然的比喻(如“如风过林”)。 - - **意在言外**:不要把话说透,留三分余地。写景即是写情,不要将情感直接剖白。 - -3. **禁忌**: - - 严禁使用现代科技词汇(除非题材需要)、网络用语或过于西化的句式(如长定语从句)。 - - 避免滥用“之乎者也”,追求的是“神似”而非生硬的半文半白。 -""" - }, - "modern": { - "name": "冷硬现代 (Modern & Hard-boiled)", - "description": "海明威式的冰山理论,节奏极快,零度情感", - "prompt_content": """ -### 核心指令:冷硬现代风格 -请采用“极简主义”和“零度写作”手法,去除所有矫饰: - -1. **冰山理论**: - - **只写动作和对话,完全剔除心理描写和形容词堆砌。** - - 不要告诉读者角色感觉如何,通过角色的反应和环境的冷峻反馈来体现。 - -2. **电影蒙太奇节奏**: - - 句子要短、脆、硬。像手术刀一样切开场景。 - - 段落之间快速切换,不要用过渡句连接,直接跳切。 - -3. **高信息密度**: - - 删除所有废话。如果一个词删掉不影响理解,就删掉它。 - - 多用名词和强动词(Strong Verbs),少用副词(Adverbs)。例如:不要写“他重重地关上门”,写“他摔上了门”。 -""" - }, - "poetic": { - "name": "意识流 (Stream of Consciousness)", - "description": "注重感官通感与内心独白,打破现实与幻想的边界", - "prompt_content": """ -### 核心指令:意识流/诗意风格 -请侧重于主观感受的流动,而非客观事实的记录: - -1. **通感与陌生化**: - - 打通五感(如:听到了颜色的声音,闻到了悲伤的气味)。 - - 使用“陌生化”的语言,把熟悉的事物写得陌生,迫使读者重新审视。 - -2. **情绪的具象化**: - - **绝对禁止**直接出现“开心”、“痛苦”等抽象词汇。 - - 必须寻找“客观对应物”(Objective Correlative),将情绪投射到具体的景物上(如:生锈的铁轨、发霉的橘子)。 - -3. **流动的句式**: - - 句子可以很长,包含多重意象的叠加。 - - 允许思维的非线性跳跃,模拟梦境或深层潜意识的逻辑。 -""" - }, - "concise": { - "name": "白描速写 (Sketch & Concise)", - "description": "只有骨架的叙事,强调绝对的精准和功能性", - "prompt_content": """ -### 核心指令:白描速写风格 -请像速写画家一样,只勾勒线条,不涂抹色彩: - -1. **功能性第一**: - - 每一句话必须推动情节,或者揭示关键信息。 - - 如果一句话只是为了渲染气氛,删掉它。 - -2. **主谓宾结构**: - - 尽量使用简单的主谓宾结构,减少修饰语。 - - 避免复杂的从句和嵌套结构。 - -3. **直击核心**: - - 对话直接进入主题,去除寒暄和废话。 - - 环境描写仅限于对情节有物理影响的物体(如:挡路的石头、藏在桌下的枪)。 -""" - }, - "vivid": { - "name": "感官特写 (Sensory & Vivid)", - "description": "高分辨率的描写,强调材质、光影和微观细节", - "prompt_content": """ -### 核心指令:感官特写风格 -请将镜头推到特写级别(Macro Lens),捕捉常人忽略的细节: - -1. **反套路细节**: - - 不要写大众化的细节(如:蓝天白云),要写具有**独特性**的细节(如:云层边缘那抹像淤青一样的灰紫色)。 - - 关注物体的**质感(Texture)**:粗糙的、粘稠的、冰凉的、颗粒感的。 - -2. **动态捕捉**: - - 不要写静止的画面,要写光影的流变、灰尘的飞舞、肌肉的抽动。 - - 让读者产生生理性的反应(如:痛感、饥饿感、窒息感)。 - -3. **禁用词汇**: - - 禁止使用“映入眼帘”、“宛如画卷”等陈词滥调。 - - 必须用具体的动词带动感官描写。 -""" - } - } - - @classmethod - def get_preset_style(cls, preset_id: str) -> Optional[Dict[str, str]]: - """获取预设风格配置""" - return cls.PRESET_STYLES.get(preset_id) - - @classmethod - def get_all_presets(cls) -> Dict[str, Dict[str, str]]: - """获取所有预设风格""" - return cls.PRESET_STYLES - @staticmethod def apply_style_to_prompt(base_prompt: str, style_content: str) -> str: """ @@ -692,9 +556,8 @@ class PromptService: 6. **承上启下**: - 开头自然衔接上一章结尾(但不重复上一章内容) - - 结尾为下一章做好铺垫 -6. **记忆系统使用指南**: +7. **记忆系统使用指南**: - **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展 - **语义相关记忆**:参考相似情节的处理方式 - **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果 @@ -1308,16 +1171,15 @@ class PromptService: - 如果参数名是 snake_case(如 next_thought),就使用 snake_case - 保持与 schema 中定义的完全一致,包括大小写和命名风格""" - # 灵感模式提示词字典 - INSPIRATION_PROMPTS = { - "title": { - "system": """你是一位专业的小说创作顾问。 + # 灵感模式 - 书名生成(系统提示词) + INSPIRATION_TITLE_SYSTEM = """你是一位专业的小说创作顾问。 用户的原始想法:{initial_idea} 请根据用户的想法,生成6个吸引人的书名建议,要求: 1. 紧扣用户的原始想法和核心故事构思 2. 富有创意和吸引力 3. 涵盖不同的风格倾向 +4. 书名中不要带有"《》"符号 返回JSON格式: {{ @@ -1325,11 +1187,13 @@ class PromptService: "options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"] }} -只返回纯JSON,不要有其他文字。""", - "user": "用户的想法:{initial_idea}\n请生成6个书名建议" - }, - "description": { - "system": """你是一位专业的小说创作顾问。 +只返回纯JSON,不要有其他文字。""" + + # 灵感模式 - 书名生成(用户提示词) + INSPIRATION_TITLE_USER = "用户的想法:{initial_idea}\n请生成6个书名建议" + + # 灵感模式 - 简介生成(系统提示词) + INSPIRATION_DESCRIPTION_SYSTEM = """你是一位专业的小说创作顾问。 用户的原始想法:{initial_idea} 已确定的书名:{title} @@ -1343,11 +1207,13 @@ class PromptService: 返回JSON格式: {{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}} -只返回纯JSON,不要有其他文字,不要换行。""", - "user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项" - }, - "theme": { - "system": """你是一位专业的小说创作顾问。 +只返回纯JSON,不要有其他文字,不要换行。""" + + # 灵感模式 - 简介生成(用户提示词) + INSPIRATION_DESCRIPTION_USER = "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项" + + # 灵感模式 - 主题生成(系统提示词) + INSPIRATION_THEME_SYSTEM = """你是一位专业的小说创作顾问。 用户的原始想法:{initial_idea} 小说信息: - 书名:{title} @@ -1363,11 +1229,13 @@ class PromptService: 返回JSON格式: {{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}} -只返回纯JSON,不要有其他文字,不要换行。""", - "user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项" - }, - "genre": { - "system": """你是一位专业的小说创作顾问。 +只返回纯JSON,不要有其他文字,不要换行。""" + + # 灵感模式 - 主题生成(用户提示词) + INSPIRATION_THEME_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项" + + # 灵感模式 - 类型生成(系统提示词) + INSPIRATION_GENRE_SYSTEM = """你是一位专业的小说创作顾问。 用户的原始想法:{initial_idea} 小说信息: - 书名:{title} @@ -1384,10 +1252,10 @@ class PromptService: 返回JSON格式: {{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}} -只返回紧凑的纯JSON,不要换行,不要有其他文字。""", - "user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签" - } - } +只返回紧凑的纯JSON,不要换行,不要有其他文字。""" + + # 灵感模式 - 类型生成(用户提示词) + INSPIRATION_GENRE_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签" # 灵感模式智能补全提示词 INSPIRATION_QUICK_COMPLETE = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。 @@ -1887,7 +1755,26 @@ class PromptService: @classmethod def get_inspiration_prompt(cls, step: str) -> Optional[Dict[str, str]]: """获取灵感模式指定步骤的提示词""" - return cls.INSPIRATION_PROMPTS.get(step) + # 根据步骤名称返回对应的system和user提示词 + step_map = { + "title": { + "system": cls.INSPIRATION_TITLE_SYSTEM, + "user": cls.INSPIRATION_TITLE_USER + }, + "description": { + "system": cls.INSPIRATION_DESCRIPTION_SYSTEM, + "user": cls.INSPIRATION_DESCRIPTION_USER + }, + "theme": { + "system": cls.INSPIRATION_THEME_SYSTEM, + "user": cls.INSPIRATION_THEME_USER + }, + "genre": { + "system": cls.INSPIRATION_GENRE_SYSTEM, + "user": cls.INSPIRATION_GENRE_USER + } + } + return step_map.get(step) @classmethod def get_inspiration_quick_complete_prompt(cls, existing: str) -> Dict[str, str]: @@ -1997,17 +1884,12 @@ class PromptService: # 2. 降级到系统默认模板 logger.info(f"⚪ 使用系统默认提示词: user_id={user_id}, template_key={template_key} (未找到自定义模板)") - # 特殊处理灵感模式的提示词(存储在INSPIRATION_PROMPTS字典中) + # 特殊处理灵感模式的提示词(直接从类属性获取) if template_key.startswith("INSPIRATION_"): - # 提取步骤名称(如 INSPIRATION_TITLE -> title) - step = template_key.replace("INSPIRATION_", "").lower() - inspiration_prompt = cls.INSPIRATION_PROMPTS.get(step) - if inspiration_prompt: - # 返回JSON格式的提示词 - return json.dumps(inspiration_prompt, ensure_ascii=False) - # 如果是INSPIRATION_QUICK_COMPLETE - if template_key == "INSPIRATION_QUICK_COMPLETE": - return cls.INSPIRATION_QUICK_COMPLETE + # 直接从类属性获取 + template_content = getattr(cls, template_key, None) + if template_content: + return template_content # 其他模板直接从类属性获取 template_content = getattr(cls, template_key, None) @@ -2182,6 +2064,60 @@ class PromptService: "category": "世界构建", "description": "根据世界观自动生成完整的职业体系,包括主职业和副职业", "parameters": ["title", "genre", "theme", "time_period", "location", "atmosphere", "rules"] + }, + "INSPIRATION_TITLE_SYSTEM": { + "name": "灵感模式-书名生成(系统提示词)", + "category": "灵感模式", + "description": "根据用户的原始想法生成6个书名建议的系统提示词", + "parameters": ["initial_idea"] + }, + "INSPIRATION_TITLE_USER": { + "name": "灵感模式-书名生成(用户提示词)", + "category": "灵感模式", + "description": "根据用户的原始想法生成6个书名建议的用户提示词", + "parameters": ["initial_idea"] + }, + "INSPIRATION_DESCRIPTION_SYSTEM": { + "name": "灵感模式-简介生成(系统提示词)", + "category": "灵感模式", + "description": "根据用户想法和书名生成6个简介选项的系统提示词", + "parameters": ["initial_idea", "title"] + }, + "INSPIRATION_DESCRIPTION_USER": { + "name": "灵感模式-简介生成(用户提示词)", + "category": "灵感模式", + "description": "根据用户想法和书名生成6个简介选项的用户提示词", + "parameters": ["initial_idea", "title"] + }, + "INSPIRATION_THEME_SYSTEM": { + "name": "灵感模式-主题生成(系统提示词)", + "category": "灵感模式", + "description": "根据书名和简介生成6个深刻的主题选项的系统提示词", + "parameters": ["initial_idea", "title", "description"] + }, + "INSPIRATION_THEME_USER": { + "name": "灵感模式-主题生成(用户提示词)", + "category": "灵感模式", + "description": "根据书名和简介生成6个深刻的主题选项的用户提示词", + "parameters": ["initial_idea", "title", "description"] + }, + "INSPIRATION_GENRE_SYSTEM": { + "name": "灵感模式-类型生成(系统提示词)", + "category": "灵感模式", + "description": "根据小说信息生成6个合适的类型标签的系统提示词", + "parameters": ["initial_idea", "title", "description", "theme"] + }, + "INSPIRATION_GENRE_USER": { + "name": "灵感模式-类型生成(用户提示词)", + "category": "灵感模式", + "description": "根据小说信息生成6个合适的类型标签的用户提示词", + "parameters": ["initial_idea", "title", "description", "theme"] + }, + "INSPIRATION_QUICK_COMPLETE": { + "name": "灵感模式-智能补全", + "category": "灵感模式", + "description": "根据用户提供的部分信息智能补全完整的小说方案", + "parameters": ["existing"] } } diff --git a/backend/app/utils/sse_response.py b/backend/app/utils/sse_response.py index 150c105..66eec40 100644 --- a/backend/app/utils/sse_response.py +++ b/backend/app/utils/sse_response.py @@ -23,11 +23,22 @@ class SSEResponse: Returns: 格式化后的SSE消息字符串 """ - message = "" - if event: - message += f"event: {event}\n" - message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - return message + try: + message = "" + if event: + message += f"event: {event}\n" + message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + return message + except Exception as e: + logger.error(f"❌ SSE格式化失败: {type(e).__name__}: {e}") + logger.error(f" data类型: {type(data)}") + logger.error(f" data内容: {str(data)[:500]}") + # 返回错误消息而不是崩溃 + error_message = "" + if event: + error_message += f"event: {event}\n" + error_message += f'data: {{"type": "error", "error": "SSE格式化失败: {str(e)}", "code": 500}}\n\n' + return error_message @staticmethod async def send_progress( diff --git a/frontend/src/components/AIProjectGenerator.tsx b/frontend/src/components/AIProjectGenerator.tsx index 7889572..68f7d8f 100644 --- a/frontend/src/components/AIProjectGenerator.tsx +++ b/frontend/src/components/AIProjectGenerator.tsx @@ -190,7 +190,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: (result) => { @@ -236,7 +237,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(33 + Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: (result) => { @@ -273,7 +275,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(66 + Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: () => { @@ -336,15 +339,13 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - // 世界观生成占0%-20%,职业生成占20%-30% - const baseProgress = Math.floor(prog / 5); - setProgress(baseProgress); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); - // 检测职业体系生成阶段 - 必须包含"职业体系"才算职业阶段 + // 检测职业体系生成阶段 if (msg.includes('职业体系')) { if (msg.includes('开始') || msg.includes('生成')) { - // 职业开始时,世界观应该已完成 setGenerationSteps(prev => ({ ...prev, worldBuilding: 'completed', @@ -403,8 +404,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - // 角色生成占40%-70% - setProgress(40 + Math.floor(prog * 0.3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: (result) => { @@ -437,8 +438,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - // 大纲生成占70%-100% - setProgress(70 + Math.floor(prog * 0.3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: () => { @@ -533,8 +534,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - const baseProgress = Math.floor(prog / 5); - setProgress(baseProgress); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); // 检测职业体系生成阶段 @@ -604,7 +605,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(33 + Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: (result) => { @@ -647,7 +649,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(66 + Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: () => { @@ -707,7 +710,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(33 + Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: (result) => { @@ -746,7 +750,8 @@ export const AIProjectGenerator: React.FC = ({ }, { onProgress: (msg, prog) => { - setProgress(66 + Math.floor(prog / 3)); + // 直接使用后端返回的进度值 + setProgress(prog); setProgressMessage(msg); }, onResult: () => { diff --git a/frontend/src/pages/Inspiration.tsx b/frontend/src/pages/Inspiration.tsx index 66fb101..a306717 100644 --- a/frontend/src/pages/Inspiration.tsx +++ b/frontend/src/pages/Inspiration.tsx @@ -16,6 +16,8 @@ interface Message { options?: string[]; isMultiSelect?: boolean; optionsDisabled?: boolean; // 标记选项是否已禁用 + canRefine?: boolean; // 是否可以优化(用于支持多轮对话) + step?: Step; // 当前步骤(用于反馈) } interface WizardData { @@ -69,6 +71,11 @@ const Inspiration: React.FC = () => { const [wizardData, setWizardData] = useState>({}); // 保存用户的原始想法,用于保持上下文一致性 const [initialIdea, setInitialIdea] = useState(''); + + // 反馈相关状态 + const [feedbackValue, setFeedbackValue] = useState(''); + const [showFeedbackInput, setShowFeedbackInput] = useState(null); // 当前显示反馈输入的消息索引 + const [refining, setRefining] = useState(false); // 正在优化选项 // 生成配置 const [generationConfig, setGenerationConfig] = useState(null); @@ -248,6 +255,86 @@ const Inspiration: React.FC = () => { } }; + // 处理用户反馈,重新生成选项 + const handleRefineOptions = async (messageIndex: number, feedback: string) => { + if (!feedback.trim()) { + message.warning('请输入您的反馈意见'); + return; + } + + const targetMessage = messages[messageIndex]; + if (!targetMessage.options || !targetMessage.step) { + return; + } + + setRefining(true); + setShowFeedbackInput(null); + setFeedbackValue(''); + + // 先禁用旧的选项 + setMessages(prev => { + const newMessages = [...prev]; + if (newMessages[messageIndex]) { + newMessages[messageIndex] = { + ...newMessages[messageIndex], + optionsDisabled: true, + canRefine: false, // 同时禁用反馈功能 + }; + } + return newMessages; + }); + + try { + // 添加用户反馈消息 + const feedbackMessage: Message = { + type: 'user', + content: `💭 ${feedback}`, + }; + setMessages(prev => [...prev, feedbackMessage]); + + const step = targetMessage.step as 'title' | 'description' | 'theme' | 'genre'; + + // 构建上下文 + const context: any = { + initial_idea: initialIdea, + title: wizardData.title, + description: wizardData.description, + theme: wizardData.theme, + }; + + // 调用refine接口 + const response = await inspirationApi.refineOptions({ + step, + context, + feedback, + previous_options: targetMessage.options, + }); + + if (response.error) { + message.error(response.error); + return; + } + + // 添加新的AI消息 + const aiMessage: Message = { + type: 'ai', + content: response.prompt || `根据您的反馈,我重新生成了一些${step === 'title' ? '书名' : step === 'description' ? '简介' : step === 'theme' ? '主题' : '类型'}选项:`, + options: response.options || [], + isMultiSelect: step === 'genre', + canRefine: true, + step: step, + }; + setMessages(prev => [...prev, aiMessage]); + + message.success('已根据您的反馈重新生成选项'); + } catch (error: any) { + console.error('优化选项失败:', error); + message.error(error.response?.data?.detail || '优化失败,请重试'); + } finally { + setRefining(false); + } + }; + // 步骤顺序 const stepOrder: Step[] = ['idea', 'title', 'description', 'theme', 'genre', 'perspective', 'outline_mode', 'confirm']; @@ -297,7 +384,9 @@ const Inspiration: React.FC = () => { const aiMessage: Message = { type: 'ai', content: response.prompt || '请选择一个书名,或者输入你自己的:', - options: response.options + options: response.options, + canRefine: true, + step: 'title' }; setMessages(prev => [...prev, aiMessage]); setCurrentStep('title'); @@ -497,6 +586,24 @@ const Inspiration: React.FC = () => { updatedData.genre = [input]; } else if (currentStep === 'perspective') { updatedData.narrative_perspective = input; + setWizardData(updatedData); + + // 直接进入大纲模式选择 + const aiMessage: Message = { + type: 'ai', + content: `很好!现在请选择你想要的大纲模式: + +📋 一对一模式:传统模式,一个大纲对应一个章节,适合结构清晰、章节独立的小说。 + +📚 一对多模式:细化模式,一个大纲可以展开成多个章节,适合需要详细展开情节的小说。 + +请选择:`, + options: ['📋 一对一模式', '📚 一对多模式'] + }; + setMessages(prev => [...prev, aiMessage]); + setCurrentStep('outline_mode'); + setLoading(false); + return; } else if (currentStep === 'outline_mode') { // 大纲模式不支持自定义输入 message.warning('请从选项中选择一个大纲模式'); @@ -561,7 +668,16 @@ const Inspiration: React.FC = () => { const currentIndex = stepOrder.indexOf(currentStep); const nextStep = stepOrder[currentIndex + 1]; - if (nextStep === 'description') { + if (nextStep === 'perspective') { + // genre 步骤完成后,进入 perspective + const aiMessage: Message = { + type: 'ai', + content: '很好!接下来,请选择小说的叙事视角:', + options: ['第一人称', '第三人称', '全知视角'] + }; + setMessages(prev => [...prev, aiMessage]); + setCurrentStep('perspective'); + } else if (nextStep === 'description') { const requestData = { step: 'description' as const, context: { @@ -587,7 +703,9 @@ const Inspiration: React.FC = () => { const aiMessage: Message = { type: 'ai', content: response.prompt || '请选择一个简介,或者输入你自己的:', - options: response.options + options: response.options, + canRefine: true, + step: 'description' }; setMessages(prev => [...prev, aiMessage]); setCurrentStep('description'); @@ -620,7 +738,9 @@ const Inspiration: React.FC = () => { const aiMessage: Message = { type: 'ai', content: response.prompt || '请选择一个主题,或者输入你自己的:', - options: response.options + options: response.options, + canRefine: true, + step: 'theme' }; setMessages(prev => [...prev, aiMessage]); setCurrentStep('theme'); @@ -656,7 +776,9 @@ const Inspiration: React.FC = () => { type: 'ai', content: response.prompt || '请选择类型标签(可多选):', options: response.options, - isMultiSelect: true + isMultiSelect: true, + canRefine: true, + step: 'genre' }; setMessages(prev => [...prev, aiMessage]); setCurrentStep('genre'); @@ -767,7 +889,7 @@ const Inspiration: React.FC = () => { background: msg.optionsDisabled ? 'var(--color-bg-layout)' : msg.isMultiSelect && selectedOptions.includes(option) - ? 'var(--color-bg-spotlight)' // Need to ensure this exists or use safe fallback + ? 'var(--color-bg-spotlight)' : 'var(--color-bg-container)', opacity: msg.optionsDisabled ? 0.6 : 1, animation: 'floatIn 0.6s ease-out', @@ -802,19 +924,72 @@ const Inspiration: React.FC = () => { 确认选择 ({selectedOptions.length}) )} + + {/* 反馈优化区域 - 新增 */} + {msg.canRefine && !msg.optionsDisabled && !msg.isMultiSelect && ( +
+ {showFeedbackInput === index ? ( + +