From cb036deb1589dd226641e0796594ed62c09402f6 Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Tue, 30 Dec 2025 10:04:41 +0800 Subject: [PATCH] =?UTF-8?q?refactor:1.=E9=87=8D=E6=9E=84=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E8=AF=8D=E6=A8=A1=E6=9D=BF=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=E5=AF=BC=E5=87=BA=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/prompt_templates.py | 236 ++++++++++++++++++++----- backend/app/schemas/prompt_template.py | 25 ++- frontend/src/pages/PromptTemplates.tsx | 69 +++++++- 3 files changed, 280 insertions(+), 50 deletions(-) diff --git a/backend/app/api/prompt_templates.py b/backend/app/api/prompt_templates.py index 28629a9..3defc7f 100644 --- a/backend/app/api/prompt_templates.py +++ b/backend/app/api/prompt_templates.py @@ -5,6 +5,7 @@ from sqlalchemy import select, func, delete from typing import List, Optional from datetime import datetime import json +import hashlib from app.database import get_db from app.models.prompt_template import PromptTemplate @@ -15,6 +16,8 @@ from app.schemas.prompt_template import ( PromptTemplateListResponse, PromptTemplateCategoryResponse, PromptTemplateExport, + PromptTemplateExportItem, + PromptTemplateImportResult, PromptTemplatePreviewRequest ) from app.services.prompt_service import PromptService @@ -22,6 +25,10 @@ from app.logger import get_logger logger = get_logger(__name__) +def calculate_content_hash(content: str) -> str: + """计算模板内容的SHA256哈希值""" + return hashlib.sha256(content.strip().encode('utf-8')).hexdigest()[:16] + router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"]) @@ -352,90 +359,235 @@ async def export_templates( db: AsyncSession = Depends(get_db) ): """ - 导出用户所有自定义模板 + 导出所有提示词模板(包括用户自定义和系统默认) + - 用户自定义的提示词标记为 is_customized=true + - 系统默认的提示词标记为 is_customized=false """ # 从认证中间件获取用户ID user_id = getattr(request.state, 'user_id', None) if not user_id: raise HTTPException(status_code=401, detail="未登录") + # 1. 查询用户自定义模板 result = await db.execute( select(PromptTemplate).where(PromptTemplate.user_id == user_id) ) - templates = result.scalars().all() + user_templates = result.scalars().all() - # 转换为导出格式 - export_data = [ - { - "template_key": t.template_key, - "template_name": t.template_name, - "template_content": t.template_content, - "description": t.description, - "category": t.category, - "parameters": t.parameters, - "is_active": t.is_active - } - for t in templates - ] + # 2. 获取所有系统默认模板 + system_templates = PromptService.get_all_system_templates() - logger.info(f"用户 {user_id} 导出了 {len(export_data)} 个模板") + # 3. 构建用户自定义模板的键集合 + user_template_keys = {t.template_key for t in user_templates} + + # 4. 准备导出数据 + export_items = [] + customized_count = 0 + system_default_count = 0 + + # 添加用户自定义的模板 + for user_template in user_templates: + # 获取对应的系统模板用于计算哈希 + system_template = next( + (t for t in system_templates if t["template_key"] == user_template.template_key), + None + ) + system_hash = calculate_content_hash(system_template["content"]) if system_template else None + + export_items.append(PromptTemplateExportItem( + template_key=user_template.template_key, + template_name=user_template.template_name, + template_content=user_template.template_content, + description=user_template.description, + category=user_template.category, + parameters=user_template.parameters, + is_active=user_template.is_active, + is_customized=True, + system_content_hash=system_hash + )) + customized_count += 1 + + # 添加未自定义的系统默认模板 + for sys_template in system_templates: + if sys_template['template_key'] not in user_template_keys: + export_items.append(PromptTemplateExportItem( + template_key=sys_template['template_key'], + template_name=sys_template['template_name'], + template_content=sys_template['content'], + description=sys_template['description'], + category=sys_template['category'], + parameters=json.dumps(sys_template['parameters']), + is_active=True, + is_customized=False, + system_content_hash=calculate_content_hash(sys_template['content']) + )) + system_default_count += 1 + + statistics = { + "total": len(export_items), + "customized": customized_count, + "system_default": system_default_count + } + + logger.info(f"用户 {user_id} 导出了 {statistics['total']} 个模板 " + f"(自定义: {statistics['customized']}, 系统默认: {statistics['system_default']})") return PromptTemplateExport( - templates=export_data, - export_time=datetime.now() + templates=export_items, + export_time=datetime.now(), + version="2.0", + statistics=statistics ) -@router.post("/import") +@router.post("/import", response_model=PromptTemplateImportResult) async def import_templates( data: PromptTemplateExport, request: Request, db: AsyncSession = Depends(get_db) ): """ - 导入提示词模板 + 智能导入提示词模板 + - 如果导入的是系统默认且内容未修改 → 删除自定义记录(使用系统默认) + - 如果导入的是系统默认但内容已修改 → 创建自定义记录 + - 如果导入的是用户自定义 → 创建/更新自定义记录 """ # 从认证中间件获取用户ID user_id = getattr(request.state, 'user_id', None) if not user_id: raise HTTPException(status_code=401, detail="未登录") - imported_count = 0 - updated_count = 0 + # 获取所有系统默认模板用于比对 + system_templates = PromptService.get_all_system_templates() + system_template_dict = {t["template_key"]: t for t in system_templates} + + # 统计信息 + kept_system_default = 0 # 保持系统默认 + created_or_updated = 0 # 创建或更新自定义 + converted_to_custom = 0 # 从系统默认转为自定义 + converted_templates = [] # 被转换的模板列表 for template_data in data.templates: - # 查找是否已存在 + template_key = template_data.template_key + is_customized = template_data.is_customized + imported_content = template_data.template_content.strip() + + # 查找当前用户是否已有该模板的自定义版本 result = await db.execute( select(PromptTemplate).where( PromptTemplate.user_id == user_id, - PromptTemplate.template_key == template_data.template_key + PromptTemplate.template_key == template_key ) ) existing = result.scalar_one_or_none() - if existing: - # 更新现有模板 - for key, value in template_data.model_dump().items(): - setattr(existing, key, value) - updated_count += 1 + # 获取系统默认模板 + system_template = system_template_dict.get(template_key) + + if not is_customized: + # 导入的标记为系统默认 + if system_template: + system_content = system_template["content"].strip() + + # 比对内容是否与系统默认一致 + if imported_content == system_content: + # 内容一致,删除自定义记录(如果有) + if existing: + await db.delete(existing) + logger.info(f"用户 {user_id} 的模板 {template_key} 恢复为系统默认(删除自定义)") + kept_system_default += 1 + else: + # 内容不一致,用户修改过,创建/更新为自定义 + if existing: + # 更新现有自定义 + existing.template_name = template_data.template_name + existing.template_content = template_data.template_content + existing.description = template_data.description + existing.category = template_data.category + existing.parameters = template_data.parameters + existing.is_active = template_data.is_active + else: + # 创建新自定义 + new_template = PromptTemplate( + user_id=user_id, + template_key=template_data.template_key, + template_name=template_data.template_name, + template_content=template_data.template_content, + description=template_data.description, + category=template_data.category, + parameters=template_data.parameters, + is_active=template_data.is_active + ) + db.add(new_template) + + converted_to_custom += 1 + converted_templates.append({ + "template_key": template_key, + "template_name": template_data.template_name, + "reason": "内容与系统默认不一致,已转为自定义" + }) + logger.info(f"用户 {user_id} 的模板 {template_key} 内容已修改,转为自定义") + else: + # 系统中不存在该模板,作为自定义导入 + if existing: + existing.template_name = template_data.template_name + existing.template_content = template_data.template_content + existing.description = template_data.description + existing.category = template_data.category + existing.parameters = template_data.parameters + existing.is_active = template_data.is_active + else: + new_template = PromptTemplate( + user_id=user_id, + template_key=template_data.template_key, + template_name=template_data.template_name, + template_content=template_data.template_content, + description=template_data.description, + category=template_data.category, + parameters=template_data.parameters, + is_active=template_data.is_active + ) + db.add(new_template) + created_or_updated += 1 else: - # 创建新模板 - new_template = PromptTemplate( - user_id=user_id, - **template_data.model_dump() - ) - db.add(new_template) - imported_count += 1 + # 导入的标记为用户自定义,直接创建/更新 + if existing: + existing.template_name = template_data.template_name + existing.template_content = template_data.template_content + existing.description = template_data.description + existing.category = template_data.category + existing.parameters = template_data.parameters + existing.is_active = template_data.is_active + else: + new_template = PromptTemplate( + user_id=user_id, + template_key=template_data.template_key, + template_name=template_data.template_name, + template_content=template_data.template_content, + description=template_data.description, + category=template_data.category, + parameters=template_data.parameters, + is_active=template_data.is_active + ) + db.add(new_template) + created_or_updated += 1 await db.commit() - logger.info(f"用户 {user_id} 导入了 {imported_count} 个新模板,更新了 {updated_count} 个模板") - return { - "message": "导入成功", - "imported": imported_count, - "updated": updated_count, - "total": imported_count + updated_count + statistics = { + "total": len(data.templates), + "kept_system_default": kept_system_default, + "created_or_updated": created_or_updated, + "converted_to_custom": converted_to_custom } + + logger.info(f"用户 {user_id} 导入完成: {statistics}") + + return PromptTemplateImportResult( + message="导入成功", + statistics=statistics, + converted_templates=converted_templates + ) @router.post("/{template_key}/preview") diff --git a/backend/app/schemas/prompt_template.py b/backend/app/schemas/prompt_template.py index 3115c49..3568ba1 100644 --- a/backend/app/schemas/prompt_template.py +++ b/backend/app/schemas/prompt_template.py @@ -55,11 +55,32 @@ class PromptTemplateCategoryResponse(BaseModel): templates: List[PromptTemplateResponse] +class PromptTemplateExportItem(BaseModel): + """提示词模板导出项模型""" + template_key: str = Field(..., description="模板键名") + template_name: str = Field(..., description="模板显示名称") + template_content: str = Field(..., description="模板内容") + description: Optional[str] = Field(None, description="模板描述") + category: Optional[str] = Field(None, description="模板分类") + parameters: Optional[str] = Field(None, description="模板参数定义(JSON)") + is_active: bool = Field(True, description="是否启用") + is_customized: bool = Field(..., description="是否为用户自定义(false=系统默认,true=用户自定义)") + system_content_hash: Optional[str] = Field(None, description="系统默认内容的哈希值,用于比对") + + class PromptTemplateExport(BaseModel): """提示词模板导出模型""" - templates: List[PromptTemplateBase] + templates: List[PromptTemplateExportItem] export_time: datetime - version: str = "1.0" + version: str = "2.0" + statistics: Optional[dict] = Field(None, description="导出统计信息") + + +class PromptTemplateImportResult(BaseModel): + """提示词模板导入结果""" + message: str + statistics: dict = Field(..., description="导入统计信息") + converted_templates: List[dict] = Field(default_factory=list, description="被转换为自定义的模板列表") class PromptTemplatePreviewRequest(BaseModel): diff --git a/frontend/src/pages/PromptTemplates.tsx b/frontend/src/pages/PromptTemplates.tsx index 2eca8cf..b5e0fdc 100644 --- a/frontend/src/pages/PromptTemplates.tsx +++ b/frontend/src/pages/PromptTemplates.tsx @@ -57,6 +57,7 @@ interface CategoryGroup { export default function PromptTemplates() { const navigate = useNavigate(); + const [modal, contextHolder] = Modal.useModal(); const [categories, setCategories] = useState([]); const [selectedCategory, setSelectedCategory] = useState('0'); const [editingTemplate, setEditingTemplate] = useState(null); @@ -124,7 +125,7 @@ export default function PromptTemplates() { // 重置为系统默认 const handleReset = async (templateKey: string) => { - Modal.confirm({ + modal.confirm({ title: '确认重置', content: '确定要重置为系统默认模板吗?这将覆盖您的自定义内容。', okText: '确定', @@ -161,6 +162,8 @@ export default function PromptTemplates() { const handleExport = async () => { try { const response = await axios.post('/api/prompt-templates/export'); + const stats = response.data.statistics; + const blob = new Blob([JSON.stringify(response.data, null, 2)], { type: 'application/json' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); @@ -168,7 +171,15 @@ export default function PromptTemplates() { a.download = `prompt-templates-${new Date().toISOString().split('T')[0]}.json`; a.click(); URL.revokeObjectURL(url); - message.success('导出成功'); + + if (stats) { + message.success( + `成功导出 ${stats.total} 个提示词配置(${stats.customized} 个自定义,${stats.system_default} 个系统默认)`, + 5 + ); + } else { + message.success('导出成功'); + } } catch (error: any) { message.error(error.response?.data?.detail || '导出失败'); } @@ -179,8 +190,51 @@ export default function PromptTemplates() { try { const text = await file.text(); const data = JSON.parse(text); - await axios.post('/api/prompt-templates/import', data); - message.success('导入成功'); + const response = await axios.post('/api/prompt-templates/import', data); + + const result = response.data; + const stats = result.statistics; + + // 构建详细的成功消息 + let successMsg = `导入成功!\n`; + if (stats) { + successMsg += `• 保持系统默认:${stats.kept_system_default} 个\n`; + successMsg += `• 创建/更新自定义:${stats.created_or_updated} 个`; + + if (stats.converted_to_custom > 0) { + successMsg += `\n• 检测到修改(已转为自定义):${stats.converted_to_custom} 个`; + } + } + + // 如果有被转换的模板,显示详细信息 + if (result.converted_templates && result.converted_templates.length > 0) { + modal.info({ + title: '导入完成', + width: 600, + centered: true, + content: ( +
+

{successMsg}

+ {result.converted_templates.length > 0 && ( +
+

以下模板内容与系统默认不一致,已转为自定义:

+
    + {result.converted_templates.map((t: any) => ( +
  • + {t.template_name} ({t.template_key}) +
  • + ))} +
+
+ )} +
+ ), + okText: '确定' + }); + } else { + message.success(successMsg, 5); + } + loadTemplates(); } catch (error: any) { message.error(error.response?.data?.detail || '导入失败'); @@ -191,7 +245,9 @@ export default function PromptTemplates() { const currentTemplates = getCurrentTemplates(); return ( -
+ {contextHolder} +
-
+
+ ); } \ No newline at end of file