refactor:1.重构系统提示词模板导入导出功能
This commit is contained in:
@@ -5,6 +5,7 @@ from sqlalchemy import select, func, delete
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.prompt_template import PromptTemplate
|
from app.models.prompt_template import PromptTemplate
|
||||||
@@ -15,6 +16,8 @@ from app.schemas.prompt_template import (
|
|||||||
PromptTemplateListResponse,
|
PromptTemplateListResponse,
|
||||||
PromptTemplateCategoryResponse,
|
PromptTemplateCategoryResponse,
|
||||||
PromptTemplateExport,
|
PromptTemplateExport,
|
||||||
|
PromptTemplateExportItem,
|
||||||
|
PromptTemplateImportResult,
|
||||||
PromptTemplatePreviewRequest
|
PromptTemplatePreviewRequest
|
||||||
)
|
)
|
||||||
from app.services.prompt_service import PromptService
|
from app.services.prompt_service import PromptService
|
||||||
@@ -22,6 +25,10 @@ from app.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
def calculate_content_hash(content: str) -> str:
|
||||||
|
"""计算模板内容的SHA256哈希值"""
|
||||||
|
return hashlib.sha256(content.strip().encode('utf-8')).hexdigest()[:16]
|
||||||
|
|
||||||
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
|
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
|
||||||
|
|
||||||
|
|
||||||
@@ -352,91 +359,236 @@ async def export_templates(
|
|||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
导出用户所有自定义模板
|
导出所有提示词模板(包括用户自定义和系统默认)
|
||||||
|
- 用户自定义的提示词标记为 is_customized=true
|
||||||
|
- 系统默认的提示词标记为 is_customized=false
|
||||||
"""
|
"""
|
||||||
# 从认证中间件获取用户ID
|
# 从认证中间件获取用户ID
|
||||||
user_id = getattr(request.state, 'user_id', None)
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(status_code=401, detail="未登录")
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
|
# 1. 查询用户自定义模板
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
||||||
)
|
)
|
||||||
templates = result.scalars().all()
|
user_templates = result.scalars().all()
|
||||||
|
|
||||||
# 转换为导出格式
|
# 2. 获取所有系统默认模板
|
||||||
export_data = [
|
system_templates = PromptService.get_all_system_templates()
|
||||||
{
|
|
||||||
"template_key": t.template_key,
|
# 3. 构建用户自定义模板的键集合
|
||||||
"template_name": t.template_name,
|
user_template_keys = {t.template_key for t in user_templates}
|
||||||
"template_content": t.template_content,
|
|
||||||
"description": t.description,
|
# 4. 准备导出数据
|
||||||
"category": t.category,
|
export_items = []
|
||||||
"parameters": t.parameters,
|
customized_count = 0
|
||||||
"is_active": t.is_active
|
system_default_count = 0
|
||||||
|
|
||||||
|
# 添加用户自定义的模板
|
||||||
|
for user_template in user_templates:
|
||||||
|
# 获取对应的系统模板用于计算哈希
|
||||||
|
system_template = next(
|
||||||
|
(t for t in system_templates if t["template_key"] == user_template.template_key),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
system_hash = calculate_content_hash(system_template["content"]) if system_template else None
|
||||||
|
|
||||||
|
export_items.append(PromptTemplateExportItem(
|
||||||
|
template_key=user_template.template_key,
|
||||||
|
template_name=user_template.template_name,
|
||||||
|
template_content=user_template.template_content,
|
||||||
|
description=user_template.description,
|
||||||
|
category=user_template.category,
|
||||||
|
parameters=user_template.parameters,
|
||||||
|
is_active=user_template.is_active,
|
||||||
|
is_customized=True,
|
||||||
|
system_content_hash=system_hash
|
||||||
|
))
|
||||||
|
customized_count += 1
|
||||||
|
|
||||||
|
# 添加未自定义的系统默认模板
|
||||||
|
for sys_template in system_templates:
|
||||||
|
if sys_template['template_key'] not in user_template_keys:
|
||||||
|
export_items.append(PromptTemplateExportItem(
|
||||||
|
template_key=sys_template['template_key'],
|
||||||
|
template_name=sys_template['template_name'],
|
||||||
|
template_content=sys_template['content'],
|
||||||
|
description=sys_template['description'],
|
||||||
|
category=sys_template['category'],
|
||||||
|
parameters=json.dumps(sys_template['parameters']),
|
||||||
|
is_active=True,
|
||||||
|
is_customized=False,
|
||||||
|
system_content_hash=calculate_content_hash(sys_template['content'])
|
||||||
|
))
|
||||||
|
system_default_count += 1
|
||||||
|
|
||||||
|
statistics = {
|
||||||
|
"total": len(export_items),
|
||||||
|
"customized": customized_count,
|
||||||
|
"system_default": system_default_count
|
||||||
}
|
}
|
||||||
for t in templates
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"用户 {user_id} 导出了 {len(export_data)} 个模板")
|
logger.info(f"用户 {user_id} 导出了 {statistics['total']} 个模板 "
|
||||||
|
f"(自定义: {statistics['customized']}, 系统默认: {statistics['system_default']})")
|
||||||
|
|
||||||
return PromptTemplateExport(
|
return PromptTemplateExport(
|
||||||
templates=export_data,
|
templates=export_items,
|
||||||
export_time=datetime.now()
|
export_time=datetime.now(),
|
||||||
|
version="2.0",
|
||||||
|
statistics=statistics
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import")
|
@router.post("/import", response_model=PromptTemplateImportResult)
|
||||||
async def import_templates(
|
async def import_templates(
|
||||||
data: PromptTemplateExport,
|
data: PromptTemplateExport,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
导入提示词模板
|
智能导入提示词模板
|
||||||
|
- 如果导入的是系统默认且内容未修改 → 删除自定义记录(使用系统默认)
|
||||||
|
- 如果导入的是系统默认但内容已修改 → 创建自定义记录
|
||||||
|
- 如果导入的是用户自定义 → 创建/更新自定义记录
|
||||||
"""
|
"""
|
||||||
# 从认证中间件获取用户ID
|
# 从认证中间件获取用户ID
|
||||||
user_id = getattr(request.state, 'user_id', None)
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(status_code=401, detail="未登录")
|
raise HTTPException(status_code=401, detail="未登录")
|
||||||
|
|
||||||
imported_count = 0
|
# 获取所有系统默认模板用于比对
|
||||||
updated_count = 0
|
system_templates = PromptService.get_all_system_templates()
|
||||||
|
system_template_dict = {t["template_key"]: t for t in system_templates}
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
kept_system_default = 0 # 保持系统默认
|
||||||
|
created_or_updated = 0 # 创建或更新自定义
|
||||||
|
converted_to_custom = 0 # 从系统默认转为自定义
|
||||||
|
converted_templates = [] # 被转换的模板列表
|
||||||
|
|
||||||
for template_data in data.templates:
|
for template_data in data.templates:
|
||||||
# 查找是否已存在
|
template_key = template_data.template_key
|
||||||
|
is_customized = template_data.is_customized
|
||||||
|
imported_content = template_data.template_content.strip()
|
||||||
|
|
||||||
|
# 查找当前用户是否已有该模板的自定义版本
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(PromptTemplate).where(
|
select(PromptTemplate).where(
|
||||||
PromptTemplate.user_id == user_id,
|
PromptTemplate.user_id == user_id,
|
||||||
PromptTemplate.template_key == template_data.template_key
|
PromptTemplate.template_key == template_key
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# 获取系统默认模板
|
||||||
|
system_template = system_template_dict.get(template_key)
|
||||||
|
|
||||||
|
if not is_customized:
|
||||||
|
# 导入的标记为系统默认
|
||||||
|
if system_template:
|
||||||
|
system_content = system_template["content"].strip()
|
||||||
|
|
||||||
|
# 比对内容是否与系统默认一致
|
||||||
|
if imported_content == system_content:
|
||||||
|
# 内容一致,删除自定义记录(如果有)
|
||||||
if existing:
|
if existing:
|
||||||
# 更新现有模板
|
await db.delete(existing)
|
||||||
for key, value in template_data.model_dump().items():
|
logger.info(f"用户 {user_id} 的模板 {template_key} 恢复为系统默认(删除自定义)")
|
||||||
setattr(existing, key, value)
|
kept_system_default += 1
|
||||||
updated_count += 1
|
|
||||||
else:
|
else:
|
||||||
# 创建新模板
|
# 内容不一致,用户修改过,创建/更新为自定义
|
||||||
|
if existing:
|
||||||
|
# 更新现有自定义
|
||||||
|
existing.template_name = template_data.template_name
|
||||||
|
existing.template_content = template_data.template_content
|
||||||
|
existing.description = template_data.description
|
||||||
|
existing.category = template_data.category
|
||||||
|
existing.parameters = template_data.parameters
|
||||||
|
existing.is_active = template_data.is_active
|
||||||
|
else:
|
||||||
|
# 创建新自定义
|
||||||
new_template = PromptTemplate(
|
new_template = PromptTemplate(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
**template_data.model_dump()
|
template_key=template_data.template_key,
|
||||||
|
template_name=template_data.template_name,
|
||||||
|
template_content=template_data.template_content,
|
||||||
|
description=template_data.description,
|
||||||
|
category=template_data.category,
|
||||||
|
parameters=template_data.parameters,
|
||||||
|
is_active=template_data.is_active
|
||||||
)
|
)
|
||||||
db.add(new_template)
|
db.add(new_template)
|
||||||
imported_count += 1
|
|
||||||
|
converted_to_custom += 1
|
||||||
|
converted_templates.append({
|
||||||
|
"template_key": template_key,
|
||||||
|
"template_name": template_data.template_name,
|
||||||
|
"reason": "内容与系统默认不一致,已转为自定义"
|
||||||
|
})
|
||||||
|
logger.info(f"用户 {user_id} 的模板 {template_key} 内容已修改,转为自定义")
|
||||||
|
else:
|
||||||
|
# 系统中不存在该模板,作为自定义导入
|
||||||
|
if existing:
|
||||||
|
existing.template_name = template_data.template_name
|
||||||
|
existing.template_content = template_data.template_content
|
||||||
|
existing.description = template_data.description
|
||||||
|
existing.category = template_data.category
|
||||||
|
existing.parameters = template_data.parameters
|
||||||
|
existing.is_active = template_data.is_active
|
||||||
|
else:
|
||||||
|
new_template = PromptTemplate(
|
||||||
|
user_id=user_id,
|
||||||
|
template_key=template_data.template_key,
|
||||||
|
template_name=template_data.template_name,
|
||||||
|
template_content=template_data.template_content,
|
||||||
|
description=template_data.description,
|
||||||
|
category=template_data.category,
|
||||||
|
parameters=template_data.parameters,
|
||||||
|
is_active=template_data.is_active
|
||||||
|
)
|
||||||
|
db.add(new_template)
|
||||||
|
created_or_updated += 1
|
||||||
|
else:
|
||||||
|
# 导入的标记为用户自定义,直接创建/更新
|
||||||
|
if existing:
|
||||||
|
existing.template_name = template_data.template_name
|
||||||
|
existing.template_content = template_data.template_content
|
||||||
|
existing.description = template_data.description
|
||||||
|
existing.category = template_data.category
|
||||||
|
existing.parameters = template_data.parameters
|
||||||
|
existing.is_active = template_data.is_active
|
||||||
|
else:
|
||||||
|
new_template = PromptTemplate(
|
||||||
|
user_id=user_id,
|
||||||
|
template_key=template_data.template_key,
|
||||||
|
template_name=template_data.template_name,
|
||||||
|
template_content=template_data.template_content,
|
||||||
|
description=template_data.description,
|
||||||
|
category=template_data.category,
|
||||||
|
parameters=template_data.parameters,
|
||||||
|
is_active=template_data.is_active
|
||||||
|
)
|
||||||
|
db.add(new_template)
|
||||||
|
created_or_updated += 1
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
logger.info(f"用户 {user_id} 导入了 {imported_count} 个新模板,更新了 {updated_count} 个模板")
|
|
||||||
|
|
||||||
return {
|
statistics = {
|
||||||
"message": "导入成功",
|
"total": len(data.templates),
|
||||||
"imported": imported_count,
|
"kept_system_default": kept_system_default,
|
||||||
"updated": updated_count,
|
"created_or_updated": created_or_updated,
|
||||||
"total": imported_count + updated_count
|
"converted_to_custom": converted_to_custom
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(f"用户 {user_id} 导入完成: {statistics}")
|
||||||
|
|
||||||
|
return PromptTemplateImportResult(
|
||||||
|
message="导入成功",
|
||||||
|
statistics=statistics,
|
||||||
|
converted_templates=converted_templates
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{template_key}/preview")
|
@router.post("/{template_key}/preview")
|
||||||
async def preview_template(
|
async def preview_template(
|
||||||
|
|||||||
@@ -55,11 +55,32 @@ class PromptTemplateCategoryResponse(BaseModel):
|
|||||||
templates: List[PromptTemplateResponse]
|
templates: List[PromptTemplateResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateExportItem(BaseModel):
|
||||||
|
"""提示词模板导出项模型"""
|
||||||
|
template_key: str = Field(..., description="模板键名")
|
||||||
|
template_name: str = Field(..., description="模板显示名称")
|
||||||
|
template_content: str = Field(..., description="模板内容")
|
||||||
|
description: Optional[str] = Field(None, description="模板描述")
|
||||||
|
category: Optional[str] = Field(None, description="模板分类")
|
||||||
|
parameters: Optional[str] = Field(None, description="模板参数定义(JSON)")
|
||||||
|
is_active: bool = Field(True, description="是否启用")
|
||||||
|
is_customized: bool = Field(..., description="是否为用户自定义(false=系统默认,true=用户自定义)")
|
||||||
|
system_content_hash: Optional[str] = Field(None, description="系统默认内容的哈希值,用于比对")
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateExport(BaseModel):
|
class PromptTemplateExport(BaseModel):
|
||||||
"""提示词模板导出模型"""
|
"""提示词模板导出模型"""
|
||||||
templates: List[PromptTemplateBase]
|
templates: List[PromptTemplateExportItem]
|
||||||
export_time: datetime
|
export_time: datetime
|
||||||
version: str = "1.0"
|
version: str = "2.0"
|
||||||
|
statistics: Optional[dict] = Field(None, description="导出统计信息")
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateImportResult(BaseModel):
|
||||||
|
"""提示词模板导入结果"""
|
||||||
|
message: str
|
||||||
|
statistics: dict = Field(..., description="导入统计信息")
|
||||||
|
converted_templates: List[dict] = Field(default_factory=list, description="被转换为自定义的模板列表")
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplatePreviewRequest(BaseModel):
|
class PromptTemplatePreviewRequest(BaseModel):
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ interface CategoryGroup {
|
|||||||
|
|
||||||
export default function PromptTemplates() {
|
export default function PromptTemplates() {
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
const [modal, contextHolder] = Modal.useModal();
|
||||||
const [categories, setCategories] = useState<CategoryGroup[]>([]);
|
const [categories, setCategories] = useState<CategoryGroup[]>([]);
|
||||||
const [selectedCategory, setSelectedCategory] = useState<string>('0');
|
const [selectedCategory, setSelectedCategory] = useState<string>('0');
|
||||||
const [editingTemplate, setEditingTemplate] = useState<PromptTemplate | null>(null);
|
const [editingTemplate, setEditingTemplate] = useState<PromptTemplate | null>(null);
|
||||||
@@ -124,7 +125,7 @@ export default function PromptTemplates() {
|
|||||||
|
|
||||||
// 重置为系统默认
|
// 重置为系统默认
|
||||||
const handleReset = async (templateKey: string) => {
|
const handleReset = async (templateKey: string) => {
|
||||||
Modal.confirm({
|
modal.confirm({
|
||||||
title: '确认重置',
|
title: '确认重置',
|
||||||
content: '确定要重置为系统默认模板吗?这将覆盖您的自定义内容。',
|
content: '确定要重置为系统默认模板吗?这将覆盖您的自定义内容。',
|
||||||
okText: '确定',
|
okText: '确定',
|
||||||
@@ -161,6 +162,8 @@ export default function PromptTemplates() {
|
|||||||
const handleExport = async () => {
|
const handleExport = async () => {
|
||||||
try {
|
try {
|
||||||
const response = await axios.post('/api/prompt-templates/export');
|
const response = await axios.post('/api/prompt-templates/export');
|
||||||
|
const stats = response.data.statistics;
|
||||||
|
|
||||||
const blob = new Blob([JSON.stringify(response.data, null, 2)], { type: 'application/json' });
|
const blob = new Blob([JSON.stringify(response.data, null, 2)], { type: 'application/json' });
|
||||||
const url = URL.createObjectURL(blob);
|
const url = URL.createObjectURL(blob);
|
||||||
const a = document.createElement('a');
|
const a = document.createElement('a');
|
||||||
@@ -168,7 +171,15 @@ export default function PromptTemplates() {
|
|||||||
a.download = `prompt-templates-${new Date().toISOString().split('T')[0]}.json`;
|
a.download = `prompt-templates-${new Date().toISOString().split('T')[0]}.json`;
|
||||||
a.click();
|
a.click();
|
||||||
URL.revokeObjectURL(url);
|
URL.revokeObjectURL(url);
|
||||||
|
|
||||||
|
if (stats) {
|
||||||
|
message.success(
|
||||||
|
`成功导出 ${stats.total} 个提示词配置(${stats.customized} 个自定义,${stats.system_default} 个系统默认)`,
|
||||||
|
5
|
||||||
|
);
|
||||||
|
} else {
|
||||||
message.success('导出成功');
|
message.success('导出成功');
|
||||||
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
message.error(error.response?.data?.detail || '导出失败');
|
message.error(error.response?.data?.detail || '导出失败');
|
||||||
}
|
}
|
||||||
@@ -179,8 +190,51 @@ export default function PromptTemplates() {
|
|||||||
try {
|
try {
|
||||||
const text = await file.text();
|
const text = await file.text();
|
||||||
const data = JSON.parse(text);
|
const data = JSON.parse(text);
|
||||||
await axios.post('/api/prompt-templates/import', data);
|
const response = await axios.post('/api/prompt-templates/import', data);
|
||||||
message.success('导入成功');
|
|
||||||
|
const result = response.data;
|
||||||
|
const stats = result.statistics;
|
||||||
|
|
||||||
|
// 构建详细的成功消息
|
||||||
|
let successMsg = `导入成功!\n`;
|
||||||
|
if (stats) {
|
||||||
|
successMsg += `• 保持系统默认:${stats.kept_system_default} 个\n`;
|
||||||
|
successMsg += `• 创建/更新自定义:${stats.created_or_updated} 个`;
|
||||||
|
|
||||||
|
if (stats.converted_to_custom > 0) {
|
||||||
|
successMsg += `\n• 检测到修改(已转为自定义):${stats.converted_to_custom} 个`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有被转换的模板,显示详细信息
|
||||||
|
if (result.converted_templates && result.converted_templates.length > 0) {
|
||||||
|
modal.info({
|
||||||
|
title: '导入完成',
|
||||||
|
width: 600,
|
||||||
|
centered: true,
|
||||||
|
content: (
|
||||||
|
<div>
|
||||||
|
<p style={{ marginBottom: 16 }}>{successMsg}</p>
|
||||||
|
{result.converted_templates.length > 0 && (
|
||||||
|
<div>
|
||||||
|
<p style={{ fontWeight: 'bold', marginBottom: 8 }}>以下模板内容与系统默认不一致,已转为自定义:</p>
|
||||||
|
<ul style={{ marginLeft: 20 }}>
|
||||||
|
{result.converted_templates.map((t: any) => (
|
||||||
|
<li key={t.template_key}>
|
||||||
|
{t.template_name} ({t.template_key})
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
okText: '确定'
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
message.success(successMsg, 5);
|
||||||
|
}
|
||||||
|
|
||||||
loadTemplates();
|
loadTemplates();
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
message.error(error.response?.data?.detail || '导入失败');
|
message.error(error.response?.data?.detail || '导入失败');
|
||||||
@@ -191,6 +245,8 @@ export default function PromptTemplates() {
|
|||||||
const currentTemplates = getCurrentTemplates();
|
const currentTemplates = getCurrentTemplates();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<>
|
||||||
|
{contextHolder}
|
||||||
<div style={{
|
<div style={{
|
||||||
minHeight: '100vh',
|
minHeight: '100vh',
|
||||||
background: 'linear-gradient(180deg, var(--color-bg-base) 0%, #EEF2F3 100%)',
|
background: 'linear-gradient(180deg, var(--color-bg-base) 0%, #EEF2F3 100%)',
|
||||||
@@ -531,5 +587,6 @@ export default function PromptTemplates() {
|
|||||||
</Space>
|
</Space>
|
||||||
</Modal>
|
</Modal>
|
||||||
</div>
|
</div>
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user