fix:修复提示词工坊公开接口验证逻辑
This commit is contained in:
@@ -36,6 +36,45 @@ def get_user_identifier(user_id: str) -> str:
|
|||||||
return f"{INSTANCE_ID}:{user_id}"
|
return f"{INSTANCE_ID}:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_identifier_from_request(request: Request) -> str:
|
||||||
|
"""
|
||||||
|
从请求中获取用户标识
|
||||||
|
- 服务端模式:优先从 Header 获取(来自其他实例的代理请求)
|
||||||
|
- 客户端模式:从本地用户获取
|
||||||
|
"""
|
||||||
|
if is_workshop_server():
|
||||||
|
# 服务端模式:检查是否来自其他实例的代理请求
|
||||||
|
instance_id = request.headers.get("X-Instance-ID")
|
||||||
|
header_user_id = request.headers.get("X-User-ID")
|
||||||
|
if instance_id and header_user_id:
|
||||||
|
# 来自其他实例的请求,使用 Header 中的用户标识
|
||||||
|
return header_user_id
|
||||||
|
|
||||||
|
# 本地用户
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="未登录或用户ID缺失")
|
||||||
|
return get_user_identifier(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optional_user_identifier(request: Request) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取可选的用户标识(用于公开API,可以没有用户)
|
||||||
|
"""
|
||||||
|
if is_workshop_server():
|
||||||
|
# 服务端模式:检查是否来自其他实例的代理请求
|
||||||
|
instance_id = request.headers.get("X-Instance-ID")
|
||||||
|
header_user_id = request.headers.get("X-User-ID")
|
||||||
|
if instance_id and header_user_id:
|
||||||
|
return header_user_id
|
||||||
|
|
||||||
|
# 本地用户
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if user_id:
|
||||||
|
return get_user_identifier(user_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _item_to_dict(item: PromptWorkshopItem, is_liked: bool = False) -> dict:
|
def _item_to_dict(item: PromptWorkshopItem, is_liked: bool = False) -> dict:
|
||||||
"""将模型转换为字典"""
|
"""将模型转换为字典"""
|
||||||
return {
|
return {
|
||||||
@@ -120,9 +159,8 @@ async def get_items(
|
|||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取提示词列表"""
|
"""获取提示词列表(公开接口,不需要登录)"""
|
||||||
user_id = getattr(request.state, 'user_id', None)
|
user_identifier = get_optional_user_identifier(request)
|
||||||
user_identifier = get_user_identifier(user_id) if user_id else None
|
|
||||||
|
|
||||||
if is_workshop_server():
|
if is_workshop_server():
|
||||||
# 服务端模式:直接查询本地数据库
|
# 服务端模式:直接查询本地数据库
|
||||||
@@ -317,8 +355,7 @@ async def toggle_like(
|
|||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""点赞/取消点赞"""
|
"""点赞/取消点赞"""
|
||||||
user_id = get_current_user_id(request)
|
user_identifier = get_user_identifier_from_request(request)
|
||||||
user_identifier = get_user_identifier(user_id)
|
|
||||||
|
|
||||||
if is_workshop_server():
|
if is_workshop_server():
|
||||||
# 检查是否已点赞
|
# 检查是否已点赞
|
||||||
@@ -393,10 +430,27 @@ async def submit_prompt(
|
|||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""提交提示词"""
|
"""提交提示词"""
|
||||||
user_id = get_current_user_id(request)
|
user_identifier = get_user_identifier_from_request(request)
|
||||||
user_identifier = get_user_identifier(user_id)
|
|
||||||
|
|
||||||
# 获取用户显示名称
|
# 获取用户显示名称
|
||||||
|
submitter_name = "未知用户"
|
||||||
|
if is_workshop_server():
|
||||||
|
# 服务端模式:检查是否来自代理请求
|
||||||
|
instance_id = request.headers.get("X-Instance-ID")
|
||||||
|
if instance_id:
|
||||||
|
# 代理请求,从请求数据中获取提交者名称
|
||||||
|
submitter_name = data.author_display_name or "未知用户"
|
||||||
|
else:
|
||||||
|
# 本地请求
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if user_id:
|
||||||
|
from app.user_manager import user_manager
|
||||||
|
user = await user_manager.get_user(user_id)
|
||||||
|
submitter_name = user.display_name if user else "未知用户"
|
||||||
|
else:
|
||||||
|
# 客户端模式:本地用户
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if user_id:
|
||||||
from app.user_manager import user_manager
|
from app.user_manager import user_manager
|
||||||
user = await user_manager.get_user(user_id)
|
user = await user_manager.get_user(user_id)
|
||||||
submitter_name = user.display_name if user else "未知用户"
|
submitter_name = user.display_name if user else "未知用户"
|
||||||
@@ -448,8 +502,7 @@ async def get_my_submissions(
|
|||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取我的提交记录"""
|
"""获取我的提交记录"""
|
||||||
user_id = get_current_user_id(request)
|
user_identifier = get_user_identifier_from_request(request)
|
||||||
user_identifier = get_user_identifier(user_id)
|
|
||||||
|
|
||||||
if is_workshop_server():
|
if is_workshop_server():
|
||||||
query = select(PromptSubmission).where(
|
query = select(PromptSubmission).where(
|
||||||
@@ -483,8 +536,7 @@ async def withdraw_submission(
|
|||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""撤回待审核的提交"""
|
"""撤回待审核的提交"""
|
||||||
user_id = get_current_user_id(request)
|
user_identifier = get_user_identifier_from_request(request)
|
||||||
user_identifier = get_user_identifier(user_id)
|
|
||||||
|
|
||||||
if is_workshop_server():
|
if is_workshop_server():
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
|
|||||||
Reference in New Issue
Block a user