2025-11-29 22:01:02 +08:00
|
|
|
"""提示词模板管理 API"""
|
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy import select, func, delete
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
import json
|
2025-12-30 10:04:41 +08:00
|
|
|
import hashlib
|
2025-11-29 22:01:02 +08:00
|
|
|
|
|
|
|
|
from app.database import get_db
|
|
|
|
|
from app.models.prompt_template import PromptTemplate
|
|
|
|
|
from app.schemas.prompt_template import (
|
|
|
|
|
PromptTemplateCreate,
|
|
|
|
|
PromptTemplateUpdate,
|
|
|
|
|
PromptTemplateResponse,
|
|
|
|
|
PromptTemplateListResponse,
|
|
|
|
|
PromptTemplateCategoryResponse,
|
|
|
|
|
PromptTemplateExport,
|
2025-12-30 10:04:41 +08:00
|
|
|
PromptTemplateExportItem,
|
|
|
|
|
PromptTemplateImportResult,
|
2025-11-29 22:01:02 +08:00
|
|
|
PromptTemplatePreviewRequest
|
|
|
|
|
)
|
|
|
|
|
from app.services.prompt_service import PromptService
|
|
|
|
|
from app.logger import get_logger
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
def calculate_content_hash(content: str) -> str:
|
|
|
|
|
"""计算模板内容的SHA256哈希值"""
|
|
|
|
|
return hashlib.sha256(content.strip().encode('utf-8')).hexdigest()[:16]
|
|
|
|
|
|
2025-11-29 22:01:02 +08:00
|
|
|
router = APIRouter(prefix="/prompt-templates", tags=["提示词模板管理"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("", response_model=PromptTemplateListResponse)
|
|
|
|
|
async def get_all_templates(
|
|
|
|
|
request: Request,
|
|
|
|
|
category: Optional[str] = Query(None, description="按分类筛选"),
|
|
|
|
|
is_active: Optional[bool] = Query(None, description="按启用状态筛选"),
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
获取用户所有提示词模板
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
query = select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
|
|
|
|
|
|
|
|
|
if category:
|
|
|
|
|
query = query.where(PromptTemplate.category == category)
|
|
|
|
|
if is_active is not None:
|
|
|
|
|
query = query.where(PromptTemplate.is_active == is_active)
|
|
|
|
|
|
|
|
|
|
query = query.order_by(PromptTemplate.category, PromptTemplate.template_key)
|
|
|
|
|
|
|
|
|
|
result = await db.execute(query)
|
|
|
|
|
templates = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
# 获取所有分类
|
|
|
|
|
categories_result = await db.execute(
|
|
|
|
|
select(PromptTemplate.category)
|
|
|
|
|
.where(PromptTemplate.user_id == user_id)
|
|
|
|
|
.distinct()
|
|
|
|
|
)
|
|
|
|
|
categories = [c for c in categories_result.scalars().all() if c]
|
|
|
|
|
|
|
|
|
|
return PromptTemplateListResponse(
|
|
|
|
|
templates=templates,
|
|
|
|
|
total=len(templates),
|
|
|
|
|
categories=sorted(categories)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/categories", response_model=List[PromptTemplateCategoryResponse])
|
|
|
|
|
async def get_templates_by_category(
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
按分类获取提示词模板(合并用户自定义和系统默认)
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
# 1. 查询用户自定义模板
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate)
|
|
|
|
|
.where(PromptTemplate.user_id == user_id)
|
|
|
|
|
.order_by(PromptTemplate.category, PromptTemplate.template_key)
|
|
|
|
|
)
|
|
|
|
|
user_templates = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
# 2. 获取所有系统默认模板
|
|
|
|
|
system_templates = PromptService.get_all_system_templates()
|
|
|
|
|
|
|
|
|
|
# 3. 构建用户自定义模板的键集合
|
|
|
|
|
user_template_keys = {t.template_key for t in user_templates}
|
|
|
|
|
|
|
|
|
|
# 4. 合并模板:用户自定义的 + 未自定义的系统默认
|
|
|
|
|
all_templates = []
|
|
|
|
|
current_time = datetime.now()
|
|
|
|
|
|
|
|
|
|
# 添加用户自定义的模板
|
|
|
|
|
for user_template in user_templates:
|
|
|
|
|
user_template.is_system_default = False # 标记为已自定义
|
|
|
|
|
all_templates.append(user_template)
|
|
|
|
|
|
|
|
|
|
# 添加未自定义的系统默认模板
|
|
|
|
|
for sys_template in system_templates:
|
|
|
|
|
if sys_template['template_key'] not in user_template_keys:
|
|
|
|
|
# 这个系统模板用户还没有自定义,创建临时对象
|
|
|
|
|
template_obj = PromptTemplate(
|
|
|
|
|
id=sys_template['template_key'], # 使用template_key作为临时ID
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
template_key=sys_template['template_key'],
|
|
|
|
|
template_name=sys_template['template_name'],
|
|
|
|
|
template_content=sys_template['content'],
|
|
|
|
|
description=sys_template['description'],
|
|
|
|
|
category=sys_template['category'],
|
|
|
|
|
parameters=json.dumps(sys_template['parameters']),
|
|
|
|
|
is_active=True,
|
|
|
|
|
is_system_default=True,
|
|
|
|
|
created_at=current_time,
|
|
|
|
|
updated_at=current_time
|
|
|
|
|
)
|
|
|
|
|
all_templates.append(template_obj)
|
|
|
|
|
|
|
|
|
|
# 5. 按分类分组
|
|
|
|
|
category_dict = {}
|
|
|
|
|
for template in all_templates:
|
|
|
|
|
cat = template.category or "未分类"
|
|
|
|
|
if cat not in category_dict:
|
|
|
|
|
category_dict[cat] = []
|
|
|
|
|
category_dict[cat].append(template)
|
|
|
|
|
|
|
|
|
|
# 6. 构建响应
|
|
|
|
|
response = []
|
|
|
|
|
for category, temps in sorted(category_dict.items()):
|
|
|
|
|
# 按template_key排序,确保顺序一致
|
|
|
|
|
temps.sort(key=lambda t: t.template_key)
|
|
|
|
|
response.append(PromptTemplateCategoryResponse(
|
|
|
|
|
category=category,
|
|
|
|
|
count=len(temps),
|
|
|
|
|
templates=temps
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/system-defaults")
|
|
|
|
|
async def get_system_defaults(
|
|
|
|
|
request: Request
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
获取所有系统默认提示词模板
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
# 从PromptService获取所有系统默认模板
|
|
|
|
|
system_templates = PromptService.get_all_system_templates()
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"templates": system_templates,
|
|
|
|
|
"total": len(system_templates)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/{template_key}", response_model=PromptTemplateResponse)
|
|
|
|
|
async def get_template(
|
|
|
|
|
template_key: str,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
获取指定的提示词模板
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(
|
|
|
|
|
PromptTemplate.user_id == user_id,
|
|
|
|
|
PromptTemplate.template_key == template_key
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
template = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not template:
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
|
|
|
|
|
|
|
|
|
return template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("", response_model=PromptTemplateResponse)
|
|
|
|
|
async def create_or_update_template(
|
|
|
|
|
data: PromptTemplateCreate,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
创建或更新提示词模板(Upsert)
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
# 查找现有模板
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(
|
|
|
|
|
PromptTemplate.user_id == user_id,
|
|
|
|
|
PromptTemplate.template_key == data.template_key
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
template = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if template:
|
|
|
|
|
# 更新现有模板
|
|
|
|
|
for key, value in data.model_dump(exclude_unset=True).items():
|
|
|
|
|
setattr(template, key, value)
|
|
|
|
|
logger.info(f"用户 {user_id} 更新模板 {data.template_key}")
|
|
|
|
|
else:
|
|
|
|
|
# 创建新模板
|
|
|
|
|
template = PromptTemplate(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
**data.model_dump()
|
|
|
|
|
)
|
|
|
|
|
db.add(template)
|
|
|
|
|
logger.info(f"用户 {user_id} 创建模板 {data.template_key}")
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(template)
|
|
|
|
|
|
|
|
|
|
return template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/{template_key}", response_model=PromptTemplateResponse)
|
|
|
|
|
async def update_template(
|
|
|
|
|
template_key: str,
|
|
|
|
|
data: PromptTemplateUpdate,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
更新提示词模板
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(
|
|
|
|
|
PromptTemplate.user_id == user_id,
|
|
|
|
|
PromptTemplate.template_key == template_key
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
template = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not template:
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
|
|
|
|
|
|
|
|
|
# 更新模板
|
|
|
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
|
|
|
for key, value in update_data.items():
|
|
|
|
|
setattr(template, key, value)
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(template)
|
|
|
|
|
logger.info(f"用户 {user_id} 更新模板 {template_key}")
|
|
|
|
|
|
|
|
|
|
return template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/{template_key}")
|
|
|
|
|
async def delete_template(
|
|
|
|
|
template_key: str,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
删除自定义提示词模板
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(
|
|
|
|
|
PromptTemplate.user_id == user_id,
|
|
|
|
|
PromptTemplate.template_key == template_key
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
template = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not template:
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"模板 {template_key} 不存在")
|
|
|
|
|
|
|
|
|
|
await db.delete(template)
|
|
|
|
|
await db.commit()
|
|
|
|
|
logger.info(f"用户 {user_id} 删除模板 {template_key}")
|
|
|
|
|
|
|
|
|
|
return {"message": "模板已删除", "template_key": template_key}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/{template_key}/reset")
|
|
|
|
|
async def reset_to_default(
|
|
|
|
|
template_key: str,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
重置为系统默认模板(删除用户自定义版本)
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
# 验证系统默认模板是否存在
|
|
|
|
|
system_template = PromptService.get_system_template_info(template_key)
|
|
|
|
|
if not system_template:
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"系统默认模板 {template_key} 不存在")
|
|
|
|
|
|
|
|
|
|
# 查找并删除用户的自定义模板
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(
|
|
|
|
|
PromptTemplate.user_id == user_id,
|
|
|
|
|
PromptTemplate.template_key == template_key
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
template = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if template:
|
|
|
|
|
await db.delete(template)
|
|
|
|
|
await db.commit()
|
|
|
|
|
logger.info(f"用户 {user_id} 删除自定义模板 {template_key},恢复为系统默认")
|
|
|
|
|
return {"message": "已重置为系统默认", "template_key": template_key}
|
|
|
|
|
else:
|
|
|
|
|
# 用户本来就没有自定义,已经是系统默认状态
|
|
|
|
|
logger.info(f"用户 {user_id} 的模板 {template_key} 本来就是系统默认")
|
|
|
|
|
return {"message": "已是系统默认状态", "template_key": template_key}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/export", response_model=PromptTemplateExport)
|
|
|
|
|
async def export_templates(
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
2025-12-30 10:04:41 +08:00
|
|
|
导出所有提示词模板(包括用户自定义和系统默认)
|
|
|
|
|
- 用户自定义的提示词标记为 is_customized=true
|
|
|
|
|
- 系统默认的提示词标记为 is_customized=false
|
2025-11-29 22:01:02 +08:00
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
# 1. 查询用户自定义模板
|
2025-11-29 22:01:02 +08:00
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(PromptTemplate.user_id == user_id)
|
|
|
|
|
)
|
2025-12-30 10:04:41 +08:00
|
|
|
user_templates = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
# 2. 获取所有系统默认模板
|
|
|
|
|
system_templates = PromptService.get_all_system_templates()
|
|
|
|
|
|
|
|
|
|
# 3. 构建用户自定义模板的键集合
|
|
|
|
|
user_template_keys = {t.template_key for t in user_templates}
|
2025-11-29 22:01:02 +08:00
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
# 4. 准备导出数据
|
|
|
|
|
export_items = []
|
|
|
|
|
customized_count = 0
|
|
|
|
|
system_default_count = 0
|
|
|
|
|
|
|
|
|
|
# 添加用户自定义的模板
|
|
|
|
|
for user_template in user_templates:
|
|
|
|
|
# 获取对应的系统模板用于计算哈希
|
|
|
|
|
system_template = next(
|
|
|
|
|
(t for t in system_templates if t["template_key"] == user_template.template_key),
|
|
|
|
|
None
|
|
|
|
|
)
|
|
|
|
|
system_hash = calculate_content_hash(system_template["content"]) if system_template else None
|
|
|
|
|
|
|
|
|
|
export_items.append(PromptTemplateExportItem(
|
|
|
|
|
template_key=user_template.template_key,
|
|
|
|
|
template_name=user_template.template_name,
|
|
|
|
|
template_content=user_template.template_content,
|
|
|
|
|
description=user_template.description,
|
|
|
|
|
category=user_template.category,
|
|
|
|
|
parameters=user_template.parameters,
|
|
|
|
|
is_active=user_template.is_active,
|
|
|
|
|
is_customized=True,
|
|
|
|
|
system_content_hash=system_hash
|
|
|
|
|
))
|
|
|
|
|
customized_count += 1
|
|
|
|
|
|
|
|
|
|
# 添加未自定义的系统默认模板
|
|
|
|
|
for sys_template in system_templates:
|
|
|
|
|
if sys_template['template_key'] not in user_template_keys:
|
|
|
|
|
export_items.append(PromptTemplateExportItem(
|
|
|
|
|
template_key=sys_template['template_key'],
|
|
|
|
|
template_name=sys_template['template_name'],
|
|
|
|
|
template_content=sys_template['content'],
|
|
|
|
|
description=sys_template['description'],
|
|
|
|
|
category=sys_template['category'],
|
|
|
|
|
parameters=json.dumps(sys_template['parameters']),
|
|
|
|
|
is_active=True,
|
|
|
|
|
is_customized=False,
|
|
|
|
|
system_content_hash=calculate_content_hash(sys_template['content'])
|
|
|
|
|
))
|
|
|
|
|
system_default_count += 1
|
|
|
|
|
|
|
|
|
|
statistics = {
|
|
|
|
|
"total": len(export_items),
|
|
|
|
|
"customized": customized_count,
|
|
|
|
|
"system_default": system_default_count
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
logger.info(f"用户 {user_id} 导出了 {statistics['total']} 个模板 "
|
|
|
|
|
f"(自定义: {statistics['customized']}, 系统默认: {statistics['system_default']})")
|
2025-11-29 22:01:02 +08:00
|
|
|
|
|
|
|
|
return PromptTemplateExport(
|
2025-12-30 10:04:41 +08:00
|
|
|
templates=export_items,
|
|
|
|
|
export_time=datetime.now(),
|
|
|
|
|
version="2.0",
|
|
|
|
|
statistics=statistics
|
2025-11-29 22:01:02 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
@router.post("/import", response_model=PromptTemplateImportResult)
|
2025-11-29 22:01:02 +08:00
|
|
|
async def import_templates(
|
|
|
|
|
data: PromptTemplateExport,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db)
|
|
|
|
|
):
|
|
|
|
|
"""
|
2025-12-30 10:04:41 +08:00
|
|
|
智能导入提示词模板
|
|
|
|
|
- 如果导入的是系统默认且内容未修改 → 删除自定义记录(使用系统默认)
|
|
|
|
|
- 如果导入的是系统默认但内容已修改 → 创建自定义记录
|
|
|
|
|
- 如果导入的是用户自定义 → 创建/更新自定义记录
|
2025-11-29 22:01:02 +08:00
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
# 获取所有系统默认模板用于比对
|
|
|
|
|
system_templates = PromptService.get_all_system_templates()
|
|
|
|
|
system_template_dict = {t["template_key"]: t for t in system_templates}
|
|
|
|
|
|
|
|
|
|
# 统计信息
|
|
|
|
|
kept_system_default = 0 # 保持系统默认
|
|
|
|
|
created_or_updated = 0 # 创建或更新自定义
|
|
|
|
|
converted_to_custom = 0 # 从系统默认转为自定义
|
|
|
|
|
converted_templates = [] # 被转换的模板列表
|
2025-11-29 22:01:02 +08:00
|
|
|
|
|
|
|
|
for template_data in data.templates:
|
2025-12-30 10:04:41 +08:00
|
|
|
template_key = template_data.template_key
|
|
|
|
|
is_customized = template_data.is_customized
|
|
|
|
|
imported_content = template_data.template_content.strip()
|
|
|
|
|
|
|
|
|
|
# 查找当前用户是否已有该模板的自定义版本
|
2025-11-29 22:01:02 +08:00
|
|
|
result = await db.execute(
|
|
|
|
|
select(PromptTemplate).where(
|
|
|
|
|
PromptTemplate.user_id == user_id,
|
2025-12-30 10:04:41 +08:00
|
|
|
PromptTemplate.template_key == template_key
|
2025-11-29 22:01:02 +08:00
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
existing = result.scalar_one_or_none()
|
|
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
# 获取系统默认模板
|
|
|
|
|
system_template = system_template_dict.get(template_key)
|
|
|
|
|
|
|
|
|
|
if not is_customized:
|
|
|
|
|
# 导入的标记为系统默认
|
|
|
|
|
if system_template:
|
|
|
|
|
system_content = system_template["content"].strip()
|
|
|
|
|
|
|
|
|
|
# 比对内容是否与系统默认一致
|
|
|
|
|
if imported_content == system_content:
|
|
|
|
|
# 内容一致,删除自定义记录(如果有)
|
|
|
|
|
if existing:
|
|
|
|
|
await db.delete(existing)
|
|
|
|
|
logger.info(f"用户 {user_id} 的模板 {template_key} 恢复为系统默认(删除自定义)")
|
|
|
|
|
kept_system_default += 1
|
|
|
|
|
else:
|
|
|
|
|
# 内容不一致,用户修改过,创建/更新为自定义
|
|
|
|
|
if existing:
|
|
|
|
|
# 更新现有自定义
|
|
|
|
|
existing.template_name = template_data.template_name
|
|
|
|
|
existing.template_content = template_data.template_content
|
|
|
|
|
existing.description = template_data.description
|
|
|
|
|
existing.category = template_data.category
|
|
|
|
|
existing.parameters = template_data.parameters
|
|
|
|
|
existing.is_active = template_data.is_active
|
|
|
|
|
else:
|
|
|
|
|
# 创建新自定义
|
|
|
|
|
new_template = PromptTemplate(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
template_key=template_data.template_key,
|
|
|
|
|
template_name=template_data.template_name,
|
|
|
|
|
template_content=template_data.template_content,
|
|
|
|
|
description=template_data.description,
|
|
|
|
|
category=template_data.category,
|
|
|
|
|
parameters=template_data.parameters,
|
|
|
|
|
is_active=template_data.is_active
|
|
|
|
|
)
|
|
|
|
|
db.add(new_template)
|
|
|
|
|
|
|
|
|
|
converted_to_custom += 1
|
|
|
|
|
converted_templates.append({
|
|
|
|
|
"template_key": template_key,
|
|
|
|
|
"template_name": template_data.template_name,
|
|
|
|
|
"reason": "内容与系统默认不一致,已转为自定义"
|
|
|
|
|
})
|
|
|
|
|
logger.info(f"用户 {user_id} 的模板 {template_key} 内容已修改,转为自定义")
|
|
|
|
|
else:
|
|
|
|
|
# 系统中不存在该模板,作为自定义导入
|
|
|
|
|
if existing:
|
|
|
|
|
existing.template_name = template_data.template_name
|
|
|
|
|
existing.template_content = template_data.template_content
|
|
|
|
|
existing.description = template_data.description
|
|
|
|
|
existing.category = template_data.category
|
|
|
|
|
existing.parameters = template_data.parameters
|
|
|
|
|
existing.is_active = template_data.is_active
|
|
|
|
|
else:
|
|
|
|
|
new_template = PromptTemplate(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
template_key=template_data.template_key,
|
|
|
|
|
template_name=template_data.template_name,
|
|
|
|
|
template_content=template_data.template_content,
|
|
|
|
|
description=template_data.description,
|
|
|
|
|
category=template_data.category,
|
|
|
|
|
parameters=template_data.parameters,
|
|
|
|
|
is_active=template_data.is_active
|
|
|
|
|
)
|
|
|
|
|
db.add(new_template)
|
|
|
|
|
created_or_updated += 1
|
2025-11-29 22:01:02 +08:00
|
|
|
else:
|
2025-12-30 10:04:41 +08:00
|
|
|
# 导入的标记为用户自定义,直接创建/更新
|
|
|
|
|
if existing:
|
|
|
|
|
existing.template_name = template_data.template_name
|
|
|
|
|
existing.template_content = template_data.template_content
|
|
|
|
|
existing.description = template_data.description
|
|
|
|
|
existing.category = template_data.category
|
|
|
|
|
existing.parameters = template_data.parameters
|
|
|
|
|
existing.is_active = template_data.is_active
|
|
|
|
|
else:
|
|
|
|
|
new_template = PromptTemplate(
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
template_key=template_data.template_key,
|
|
|
|
|
template_name=template_data.template_name,
|
|
|
|
|
template_content=template_data.template_content,
|
|
|
|
|
description=template_data.description,
|
|
|
|
|
category=template_data.category,
|
|
|
|
|
parameters=template_data.parameters,
|
|
|
|
|
is_active=template_data.is_active
|
|
|
|
|
)
|
|
|
|
|
db.add(new_template)
|
|
|
|
|
created_or_updated += 1
|
2025-11-29 22:01:02 +08:00
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
2025-12-30 10:04:41 +08:00
|
|
|
statistics = {
|
|
|
|
|
"total": len(data.templates),
|
|
|
|
|
"kept_system_default": kept_system_default,
|
|
|
|
|
"created_or_updated": created_or_updated,
|
|
|
|
|
"converted_to_custom": converted_to_custom
|
2025-11-29 22:01:02 +08:00
|
|
|
}
|
2025-12-30 10:04:41 +08:00
|
|
|
|
|
|
|
|
logger.info(f"用户 {user_id} 导入完成: {statistics}")
|
|
|
|
|
|
|
|
|
|
return PromptTemplateImportResult(
|
|
|
|
|
message="导入成功",
|
|
|
|
|
statistics=statistics,
|
|
|
|
|
converted_templates=converted_templates
|
|
|
|
|
)
|
2025-11-29 22:01:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/{template_key}/preview")
|
|
|
|
|
async def preview_template(
|
|
|
|
|
template_key: str,
|
|
|
|
|
data: PromptTemplatePreviewRequest,
|
|
|
|
|
request: Request
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
预览提示词模板(渲染变量)
|
|
|
|
|
"""
|
|
|
|
|
# 从认证中间件获取用户ID
|
|
|
|
|
user_id = getattr(request.state, 'user_id', None)
|
|
|
|
|
if not user_id:
|
|
|
|
|
raise HTTPException(status_code=401, detail="未登录")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 使用PromptService的format_prompt方法
|
|
|
|
|
rendered = PromptService.format_prompt(
|
|
|
|
|
data.template_content,
|
|
|
|
|
**data.parameters
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"success": True,
|
|
|
|
|
"rendered_content": rendered,
|
|
|
|
|
"parameters_used": list(data.parameters.keys())
|
|
|
|
|
}
|
|
|
|
|
except KeyError as e:
|
|
|
|
|
return {
|
|
|
|
|
"success": False,
|
|
|
|
|
"error": f"缺少必需的参数: {str(e)}",
|
|
|
|
|
"rendered_content": None
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
return {
|
|
|
|
|
"success": False,
|
|
|
|
|
"error": f"渲染失败: {str(e)}",
|
|
|
|
|
"rendered_content": None
|
|
|
|
|
}
|