From 4af9a31eba516de2ad975ea4f21691dcfaa2efe2 Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Fri, 24 Apr 2026 10:11:23 +0800 Subject: [PATCH] =?UTF-8?q?update:=20=E4=BF=AE=E5=A4=8D=E5=9F=BA=E4=BA=8E?= =?UTF-8?q?=E9=95=BF=E4=BA=ADmonkeycode=E6=89=AB=E6=8F=8F=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E7=9A=8412=E5=A4=84=E5=AE=89=E5=85=A8=E6=BC=8F?= =?UTF-8?q?=E6=B4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...a1b2c3d4e5f_expand_password_hash_length.py | 42 +++++++ ...b12cd34ef56_expand_password_hash_length.py | 38 ++++++ backend/app/api/admin.py | 15 ++- backend/app/api/auth.py | 33 ++++-- backend/app/api/changelog.py | 12 +- backend/app/api/mcp_plugins.py | 46 ++++++-- backend/app/api/settings.py | 7 +- backend/app/api/users.py | 19 +-- backend/app/api/wizard_stream.py | 40 +++---- backend/app/api/writing_styles.py | 6 +- backend/app/config.py | 3 +- backend/app/main.py | 22 +++- backend/app/middleware/auth_middleware.py | 7 +- backend/app/models/user.py | 4 +- backend/app/security.py | 108 ++++++++++++++++++ backend/app/services/workshop_client.py | 4 +- backend/app/user_password.py | 35 +++++- 17 files changed, 366 insertions(+), 75 deletions(-) create mode 100644 backend/alembic/postgres/versions/20260424_1006_9a1b2c3d4e5f_expand_password_hash_length.py create mode 100644 backend/alembic/sqlite/versions/20260424_1006_ab12cd34ef56_expand_password_hash_length.py create mode 100644 backend/app/security.py diff --git a/backend/alembic/postgres/versions/20260424_1006_9a1b2c3d4e5f_expand_password_hash_length.py b/backend/alembic/postgres/versions/20260424_1006_9a1b2c3d4e5f_expand_password_hash_length.py new file mode 100644 index 0000000..2c7ea5a --- /dev/null +++ b/backend/alembic/postgres/versions/20260424_1006_9a1b2c3d4e5f_expand_password_hash_length.py @@ -0,0 +1,42 @@ +"""expand password hash length + +Revision ID: 9a1b2c3d4e5f +Revises: 6eb27fce64de +Create Date: 2026-04-24 10:06:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9a1b2c3d4e5f' +down_revision: Union[str, None] = '6eb27fce64de' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.alter_column( + 'user_passwords', + 'password_hash', + existing_type=sa.String(length=64), + type_=sa.String(length=255), + existing_nullable=False, + existing_comment='密码哈希(SHA256)', + comment='密码哈希', + ) + + +def downgrade() -> None: + op.alter_column( + 'user_passwords', + 'password_hash', + existing_type=sa.String(length=255), + type_=sa.String(length=64), + existing_nullable=False, + existing_comment='密码哈希', + comment='密码哈希(SHA256)', + ) diff --git a/backend/alembic/sqlite/versions/20260424_1006_ab12cd34ef56_expand_password_hash_length.py b/backend/alembic/sqlite/versions/20260424_1006_ab12cd34ef56_expand_password_hash_length.py new file mode 100644 index 0000000..f67d88c --- /dev/null +++ b/backend/alembic/sqlite/versions/20260424_1006_ab12cd34ef56_expand_password_hash_length.py @@ -0,0 +1,38 @@ +"""expand password hash length + +Revision ID: ab12cd34ef56 +Revises: 6ff45db05863 +Create Date: 2026-04-24 10:06:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ab12cd34ef56' +down_revision: Union[str, None] = '6ff45db05863' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table('user_passwords', schema=None) as batch_op: + batch_op.alter_column( + 'password_hash', + existing_type=sa.String(length=64), + type_=sa.String(length=255), + existing_nullable=False, + ) + + +def downgrade() -> None: + with op.batch_alter_table('user_passwords', schema=None) as batch_op: + batch_op.alter_column( + 'password_hash', + existing_type=sa.String(length=255), + type_=sa.String(length=64), + existing_nullable=False, + ) diff --git a/backend/app/api/admin.py b/backend/app/api/admin.py index b42db08..682c53a 100644 --- a/backend/app/api/admin.py +++ b/backend/app/api/admin.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from typing import Optional, List from datetime import datetime import hashlib +import secrets from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db @@ -46,7 +47,7 @@ class ToggleStatusRequest(BaseModel): class ResetPasswordRequest(BaseModel): """重置密码请求""" - new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码") + new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则生成临时密码") class UserResponse(BaseModel): @@ -308,10 +309,14 @@ async def reset_password( raise HTTPException(status_code=404, detail="用户不存在") # 重置密码 - actual_password = await password_manager.set_password( + generated_password = data.new_password + if not generated_password: + generated_password = secrets.token_urlsafe(12) + + await password_manager.set_password( user_id=user_id, username=target_user.username, - password=data.new_password + password=generated_password ) logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码") @@ -319,7 +324,7 @@ async def reset_password( return { "success": True, "message": "密码重置成功", - "new_password": actual_password + "temporary_password": generated_password if not data.new_password else None } except HTTPException: @@ -385,4 +390,4 @@ async def delete_user( raise except Exception as e: logger.error(f"删除用户失败: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}") \ No newline at end of file + raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}") diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 0b4dd51..6f93675 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse from pydantic import BaseModel from typing import Optional import hashlib -import random +import secrets import re from datetime import datetime, timedelta, timezone from sqlalchemy import select @@ -21,6 +21,7 @@ from app.database import get_engine from app.models.user import User as UserModel from app.models.settings import Settings as SettingsModel from app.services.email_service import email_service +from app.security import create_session_token # 中国时区 UTC+8 CHINA_TZ = timezone(timedelta(hours=8)) @@ -43,6 +44,7 @@ _state_storage = {} # 邮箱验证码临时存储(生产环境应使用 Redis) _email_verification_storage = {} +MAX_VERIFICATION_ATTEMPTS = 5 EMAIL_REGEX = re.compile(r"^[^\s@]+@[^\s@]+\.[^\s@]+$") @@ -246,12 +248,14 @@ def _validate_password(password: str): def _set_login_cookies(response: Response, user_id: str): """设置登录 Cookie""" max_age = settings.SESSION_EXPIRE_MINUTES * 60 + session_token = create_session_token(user_id, max_age) response.set_cookie( - key="user_id", - value=user_id, + key="session_token", + value=session_token, max_age=max_age, httponly=True, - samesite="lax" + samesite="lax", + secure=not settings.debug, ) china_now = get_china_now() @@ -263,12 +267,13 @@ def _set_login_cookies(response: Response, user_id: str): value=str(expire_at), max_age=max_age, httponly=False, - samesite="lax" + samesite="lax", + secure=not settings.debug, ) def _generate_verification_code() -> str: - return f"{random.randint(0, 999999):06d}" + return f"{secrets.randbelow(1000000):06d}" def _build_verification_mail_content(scene: str, code: str, ttl_minutes: int) -> tuple[str, str, str]: @@ -456,6 +461,7 @@ async def send_email_verification_code(request: EmailSendCodeRequest): "code": code, "expires_at": expires_at, "last_sent_at": now, + "attempts": 0, } logger.info(f"[邮箱验证码] 场景={scene} 已发送到 {email}") @@ -493,6 +499,10 @@ async def email_register(request: EmailRegisterRequest, response: Response): raise HTTPException(status_code=400, detail="验证码已过期,请重新发送") if cached["code"] != code: + cached["attempts"] = cached.get("attempts", 0) + 1 + if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS: + _email_verification_storage.pop(_get_verification_storage_key("register", email), None) + raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送") raise HTTPException(status_code=400, detail="验证码错误") existing_user = await _find_user_by_email(email) @@ -541,6 +551,10 @@ async def email_login(request: EmailLoginRequest, response: Response): raise HTTPException(status_code=400, detail="登录验证码已过期,请重新发送") if cached["code"] != code: + cached["attempts"] = cached.get("attempts", 0) + 1 + if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS: + _email_verification_storage.pop(storage_key, None) + raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送") raise HTTPException(status_code=400, detail="登录验证码错误") _email_verification_storage.pop(storage_key, None) @@ -588,6 +602,10 @@ async def email_reset_password(request: EmailResetPasswordRequest): raise HTTPException(status_code=400, detail="重置密码验证码已过期,请重新发送") if cached["code"] != code: + cached["attempts"] = cached.get("attempts", 0) + 1 + if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS: + _email_verification_storage.pop(storage_key, None) + raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送") raise HTTPException(status_code=400, detail="重置密码验证码错误") await password_manager.set_password(user.user_id, email, request.new_password) @@ -757,6 +775,7 @@ async def logout(request: Request, response: Response): logger.info(f"🚪 [退出] 用户 {user_id} 退出登录") response.delete_cookie("user_id") + response.delete_cookie("session_token") response.delete_cookie("session_expire_at") return {"message": "退出登录成功"} @@ -782,8 +801,6 @@ async def get_password_status(request: Request): username = await password_manager.get_username(user.user_id) default_password = None - if has_password and not has_custom: - default_password = f"{user.username}@666" return PasswordStatusResponse( has_password=has_password, diff --git a/backend/app/api/changelog.py b/backend/app/api/changelog.py index 86adb61..632bb25 100644 --- a/backend/app/api/changelog.py +++ b/backend/app/api/changelog.py @@ -2,7 +2,7 @@ 更新日志API 提供GitHub提交历史的缓存和代理服务 """ -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Request, Depends from typing import List, Optional import httpx from datetime import datetime, timedelta @@ -13,6 +13,12 @@ logger = logging.getLogger(__name__) router = APIRouter() + +def require_login(request: Request): + if not hasattr(request.state, "user") or not request.state.user: + raise HTTPException(status_code=401, detail="需要登录") + return request.state.user + # GitHub API配置 GITHUB_API_BASE = "https://api.github.com" REPO_OWNER = "xiamuceer-j" @@ -173,7 +179,7 @@ async def get_changelog( @router.post("/changelog/refresh") -async def refresh_changelog(): +async def refresh_changelog(user=Depends(require_login)): """ 刷新更新日志缓存 @@ -230,4 +236,4 @@ async def refresh_changelog(): raise HTTPException( status_code=500, detail=f"刷新缓存失败: {str(e)}" - ) \ No newline at end of file + ) diff --git a/backend/app/api/mcp_plugins.py b/backend/app/api/mcp_plugins.py index 3767ff8..d732389 100644 --- a/backend/app/api/mcp_plugins.py +++ b/backend/app/api/mcp_plugins.py @@ -24,11 +24,22 @@ from app.user_manager import User from app.mcp import mcp_client, MCPPluginConfig, PluginStatus from app.services.mcp_test_service import mcp_test_service from app.logger import get_logger +from app.security import validate_public_http_url logger = get_logger(__name__) router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"]) +HTTP_PLUGIN_TYPES = {"http", "streamable_http", "sse"} + + +def _validate_mcp_server_url(plugin_type: str, server_url: Optional[str]) -> Optional[str]: + if plugin_type in HTTP_PLUGIN_TYPES: + if not server_url: + raise HTTPException(status_code=400, detail=f"{plugin_type}类型插件必须提供server_url") + return validate_public_http_url(server_url) + return server_url + def require_login(request: Request) -> User: """依赖:要求用户已登录""" @@ -53,7 +64,8 @@ async def _register_plugin_background( try: logger.info(f"后台注册MCP插件: {plugin_name}") - if plugin_type in ["http", "streamable_http", "sse"] and server_url: + if plugin_type in HTTP_PLUGIN_TYPES and server_url: + server_url = _validate_mcp_server_url(plugin_type, server_url) success = await mcp_client.register(MCPPluginConfig( user_id=user_id, plugin_name=plugin_name, @@ -123,11 +135,12 @@ async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool: Returns: 是否注册成功 """ - if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url: + if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url: + server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url) return await mcp_client.register(MCPPluginConfig( user_id=user_id, plugin_name=plugin.plugin_name, - url=plugin.server_url, + url=server_url, plugin_type=plugin.plugin_type, headers=plugin.headers, timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0 @@ -187,6 +200,10 @@ async def create_plugin( # 创建插件数据 plugin_data = data.model_dump() + plugin_data["server_url"] = _validate_mcp_server_url( + plugin_data.get("plugin_type", "http"), + plugin_data.get("server_url") + ) # 如果没有提供display_name,使用plugin_name作为默认值 if not plugin_data.get("display_name"): @@ -278,12 +295,9 @@ async def create_plugin_simple( "sort_order": 0 } - if server_type in ["http", "streamable_http", "sse"]: - plugin_data["server_url"] = server_config.get("url") + if server_type in HTTP_PLUGIN_TYPES: + plugin_data["server_url"] = _validate_mcp_server_url(server_type, server_config.get("url")) plugin_data["headers"] = server_config.get("headers", {}) - - if not plugin_data["server_url"]: - raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段") elif server_type == "stdio": plugin_data["command"] = server_config.get("command") @@ -415,6 +429,12 @@ async def update_plugin( # 更新字段 update_data = data.model_dump(exclude_unset=True) + target_type = update_data.get("plugin_type", plugin.plugin_type) + if "server_url" in update_data or target_type in HTTP_PLUGIN_TYPES: + update_data["server_url"] = _validate_mcp_server_url( + target_type, + update_data.get("server_url", plugin.server_url) + ) for key, value in update_data.items(): setattr(plugin, key, value) @@ -501,7 +521,8 @@ async def toggle_plugin( if enabled: # 启用:注册到统一门面 try: - if plugin_type in ["http", "streamable_http", "sse"] and server_url: + if plugin_type in HTTP_PLUGIN_TYPES and server_url: + server_url = _validate_mcp_server_url(plugin_type, server_url) success = await mcp_client.register(MCPPluginConfig( user_id=user.user_id, plugin_name=plugin_name, @@ -647,11 +668,12 @@ async def _ensure_plugin_registered( """ try: # 使用ensure_registered方法,它会检查是否已注册 - if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url: + if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url: + server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url) return await mcp_client.ensure_registered( user_id=user_id, plugin_name=plugin.plugin_name, - url=plugin.server_url, + url=server_url, plugin_type=plugin.plugin_type, headers=plugin.headers ) @@ -912,4 +934,4 @@ async def call_mcp_tool( raise except Exception as e: logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}") - raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}") \ No newline at end of file + raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}") diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py index d33b2e7..13ec0bb 100644 --- a/backend/app/api/settings.py +++ b/backend/app/api/settings.py @@ -26,6 +26,7 @@ from app.logger import get_logger from app.config import settings as app_settings, PROJECT_ROOT from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider from app.services.email_service import email_service +from app.security import validate_public_http_url logger = get_logger(__name__) @@ -452,7 +453,8 @@ async def delete_settings( async def get_available_models( api_key: str, api_base_url: str, - provider: str = "openai" + provider: str = "openai", + user: User = Depends(require_login) ): """ 从配置的 API 获取可用的模型列表 @@ -467,6 +469,7 @@ async def get_available_models( """ try: provider = normalize_provider(provider) + api_base_url = validate_public_http_url(api_base_url) async with httpx.AsyncClient(timeout=10.0) as client: if provider == "openai" or provider == "azure" or provider == "custom": # OpenAI 兼容接口获取模型列表 @@ -1291,4 +1294,4 @@ async def create_preset_from_current( ) logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}") - return await create_preset(create_request, user, db) \ No newline at end of file + return await create_preset(create_request, user, db) diff --git a/backend/app/api/users.py b/backend/app/api/users.py index 6b72707..9104cc9 100644 --- a/backend/app/api/users.py +++ b/backend/app/api/users.py @@ -32,7 +32,7 @@ class SetAdminRequest(BaseModel): class ResetPasswordRequest(BaseModel): user_id: str - new_password: Optional[str] = None # 如果为空则使用默认密码 + new_password: Optional[str] = None # 如果为空则由系统生成临时密码 @router.get("/current") @@ -140,7 +140,7 @@ async def reset_user_password( 重置用户密码(仅管理员) 如果提供了 new_password,则设置为指定密码 - 如果未提供 new_password,则重置为默认密码(username@666) + 如果未提供 new_password,则由系统生成临时密码 限制: - 不能重置自己的密码(应该使用修改密码功能) @@ -162,10 +162,15 @@ async def reset_user_password( # 重置密码 try: - actual_password = await password_manager.set_password( + generated_password = data.new_password + if not generated_password: + import secrets + generated_password = secrets.token_urlsafe(12) + + await password_manager.set_password( target_user.user_id, target_user.username, - data.new_password + generated_password ) # 如果使用了默认密码,返回密码供管理员告知用户 @@ -177,8 +182,8 @@ async def reset_user_password( } if not data.new_password: - response_data["default_password"] = actual_password - response_data["message"] = f"密码已重置为默认密码: {actual_password}" + response_data["temporary_password"] = generated_password + response_data["message"] = "密码已重置为系统生成的临时密码,请尽快通知用户修改" return response_data @@ -186,4 +191,4 @@ async def reset_user_password( raise HTTPException( status_code=500, detail=f"重置密码失败: {str(e)}" - ) \ No newline at end of file + ) diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 5576d5d..56cdbab 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -26,6 +26,18 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"]) logger = get_logger(__name__) +async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None: + if not project_id or not user_id: + return None + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.user_id == user_id, + ) + ) + return result.scalar_one_or_none() + + async def world_building_generator( data: Dict[str, Any], db: AsyncSession, @@ -326,12 +338,9 @@ async def career_system_generator( # 获取项目信息 yield await tracker.loading("加载项目信息...") - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() + project = await get_owned_project(db, project_id, user_id) if not project: - yield await tracker.error("项目不存在", 404) + yield await tracker.error("项目不存在或无权访问", 404) return # 设置用户信息以启用MCP @@ -599,12 +608,9 @@ async def characters_generator( # 验证项目 yield await tracker.loading("验证项目...", 0.3) - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() + project = await get_owned_project(db, project_id, user_id) if not project: - yield await tracker.error("项目不存在", 404) + yield await tracker.error("项目不存在或无权访问", 404) return project.wizard_step = 2 @@ -1270,12 +1276,9 @@ async def outline_generator( # 获取项目信息 yield await tracker.loading("加载项目信息...", 0.3) - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() + project = await get_owned_project(db, project_id, user_id) if not project: - yield await tracker.error("项目不存在", 404) + yield await tracker.error("项目不存在或无权访问", 404) return # 获取角色信息 @@ -1551,12 +1554,9 @@ async def world_building_regenerate_generator( # 获取项目信息 yield await tracker.loading("加载项目信息...") - result = await db.execute( - select(Project).where(Project.id == project_id) - ) - project = result.scalar_one_or_none() + project = await get_owned_project(db, project_id, user_id) if not project: - yield await tracker.error("项目不存在", 404) + yield await tracker.error("项目不存在或无权访问", 404) return # 提取参数 diff --git a/backend/app/api/writing_styles.py b/backend/app/api/writing_styles.py index 045b9cd..4d7884e 100644 --- a/backend/app/api/writing_styles.py +++ b/backend/app/api/writing_styles.py @@ -275,15 +275,19 @@ async def get_project_styles( @router.get("/{style_id}", response_model=WritingStyleResponse) async def get_writing_style( style_id: int, + request: Request, db: AsyncSession = Depends(get_db) ): """获取单个写作风格详情""" + user_id = get_current_user_id(request) result = await db.execute( select(WritingStyle).where(WritingStyle.id == style_id) ) style = result.scalar_one_or_none() if not style: raise HTTPException(status_code=404, detail="写作风格不存在") + if style.user_id is not None and style.user_id != user_id: + raise HTTPException(status_code=404, detail="写作风格不存在") # 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound) result = await db.execute( @@ -501,4 +505,4 @@ async def initialize_default_styles( 该接口保留用于兼容性,直接返回项目可用的所有风格 """ # 直接返回项目可用的所有风格(全局预设 + 用户自定义) - return await get_project_styles(project_id, request, db) \ No newline at end of file + return await get_project_styles(project_id, request, db) diff --git a/backend/app/config.py b/backend/app/config.py index 9a3e117..5d8393c 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -29,7 +29,7 @@ class Settings(BaseSettings): app_version: str = "1.0.0" app_host: str = "0.0.0.0" app_port: int = 8000 - debug: bool = True + debug: bool = False # 日志配置 log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL @@ -106,6 +106,7 @@ class Settings(BaseSettings): # 会话配置 SESSION_EXPIRE_MINUTES: int = 120 # 会话过期时间(分钟),默认2小时 SESSION_REFRESH_THRESHOLD_MINUTES: int = 30 # 会话刷新阈值(分钟),剩余时间少于此值时可刷新 + SESSION_SECRET_KEY: Optional[str] = None # 会话签名密钥,生产环境必须配置为高强度随机值 # 系统 SMTP 默认配置(可被管理员系统设置覆盖) SMTP_PROVIDER: str = "qq" diff --git a/backend/app/main.py b/backend/app/main.py index d0f3d76..7006b1e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,5 +1,5 @@ """FastAPI应用主入口""" -from fastapi import FastAPI, Request, status +from fastapi import FastAPI, Request, status, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse, FileResponse @@ -106,7 +106,7 @@ async def health_check(): @app.get("/health/db-sessions") -async def db_session_stats(): +async def db_session_stats(request: Request): """ 数据库会话统计(监控连接泄漏) @@ -118,6 +118,8 @@ async def db_session_stats(): - generator_exits: SSE断开次数 - last_check: 最后检查时间 """ + if not getattr(request.state, "is_admin", False): + raise HTTPException(status_code=403, detail="需要管理员权限") return { "status": "ok", "session_stats": _session_stats, @@ -176,8 +178,18 @@ if static_dir.exists(): ) file_path = static_dir / full_path - if file_path.is_file(): - return FileResponse(file_path) + try: + resolved_file = file_path.resolve() + resolved_static = static_dir.resolve() + resolved_file.relative_to(resolved_static) + except ValueError: + return JSONResponse( + status_code=404, + content={"detail": "页面不存在"} + ) + + if resolved_file.is_file(): + return FileResponse(resolved_file) index_file = static_dir / "index.html" if index_file.exists(): @@ -207,4 +219,4 @@ if __name__ == "__main__": host=config_settings.app_host, port=config_settings.app_port, reload=config_settings.debug - ) \ No newline at end of file + ) diff --git a/backend/app/middleware/auth_middleware.py b/backend/app/middleware/auth_middleware.py index d2c73a9..fda01b9 100644 --- a/backend/app/middleware/auth_middleware.py +++ b/backend/app/middleware/auth_middleware.py @@ -6,6 +6,7 @@ from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from app.user_manager import user_manager from app.logger import get_logger +from app.security import verify_session_token logger = get_logger(__name__) @@ -46,8 +47,8 @@ class AuthMiddleware(BaseHTTPMiddleware): request.state.is_proxy_request = False request.state.proxy_instance_id = None - # 从 Cookie 中获取用户 ID - user_id = request.cookies.get("user_id") + # 优先验证签名会话 Cookie;不再信任客户端可伪造的明文 user_id。 + user_id = verify_session_token(request.cookies.get("session_token")) if user_id: user = await user_manager.get_user(user_id) @@ -77,4 +78,4 @@ class AuthMiddleware(BaseHTTPMiddleware): # 继续处理请求 response = await call_next(request) - return response \ No newline at end of file + return response diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 89d0bb5..845f998 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -41,7 +41,7 @@ class UserPassword(Base): user_id = Column(String(100), primary_key=True, index=True, comment="用户ID") username = Column(String(100), nullable=False, comment="用户名") - password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256)") + password_hash = Column(String(255), nullable=False, comment="密码哈希") has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码") created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间") - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间") \ No newline at end of file + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间") diff --git a/backend/app/security.py b/backend/app/security.py new file mode 100644 index 0000000..c75bcc2 --- /dev/null +++ b/backend/app/security.py @@ -0,0 +1,108 @@ +"""Security helpers for sessions and outbound URL validation.""" +import base64 +import hashlib +import hmac +import ipaddress +import json +import secrets +import socket +import time +from typing import Iterable +from urllib.parse import urlparse + +from fastapi import HTTPException + +from app.config import settings + + +def _session_secret() -> bytes: + secret = ( + getattr(settings, "SESSION_SECRET_KEY", None) + or getattr(settings, "session_secret_key", None) + or settings.LINUXDO_CLIENT_SECRET + or settings.LOCAL_AUTH_PASSWORD + or settings.openai_api_key + ) + if not secret: + secret = "mumuainovel-development-session-secret" + return str(secret).encode("utf-8") + + +def create_session_token(user_id: str, max_age_seconds: int) -> str: + payload = { + "uid": user_id, + "exp": int(time.time()) + max_age_seconds, + "nonce": secrets.token_urlsafe(16), + } + payload_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8") + payload_b64 = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode("ascii") + signature = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest() + signature_b64 = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("ascii") + return f"{payload_b64}.{signature_b64}" + + +def verify_session_token(token: str) -> str | None: + if not token or "." not in token: + return None + payload_b64, signature_b64 = token.split(".", 1) + expected = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest() + try: + provided = base64.urlsafe_b64decode(signature_b64 + "=" * (-len(signature_b64) % 4)) + except Exception: + return None + if not hmac.compare_digest(expected, provided): + return None + try: + payload_raw = base64.urlsafe_b64decode(payload_b64 + "=" * (-len(payload_b64) % 4)) + payload = json.loads(payload_raw.decode("utf-8")) + except Exception: + return None + if int(payload.get("exp", 0)) < int(time.time()): + return None + user_id = payload.get("uid") + return user_id if isinstance(user_id, str) and user_id else None + + +def _is_forbidden_ip(ip: ipaddress._BaseAddress) -> bool: + return any([ + ip.is_private, + ip.is_loopback, + ip.is_link_local, + ip.is_multicast, + ip.is_reserved, + ip.is_unspecified, + ]) + + +def validate_public_http_url(raw_url: str, *, allowed_schemes: Iterable[str] = ("https", "http")) -> str: + """Validate an outbound URL to reduce SSRF risk.""" + if not raw_url or not isinstance(raw_url, str): + raise HTTPException(status_code=400, detail="URL不能为空") + + parsed = urlparse(raw_url.strip()) + if parsed.scheme not in set(allowed_schemes): + raise HTTPException(status_code=400, detail="仅支持 HTTP/HTTPS URL") + if not parsed.hostname: + raise HTTPException(status_code=400, detail="URL缺少主机名") + if parsed.username or parsed.password: + raise HTTPException(status_code=400, detail="URL不允许包含认证信息") + + host = parsed.hostname.strip().rstrip(".") + if host.lower() in {"localhost", "localhost.localdomain"}: + raise HTTPException(status_code=400, detail="URL不允许指向本机地址") + + try: + ip = ipaddress.ip_address(host) + if _is_forbidden_ip(ip): + raise HTTPException(status_code=400, detail="URL不允许指向内网或保留地址") + except ValueError: + try: + infos = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80), type=socket.SOCK_STREAM) + except socket.gaierror: + raise HTTPException(status_code=400, detail="URL主机名无法解析") + for info in infos: + resolved_ip = ipaddress.ip_address(info[4][0]) + if _is_forbidden_ip(resolved_ip): + raise HTTPException(status_code=400, detail="URL解析到内网或保留地址") + + return raw_url.strip().rstrip("/") diff --git a/backend/app/services/workshop_client.py b/backend/app/services/workshop_client.py index 71a4251..53d67b8 100644 --- a/backend/app/services/workshop_client.py +++ b/backend/app/services/workshop_client.py @@ -38,7 +38,7 @@ class WorkshopClient: url = f"{self.base_url}/api/prompt-workshop{path}" try: - async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client: + async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.request( method=method, url=url, @@ -173,4 +173,4 @@ class WorkshopClient: # 全局客户端实例 -workshop_client = WorkshopClient() \ No newline at end of file +workshop_client = WorkshopClient() diff --git a/backend/app/user_password.py b/backend/app/user_password.py index 012e9a8..c9829af 100644 --- a/backend/app/user_password.py +++ b/backend/app/user_password.py @@ -3,6 +3,8 @@ """ import asyncio import hashlib +import hmac +import secrets from typing import Optional from datetime import datetime from sqlalchemy import select @@ -34,7 +36,28 @@ class UserPasswordManager: def _hash_password(self, password: str) -> str: """密码哈希""" - return hashlib.sha256(password.encode()).hexdigest() + salt = secrets.token_hex(16) + iterations = 260000 + digest = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), iterations).hex() + return f"pbkdf2_sha256${iterations}${salt}${digest}" + + def _verify_hash(self, password: str, stored_hash: str) -> bool: + if stored_hash.startswith("pbkdf2_sha256$"): + try: + _, iterations, salt, digest = stored_hash.split("$", 3) + candidate = hashlib.pbkdf2_hmac( + "sha256", + password.encode(), + salt.encode(), + int(iterations), + ).hex() + return hmac.compare_digest(candidate, digest) + except Exception: + return False + + # Legacy unsalted SHA-256 hash support for existing deployments. + legacy_hash = hashlib.sha256(password.encode()).hexdigest() + return hmac.compare_digest(legacy_hash, stored_hash) async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str: """ @@ -104,8 +127,12 @@ class UserPasswordManager: if not pwd_record: return False - password_hash = self._hash_password(password) - return pwd_record.password_hash == password_hash + verified = self._verify_hash(password, pwd_record.password_hash) + if verified and not pwd_record.password_hash.startswith("pbkdf2_sha256$"): + pwd_record.password_hash = self._hash_password(password) + pwd_record.updated_at = datetime.now() + await session.commit() + return verified async def has_password(self, user_id: str) -> bool: """ @@ -175,4 +202,4 @@ class UserPasswordManager: # 全局密码管理器实例 -password_manager = UserPasswordManager() \ No newline at end of file +password_manager = UserPasswordManager()