refactor:1.重构系统提示词模板导入导出功能

This commit is contained in:
xiamuceer
2025-12-30 10:04:41 +08:00
parent 05c2981716
commit cb036deb15
3 changed files with 280 additions and 50 deletions
+188 -36
View File
@@ -5,6 +5,7 @@ from sqlalchemy import select, func, delete
from typing import List, Optional from typing import List, Optional
from datetime import datetime from datetime import datetime
import json import json
import hashlib
from app.database import get_db from app.database import get_db
from app.models.prompt_template import PromptTemplate from app.models.prompt_template import PromptTemplate
@@ -15,6 +16,8 @@ from app.schemas.prompt_template import (
PromptTemplateListResponse, PromptTemplateListResponse,
PromptTemplateCategoryResponse, PromptTemplateCategoryResponse,
PromptTemplateExport, PromptTemplateExport,
PromptTemplateExportItem,
PromptTemplateImportResult,
PromptTemplatePreviewRequest PromptTemplatePreviewRequest
) )
from app.services.prompt_service import PromptService from app.services.prompt_service import PromptService
@@ -22,6 +25,10 @@ from app.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def calculate_content_hash(content: str) -> str:
"""计算模板内容的SHA256哈希值"""
return hashlib.sha256(content.strip().encode('utf-8')).hexdigest()[:16]
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"]) router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
@@ -352,91 +359,236 @@ async def export_templates(
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
导出用户所有自定义模板 导出所有提示词模板(包括用户自定义和系统默认)
- 用户自定义的提示词标记为 is_customized=true
- 系统默认的提示词标记为 is_customized=false
""" """
# 从认证中间件获取用户ID # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None) user_id = getattr(request.state, 'user_id', None)
if not user_id: if not user_id:
raise HTTPException(status_code=401, detail="未登录") raise HTTPException(status_code=401, detail="未登录")
# 1. 查询用户自定义模板
result = await db.execute( result = await db.execute(
select(PromptTemplate).where(PromptTemplate.user_id == user_id) select(PromptTemplate).where(PromptTemplate.user_id == user_id)
) )
templates = result.scalars().all() user_templates = result.scalars().all()
# 转换为导出格式 # 2. 获取所有系统默认模板
export_data = [ system_templates = PromptService.get_all_system_templates()
{
"template_key": t.template_key, # 3. 构建用户自定义模板的键集合
"template_name": t.template_name, user_template_keys = {t.template_key for t in user_templates}
"template_content": t.template_content,
"description": t.description, # 4. 准备导出数据
"category": t.category, export_items = []
"parameters": t.parameters, customized_count = 0
"is_active": t.is_active system_default_count = 0
# 添加用户自定义的模板
for user_template in user_templates:
# 获取对应的系统模板用于计算哈希
system_template = next(
(t for t in system_templates if t["template_key"] == user_template.template_key),
None
)
system_hash = calculate_content_hash(system_template["content"]) if system_template else None
export_items.append(PromptTemplateExportItem(
template_key=user_template.template_key,
template_name=user_template.template_name,
template_content=user_template.template_content,
description=user_template.description,
category=user_template.category,
parameters=user_template.parameters,
is_active=user_template.is_active,
is_customized=True,
system_content_hash=system_hash
))
customized_count += 1
# 添加未自定义的系统默认模板
for sys_template in system_templates:
if sys_template['template_key'] not in user_template_keys:
export_items.append(PromptTemplateExportItem(
template_key=sys_template['template_key'],
template_name=sys_template['template_name'],
template_content=sys_template['content'],
description=sys_template['description'],
category=sys_template['category'],
parameters=json.dumps(sys_template['parameters']),
is_active=True,
is_customized=False,
system_content_hash=calculate_content_hash(sys_template['content'])
))
system_default_count += 1
statistics = {
"total": len(export_items),
"customized": customized_count,
"system_default": system_default_count
} }
for t in templates
]
logger.info(f"用户 {user_id} 导出了 {len(export_data)} 个模板") logger.info(f"用户 {user_id} 导出了 {statistics['total']} 个模板 "
f"(自定义: {statistics['customized']}, 系统默认: {statistics['system_default']})")
return PromptTemplateExport( return PromptTemplateExport(
templates=export_data, templates=export_items,
export_time=datetime.now() export_time=datetime.now(),
version="2.0",
statistics=statistics
) )
@router.post("/import") @router.post("/import", response_model=PromptTemplateImportResult)
async def import_templates( async def import_templates(
data: PromptTemplateExport, data: PromptTemplateExport,
request: Request, request: Request,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
导入提示词模板 智能导入提示词模板
- 如果导入的是系统默认且内容未修改 → 删除自定义记录(使用系统默认)
- 如果导入的是系统默认但内容已修改 → 创建自定义记录
- 如果导入的是用户自定义 → 创建/更新自定义记录
""" """
# 从认证中间件获取用户ID # 从认证中间件获取用户ID
user_id = getattr(request.state, 'user_id', None) user_id = getattr(request.state, 'user_id', None)
if not user_id: if not user_id:
raise HTTPException(status_code=401, detail="未登录") raise HTTPException(status_code=401, detail="未登录")
imported_count = 0 # 获取所有系统默认模板用于比对
updated_count = 0 system_templates = PromptService.get_all_system_templates()
system_template_dict = {t["template_key"]: t for t in system_templates}
# 统计信息
kept_system_default = 0 # 保持系统默认
created_or_updated = 0 # 创建或更新自定义
converted_to_custom = 0 # 从系统默认转为自定义
converted_templates = [] # 被转换的模板列表
for template_data in data.templates: for template_data in data.templates:
# 查找是否已存在 template_key = template_data.template_key
is_customized = template_data.is_customized
imported_content = template_data.template_content.strip()
# 查找当前用户是否已有该模板的自定义版本
result = await db.execute( result = await db.execute(
select(PromptTemplate).where( select(PromptTemplate).where(
PromptTemplate.user_id == user_id, PromptTemplate.user_id == user_id,
PromptTemplate.template_key == template_data.template_key PromptTemplate.template_key == template_key
) )
) )
existing = result.scalar_one_or_none() existing = result.scalar_one_or_none()
# 获取系统默认模板
system_template = system_template_dict.get(template_key)
if not is_customized:
# 导入的标记为系统默认
if system_template:
system_content = system_template["content"].strip()
# 比对内容是否与系统默认一致
if imported_content == system_content:
# 内容一致,删除自定义记录(如果有)
if existing: if existing:
# 更新现有模板 await db.delete(existing)
for key, value in template_data.model_dump().items(): logger.info(f"用户 {user_id} 的模板 {template_key} 恢复为系统默认(删除自定义)")
setattr(existing, key, value) kept_system_default += 1
updated_count += 1
else: else:
# 创建新模板 # 内容不一致,用户修改过,创建/更新为自定义
if existing:
# 更新现有自定义
existing.template_name = template_data.template_name
existing.template_content = template_data.template_content
existing.description = template_data.description
existing.category = template_data.category
existing.parameters = template_data.parameters
existing.is_active = template_data.is_active
else:
# 创建新自定义
new_template = PromptTemplate( new_template = PromptTemplate(
user_id=user_id, user_id=user_id,
**template_data.model_dump() template_key=template_data.template_key,
template_name=template_data.template_name,
template_content=template_data.template_content,
description=template_data.description,
category=template_data.category,
parameters=template_data.parameters,
is_active=template_data.is_active
) )
db.add(new_template) db.add(new_template)
imported_count += 1
converted_to_custom += 1
converted_templates.append({
"template_key": template_key,
"template_name": template_data.template_name,
"reason": "内容与系统默认不一致,已转为自定义"
})
logger.info(f"用户 {user_id} 的模板 {template_key} 内容已修改,转为自定义")
else:
# 系统中不存在该模板,作为自定义导入
if existing:
existing.template_name = template_data.template_name
existing.template_content = template_data.template_content
existing.description = template_data.description
existing.category = template_data.category
existing.parameters = template_data.parameters
existing.is_active = template_data.is_active
else:
new_template = PromptTemplate(
user_id=user_id,
template_key=template_data.template_key,
template_name=template_data.template_name,
template_content=template_data.template_content,
description=template_data.description,
category=template_data.category,
parameters=template_data.parameters,
is_active=template_data.is_active
)
db.add(new_template)
created_or_updated += 1
else:
# 导入的标记为用户自定义,直接创建/更新
if existing:
existing.template_name = template_data.template_name
existing.template_content = template_data.template_content
existing.description = template_data.description
existing.category = template_data.category
existing.parameters = template_data.parameters
existing.is_active = template_data.is_active
else:
new_template = PromptTemplate(
user_id=user_id,
template_key=template_data.template_key,
template_name=template_data.template_name,
template_content=template_data.template_content,
description=template_data.description,
category=template_data.category,
parameters=template_data.parameters,
is_active=template_data.is_active
)
db.add(new_template)
created_or_updated += 1
await db.commit() await db.commit()
logger.info(f"用户 {user_id} 导入了 {imported_count} 个新模板,更新了 {updated_count} 个模板")
return { statistics = {
"message": "导入成功", "total": len(data.templates),
"imported": imported_count, "kept_system_default": kept_system_default,
"updated": updated_count, "created_or_updated": created_or_updated,
"total": imported_count + updated_count "converted_to_custom": converted_to_custom
} }
logger.info(f"用户 {user_id} 导入完成: {statistics}")
return PromptTemplateImportResult(
message="导入成功",
statistics=statistics,
converted_templates=converted_templates
)
@router.post("/{template_key}/preview") @router.post("/{template_key}/preview")
async def preview_template( async def preview_template(
+23 -2
View File
@@ -55,11 +55,32 @@ class PromptTemplateCategoryResponse(BaseModel):
templates: List[PromptTemplateResponse] templates: List[PromptTemplateResponse]
class PromptTemplateExportItem(BaseModel):
"""提示词模板导出项模型"""
template_key: str = Field(..., description="模板键名")
template_name: str = Field(..., description="模板显示名称")
template_content: str = Field(..., description="模板内容")
description: Optional[str] = Field(None, description="模板描述")
category: Optional[str] = Field(None, description="模板分类")
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
is_active: bool = Field(True, description="是否启用")
is_customized: bool = Field(..., description="是否为用户自定义(false=系统默认,true=用户自定义)")
system_content_hash: Optional[str] = Field(None, description="系统默认内容的哈希值,用于比对")
class PromptTemplateExport(BaseModel): class PromptTemplateExport(BaseModel):
"""提示词模板导出模型""" """提示词模板导出模型"""
templates: List[PromptTemplateBase] templates: List[PromptTemplateExportItem]
export_time: datetime export_time: datetime
version: str = "1.0" version: str = "2.0"
statistics: Optional[dict] = Field(None, description="导出统计信息")
class PromptTemplateImportResult(BaseModel):
"""提示词模板导入结果"""
message: str
statistics: dict = Field(..., description="导入统计信息")
converted_templates: List[dict] = Field(default_factory=list, description="被转换为自定义的模板列表")
class PromptTemplatePreviewRequest(BaseModel): class PromptTemplatePreviewRequest(BaseModel):
+60 -3
View File
@@ -57,6 +57,7 @@ interface CategoryGroup {
export default function PromptTemplates() { export default function PromptTemplates() {
const navigate = useNavigate(); const navigate = useNavigate();
const [modal, contextHolder] = Modal.useModal();
const [categories, setCategories] = useState<CategoryGroup[]>([]); const [categories, setCategories] = useState<CategoryGroup[]>([]);
const [selectedCategory, setSelectedCategory] = useState<string>('0'); const [selectedCategory, setSelectedCategory] = useState<string>('0');
const [editingTemplate, setEditingTemplate] = useState<PromptTemplate | null>(null); const [editingTemplate, setEditingTemplate] = useState<PromptTemplate | null>(null);
@@ -124,7 +125,7 @@ export default function PromptTemplates() {
// 重置为系统默认 // 重置为系统默认
const handleReset = async (templateKey: string) => { const handleReset = async (templateKey: string) => {
Modal.confirm({ modal.confirm({
title: '确认重置', title: '确认重置',
content: '确定要重置为系统默认模板吗?这将覆盖您的自定义内容。', content: '确定要重置为系统默认模板吗?这将覆盖您的自定义内容。',
okText: '确定', okText: '确定',
@@ -161,6 +162,8 @@ export default function PromptTemplates() {
const handleExport = async () => { const handleExport = async () => {
try { try {
const response = await axios.post('/api/prompt-templates/export'); const response = await axios.post('/api/prompt-templates/export');
const stats = response.data.statistics;
const blob = new Blob([JSON.stringify(response.data, null, 2)], { type: 'application/json' }); const blob = new Blob([JSON.stringify(response.data, null, 2)], { type: 'application/json' });
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
const a = document.createElement('a'); const a = document.createElement('a');
@@ -168,7 +171,15 @@ export default function PromptTemplates() {
a.download = `prompt-templates-${new Date().toISOString().split('T')[0]}.json`; a.download = `prompt-templates-${new Date().toISOString().split('T')[0]}.json`;
a.click(); a.click();
URL.revokeObjectURL(url); URL.revokeObjectURL(url);
if (stats) {
message.success(
`成功导出 ${stats.total} 个提示词配置(${stats.customized} 个自定义,${stats.system_default} 个系统默认)`,
5
);
} else {
message.success('导出成功'); message.success('导出成功');
}
} catch (error: any) { } catch (error: any) {
message.error(error.response?.data?.detail || '导出失败'); message.error(error.response?.data?.detail || '导出失败');
} }
@@ -179,8 +190,51 @@ export default function PromptTemplates() {
try { try {
const text = await file.text(); const text = await file.text();
const data = JSON.parse(text); const data = JSON.parse(text);
await axios.post('/api/prompt-templates/import', data); const response = await axios.post('/api/prompt-templates/import', data);
message.success('导入成功');
const result = response.data;
const stats = result.statistics;
// 构建详细的成功消息
let successMsg = `导入成功!\n`;
if (stats) {
successMsg += `• 保持系统默认:${stats.kept_system_default}\n`;
successMsg += `• 创建/更新自定义:${stats.created_or_updated}`;
if (stats.converted_to_custom > 0) {
successMsg += `\n• 检测到修改(已转为自定义):${stats.converted_to_custom}`;
}
}
// 如果有被转换的模板,显示详细信息
if (result.converted_templates && result.converted_templates.length > 0) {
modal.info({
title: '导入完成',
width: 600,
centered: true,
content: (
<div>
<p style={{ marginBottom: 16 }}>{successMsg}</p>
{result.converted_templates.length > 0 && (
<div>
<p style={{ fontWeight: 'bold', marginBottom: 8 }}></p>
<ul style={{ marginLeft: 20 }}>
{result.converted_templates.map((t: any) => (
<li key={t.template_key}>
{t.template_name} ({t.template_key})
</li>
))}
</ul>
</div>
)}
</div>
),
okText: '确定'
});
} else {
message.success(successMsg, 5);
}
loadTemplates(); loadTemplates();
} catch (error: any) { } catch (error: any) {
message.error(error.response?.data?.detail || '导入失败'); message.error(error.response?.data?.detail || '导入失败');
@@ -191,6 +245,8 @@ export default function PromptTemplates() {
const currentTemplates = getCurrentTemplates(); const currentTemplates = getCurrentTemplates();
return ( return (
<>
{contextHolder}
<div style={{ <div style={{
minHeight: '100vh', minHeight: '100vh',
background: 'linear-gradient(180deg, var(--color-bg-base) 0%, #EEF2F3 100%)', background: 'linear-gradient(180deg, var(--color-bg-base) 0%, #EEF2F3 100%)',
@@ -531,5 +587,6 @@ export default function PromptTemplates() {
</Space> </Space>
</Modal> </Modal>
</div> </div>
</>
); );
} }