From ed1fde42e9dc7329f85426ff05ee4c665555ae19 Mon Sep 17 00:00:00 2001 From: xiamuceer-j Date: Tue, 27 Jan 2026 14:21:50 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8D=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E5=B7=A5=E5=9D=8A=E5=85=AC=E5=BC=80=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E9=AA=8C=E8=AF=81=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/prompt_workshop.py | 80 ++++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/backend/app/api/prompt_workshop.py b/backend/app/api/prompt_workshop.py index d996f2d..10adc4b 100644 --- a/backend/app/api/prompt_workshop.py +++ b/backend/app/api/prompt_workshop.py @@ -36,6 +36,45 @@ def get_user_identifier(user_id: str) -> str: return f"{INSTANCE_ID}:{user_id}" +def get_user_identifier_from_request(request: Request) -> str: + """ + 从请求中获取用户标识 + - 服务端模式:优先从 Header 获取(来自其他实例的代理请求) + - 客户端模式:从本地用户获取 + """ + if is_workshop_server(): + # 服务端模式:检查是否来自其他实例的代理请求 + instance_id = request.headers.get("X-Instance-ID") + header_user_id = request.headers.get("X-User-ID") + if instance_id and header_user_id: + # 来自其他实例的请求,使用 Header 中的用户标识 + return header_user_id + + # 本地用户 + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="未登录或用户ID缺失") + return get_user_identifier(user_id) + + +def get_optional_user_identifier(request: Request) -> Optional[str]: + """ + 获取可选的用户标识(用于公开API,可以没有用户) + """ + if is_workshop_server(): + # 服务端模式:检查是否来自其他实例的代理请求 + instance_id = request.headers.get("X-Instance-ID") + header_user_id = request.headers.get("X-User-ID") + if instance_id and header_user_id: + return header_user_id + + # 本地用户 + user_id = getattr(request.state, 'user_id', None) + if user_id: + return get_user_identifier(user_id) + return None + + def _item_to_dict(item: PromptWorkshopItem, is_liked: bool = False) -> dict: """将模型转换为字典""" return { @@ -120,9 +159,8 @@ async def get_items( limit: int = 20, db: AsyncSession = Depends(get_db) ): - """获取提示词列表""" - user_id = getattr(request.state, 'user_id', None) - user_identifier = get_user_identifier(user_id) if user_id else None + """获取提示词列表(公开接口,不需要登录)""" + user_identifier = get_optional_user_identifier(request) if is_workshop_server(): # 服务端模式:直接查询本地数据库 @@ -317,8 +355,7 @@ async def toggle_like( db: AsyncSession = Depends(get_db) ): """点赞/取消点赞""" - user_id = get_current_user_id(request) - user_identifier = get_user_identifier(user_id) + user_identifier = get_user_identifier_from_request(request) if is_workshop_server(): # 检查是否已点赞 @@ -393,13 +430,30 @@ async def submit_prompt( db: AsyncSession = Depends(get_db) ): """提交提示词""" - user_id = get_current_user_id(request) - user_identifier = get_user_identifier(user_id) + user_identifier = get_user_identifier_from_request(request) # 获取用户显示名称 - from app.user_manager import user_manager - user = await user_manager.get_user(user_id) - submitter_name = user.display_name if user else "未知用户" + submitter_name = "未知用户" + if is_workshop_server(): + # 服务端模式:检查是否来自代理请求 + instance_id = request.headers.get("X-Instance-ID") + if instance_id: + # 代理请求,从请求数据中获取提交者名称 + submitter_name = data.author_display_name or "未知用户" + else: + # 本地请求 + user_id = getattr(request.state, 'user_id', None) + if user_id: + from app.user_manager import user_manager + user = await user_manager.get_user(user_id) + submitter_name = user.display_name if user else "未知用户" + else: + # 客户端模式:本地用户 + user_id = getattr(request.state, 'user_id', None) + if user_id: + from app.user_manager import user_manager + user = await user_manager.get_user(user_id) + submitter_name = user.display_name if user else "未知用户" if is_workshop_server(): # 直接创建提交记录 @@ -448,8 +502,7 @@ async def get_my_submissions( db: AsyncSession = Depends(get_db) ): """获取我的提交记录""" - user_id = get_current_user_id(request) - user_identifier = get_user_identifier(user_id) + user_identifier = get_user_identifier_from_request(request) if is_workshop_server(): query = select(PromptSubmission).where( @@ -483,8 +536,7 @@ async def withdraw_submission( db: AsyncSession = Depends(get_db) ): """撤回待审核的提交""" - user_id = get_current_user_id(request) - user_identifier = get_user_identifier(user_id) + user_identifier = get_user_identifier_from_request(request) if is_workshop_server(): result = await db.execute(