update: 修复基于长亭monkeycode扫描结果的12处安全漏洞

This commit is contained in:
xiamuceer
2026-04-24 10:11:23 +08:00
parent 63bfabc6de
commit 4af9a31eba
17 changed files with 366 additions and 75 deletions
@@ -0,0 +1,42 @@
"""expand password hash length
Revision ID: 9a1b2c3d4e5f
Revises: 6eb27fce64de
Create Date: 2026-04-24 10:06:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '9a1b2c3d4e5f'
down_revision: Union[str, None] = '6eb27fce64de'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column(
'user_passwords',
'password_hash',
existing_type=sa.String(length=64),
type_=sa.String(length=255),
existing_nullable=False,
existing_comment='密码哈希(SHA256',
comment='密码哈希',
)
def downgrade() -> None:
op.alter_column(
'user_passwords',
'password_hash',
existing_type=sa.String(length=255),
type_=sa.String(length=64),
existing_nullable=False,
existing_comment='密码哈希',
comment='密码哈希(SHA256',
)
@@ -0,0 +1,38 @@
"""expand password hash length
Revision ID: ab12cd34ef56
Revises: 6ff45db05863
Create Date: 2026-04-24 10:06:00
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'ab12cd34ef56'
down_revision: Union[str, None] = '6ff45db05863'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
batch_op.alter_column(
'password_hash',
existing_type=sa.String(length=64),
type_=sa.String(length=255),
existing_nullable=False,
)
def downgrade() -> None:
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
batch_op.alter_column(
'password_hash',
existing_type=sa.String(length=255),
type_=sa.String(length=64),
existing_nullable=False,
)
+9 -4
View File
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
import hashlib
import secrets
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
@@ -46,7 +47,7 @@ class ToggleStatusRequest(BaseModel):
class ResetPasswordRequest(BaseModel):
"""重置密码请求"""
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码")
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则生成临时密码")
class UserResponse(BaseModel):
@@ -308,10 +309,14 @@ async def reset_password(
raise HTTPException(status_code=404, detail="用户不存在")
# 重置密码
actual_password = await password_manager.set_password(
generated_password = data.new_password
if not generated_password:
generated_password = secrets.token_urlsafe(12)
await password_manager.set_password(
user_id=user_id,
username=target_user.username,
password=data.new_password
password=generated_password
)
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
@@ -319,7 +324,7 @@ async def reset_password(
return {
"success": True,
"message": "密码重置成功",
"new_password": actual_password
"temporary_password": generated_password if not data.new_password else None
}
except HTTPException:
+25 -8
View File
@@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from typing import Optional
import hashlib
import random
import secrets
import re
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
@@ -21,6 +21,7 @@ from app.database import get_engine
from app.models.user import User as UserModel
from app.models.settings import Settings as SettingsModel
from app.services.email_service import email_service
from app.security import create_session_token
# 中国时区 UTC+8
CHINA_TZ = timezone(timedelta(hours=8))
@@ -43,6 +44,7 @@ _state_storage = {}
# 邮箱验证码临时存储(生产环境应使用 Redis)
_email_verification_storage = {}
MAX_VERIFICATION_ATTEMPTS = 5
EMAIL_REGEX = re.compile(r"^[^\s@]+@[^\s@]+\.[^\s@]+$")
@@ -246,12 +248,14 @@ def _validate_password(password: str):
def _set_login_cookies(response: Response, user_id: str):
"""设置登录 Cookie"""
max_age = settings.SESSION_EXPIRE_MINUTES * 60
session_token = create_session_token(user_id, max_age)
response.set_cookie(
key="user_id",
value=user_id,
key="session_token",
value=session_token,
max_age=max_age,
httponly=True,
samesite="lax"
samesite="lax",
secure=not settings.debug,
)
china_now = get_china_now()
@@ -263,12 +267,13 @@ def _set_login_cookies(response: Response, user_id: str):
value=str(expire_at),
max_age=max_age,
httponly=False,
samesite="lax"
samesite="lax",
secure=not settings.debug,
)
def _generate_verification_code() -> str:
return f"{random.randint(0, 999999):06d}"
return f"{secrets.randbelow(1000000):06d}"
def _build_verification_mail_content(scene: str, code: str, ttl_minutes: int) -> tuple[str, str, str]:
@@ -456,6 +461,7 @@ async def send_email_verification_code(request: EmailSendCodeRequest):
"code": code,
"expires_at": expires_at,
"last_sent_at": now,
"attempts": 0,
}
logger.info(f"[邮箱验证码] 场景={scene} 已发送到 {email}")
@@ -493,6 +499,10 @@ async def email_register(request: EmailRegisterRequest, response: Response):
raise HTTPException(status_code=400, detail="验证码已过期,请重新发送")
if cached["code"] != code:
cached["attempts"] = cached.get("attempts", 0) + 1
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
_email_verification_storage.pop(_get_verification_storage_key("register", email), None)
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
raise HTTPException(status_code=400, detail="验证码错误")
existing_user = await _find_user_by_email(email)
@@ -541,6 +551,10 @@ async def email_login(request: EmailLoginRequest, response: Response):
raise HTTPException(status_code=400, detail="登录验证码已过期,请重新发送")
if cached["code"] != code:
cached["attempts"] = cached.get("attempts", 0) + 1
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
_email_verification_storage.pop(storage_key, None)
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
raise HTTPException(status_code=400, detail="登录验证码错误")
_email_verification_storage.pop(storage_key, None)
@@ -588,6 +602,10 @@ async def email_reset_password(request: EmailResetPasswordRequest):
raise HTTPException(status_code=400, detail="重置密码验证码已过期,请重新发送")
if cached["code"] != code:
cached["attempts"] = cached.get("attempts", 0) + 1
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
_email_verification_storage.pop(storage_key, None)
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
raise HTTPException(status_code=400, detail="重置密码验证码错误")
await password_manager.set_password(user.user_id, email, request.new_password)
@@ -757,6 +775,7 @@ async def logout(request: Request, response: Response):
logger.info(f"🚪 [退出] 用户 {user_id} 退出登录")
response.delete_cookie("user_id")
response.delete_cookie("session_token")
response.delete_cookie("session_expire_at")
return {"message": "退出登录成功"}
@@ -782,8 +801,6 @@ async def get_password_status(request: Request):
username = await password_manager.get_username(user.user_id)
default_password = None
if has_password and not has_custom:
default_password = f"{user.username}@666"
return PasswordStatusResponse(
has_password=has_password,
+8 -2
View File
@@ -2,7 +2,7 @@
更新日志API
提供GitHub提交历史的缓存和代理服务
"""
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Request, Depends
from typing import List, Optional
import httpx
from datetime import datetime, timedelta
@@ -13,6 +13,12 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def require_login(request: Request):
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="需要登录")
return request.state.user
# GitHub API配置
GITHUB_API_BASE = "https://api.github.com"
REPO_OWNER = "xiamuceer-j"
@@ -173,7 +179,7 @@ async def get_changelog(
@router.post("/changelog/refresh")
async def refresh_changelog():
async def refresh_changelog(user=Depends(require_login)):
"""
刷新更新日志缓存
+33 -11
View File
@@ -24,11 +24,22 @@ from app.user_manager import User
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
from app.services.mcp_test_service import mcp_test_service
from app.logger import get_logger
from app.security import validate_public_http_url
logger = get_logger(__name__)
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
HTTP_PLUGIN_TYPES = {"http", "streamable_http", "sse"}
def _validate_mcp_server_url(plugin_type: str, server_url: Optional[str]) -> Optional[str]:
if plugin_type in HTTP_PLUGIN_TYPES:
if not server_url:
raise HTTPException(status_code=400, detail=f"{plugin_type}类型插件必须提供server_url")
return validate_public_http_url(server_url)
return server_url
def require_login(request: Request) -> User:
"""依赖:要求用户已登录"""
@@ -53,7 +64,8 @@ async def _register_plugin_background(
try:
logger.info(f"后台注册MCP插件: {plugin_name}")
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
server_url = _validate_mcp_server_url(plugin_type, server_url)
success = await mcp_client.register(MCPPluginConfig(
user_id=user_id,
plugin_name=plugin_name,
@@ -123,11 +135,12 @@ async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
Returns:
是否注册成功
"""
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
return await mcp_client.register(MCPPluginConfig(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
url=server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers,
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
@@ -187,6 +200,10 @@ async def create_plugin(
# 创建插件数据
plugin_data = data.model_dump()
plugin_data["server_url"] = _validate_mcp_server_url(
plugin_data.get("plugin_type", "http"),
plugin_data.get("server_url")
)
# 如果没有提供display_name,使用plugin_name作为默认值
if not plugin_data.get("display_name"):
@@ -278,13 +295,10 @@ async def create_plugin_simple(
"sort_order": 0
}
if server_type in ["http", "streamable_http", "sse"]:
plugin_data["server_url"] = server_config.get("url")
if server_type in HTTP_PLUGIN_TYPES:
plugin_data["server_url"] = _validate_mcp_server_url(server_type, server_config.get("url"))
plugin_data["headers"] = server_config.get("headers", {})
if not plugin_data["server_url"]:
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
elif server_type == "stdio":
plugin_data["command"] = server_config.get("command")
plugin_data["args"] = server_config.get("args", [])
@@ -415,6 +429,12 @@ async def update_plugin(
# 更新字段
update_data = data.model_dump(exclude_unset=True)
target_type = update_data.get("plugin_type", plugin.plugin_type)
if "server_url" in update_data or target_type in HTTP_PLUGIN_TYPES:
update_data["server_url"] = _validate_mcp_server_url(
target_type,
update_data.get("server_url", plugin.server_url)
)
for key, value in update_data.items():
setattr(plugin, key, value)
@@ -501,7 +521,8 @@ async def toggle_plugin(
if enabled:
# 启用:注册到统一门面
try:
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
server_url = _validate_mcp_server_url(plugin_type, server_url)
success = await mcp_client.register(MCPPluginConfig(
user_id=user.user_id,
plugin_name=plugin_name,
@@ -647,11 +668,12 @@ async def _ensure_plugin_registered(
"""
try:
# 使用ensure_registered方法,它会检查是否已注册
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
url=server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
+4 -1
View File
@@ -26,6 +26,7 @@ from app.logger import get_logger
from app.config import settings as app_settings, PROJECT_ROOT
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider
from app.services.email_service import email_service
from app.security import validate_public_http_url
logger = get_logger(__name__)
@@ -452,7 +453,8 @@ async def delete_settings(
async def get_available_models(
api_key: str,
api_base_url: str,
provider: str = "openai"
provider: str = "openai",
user: User = Depends(require_login)
):
"""
从配置的 API 获取可用的模型列表
@@ -467,6 +469,7 @@ async def get_available_models(
"""
try:
provider = normalize_provider(provider)
api_base_url = validate_public_http_url(api_base_url)
async with httpx.AsyncClient(timeout=10.0) as client:
if provider == "openai" or provider == "azure" or provider == "custom":
# OpenAI 兼容接口获取模型列表
+11 -6
View File
@@ -32,7 +32,7 @@ class SetAdminRequest(BaseModel):
class ResetPasswordRequest(BaseModel):
user_id: str
new_password: Optional[str] = None # 如果为空则使用默认密码
new_password: Optional[str] = None # 如果为空则由系统生成临时密码
@router.get("/current")
@@ -140,7 +140,7 @@ async def reset_user_password(
重置用户密码(仅管理员)
如果提供了 new_password,则设置为指定密码
如果未提供 new_password,则重置为默认密码(username@666
如果未提供 new_password,则由系统生成临时密码
限制:
- 不能重置自己的密码(应该使用修改密码功能)
@@ -162,10 +162,15 @@ async def reset_user_password(
# 重置密码
try:
actual_password = await password_manager.set_password(
generated_password = data.new_password
if not generated_password:
import secrets
generated_password = secrets.token_urlsafe(12)
await password_manager.set_password(
target_user.user_id,
target_user.username,
data.new_password
generated_password
)
# 如果使用了默认密码,返回密码供管理员告知用户
@@ -177,8 +182,8 @@ async def reset_user_password(
}
if not data.new_password:
response_data["default_password"] = actual_password
response_data["message"] = f"密码已重置为默认密码: {actual_password}"
response_data["temporary_password"] = generated_password
response_data["message"] = "密码已重置为系统生成的临时密码,请尽快通知用户修改"
return response_data
+20 -20
View File
@@ -26,6 +26,18 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
logger = get_logger(__name__)
async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None:
if not project_id or not user_id:
return None
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def world_building_generator(
data: Dict[str, Any],
db: AsyncSession,
@@ -326,12 +338,9 @@ async def career_system_generator(
# 获取项目信息
yield await tracker.loading("加载项目信息...")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
project = await get_owned_project(db, project_id, user_id)
if not project:
yield await tracker.error("项目不存在", 404)
yield await tracker.error("项目不存在或无权访问", 404)
return
# 设置用户信息以启用MCP
@@ -599,12 +608,9 @@ async def characters_generator(
# 验证项目
yield await tracker.loading("验证项目...", 0.3)
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
project = await get_owned_project(db, project_id, user_id)
if not project:
yield await tracker.error("项目不存在", 404)
yield await tracker.error("项目不存在或无权访问", 404)
return
project.wizard_step = 2
@@ -1270,12 +1276,9 @@ async def outline_generator(
# 获取项目信息
yield await tracker.loading("加载项目信息...", 0.3)
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
project = await get_owned_project(db, project_id, user_id)
if not project:
yield await tracker.error("项目不存在", 404)
yield await tracker.error("项目不存在或无权访问", 404)
return
# 获取角色信息
@@ -1551,12 +1554,9 @@ async def world_building_regenerate_generator(
# 获取项目信息
yield await tracker.loading("加载项目信息...")
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
project = await get_owned_project(db, project_id, user_id)
if not project:
yield await tracker.error("项目不存在", 404)
yield await tracker.error("项目不存在或无权访问", 404)
return
# 提取参数
+4
View File
@@ -275,15 +275,19 @@ async def get_project_styles(
@router.get("/{style_id}", response_model=WritingStyleResponse)
async def get_writing_style(
style_id: int,
request: Request,
db: AsyncSession = Depends(get_db)
):
"""获取单个写作风格详情"""
user_id = get_current_user_id(request)
result = await db.execute(
select(WritingStyle).where(WritingStyle.id == style_id)
)
style = result.scalar_one_or_none()
if not style:
raise HTTPException(status_code=404, detail="写作风格不存在")
if style.user_id is not None and style.user_id != user_id:
raise HTTPException(status_code=404, detail="写作风格不存在")
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound
result = await db.execute(
+2 -1
View File
@@ -29,7 +29,7 @@ class Settings(BaseSettings):
app_version: str = "1.0.0"
app_host: str = "0.0.0.0"
app_port: int = 8000
debug: bool = True
debug: bool = False
# 日志配置
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
@@ -106,6 +106,7 @@ class Settings(BaseSettings):
# 会话配置
SESSION_EXPIRE_MINUTES: int = 120 # 会话过期时间(分钟),默认2小时
SESSION_REFRESH_THRESHOLD_MINUTES: int = 30 # 会话刷新阈值(分钟),剩余时间少于此值时可刷新
SESSION_SECRET_KEY: Optional[str] = None # 会话签名密钥,生产环境必须配置为高强度随机值
# 系统 SMTP 默认配置(可被管理员系统设置覆盖)
SMTP_PROVIDER: str = "qq"
+16 -4
View File
@@ -1,5 +1,5 @@
"""FastAPI应用主入口"""
from fastapi import FastAPI, Request, status
from fastapi import FastAPI, Request, status, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, FileResponse
@@ -106,7 +106,7 @@ async def health_check():
@app.get("/health/db-sessions")
async def db_session_stats():
async def db_session_stats(request: Request):
"""
数据库会话统计(监控连接泄漏)
@@ -118,6 +118,8 @@ async def db_session_stats():
- generator_exits: SSE断开次数
- last_check: 最后检查时间
"""
if not getattr(request.state, "is_admin", False):
raise HTTPException(status_code=403, detail="需要管理员权限")
return {
"status": "ok",
"session_stats": _session_stats,
@@ -176,8 +178,18 @@ if static_dir.exists():
)
file_path = static_dir / full_path
if file_path.is_file():
return FileResponse(file_path)
try:
resolved_file = file_path.resolve()
resolved_static = static_dir.resolve()
resolved_file.relative_to(resolved_static)
except ValueError:
return JSONResponse(
status_code=404,
content={"detail": "页面不存在"}
)
if resolved_file.is_file():
return FileResponse(resolved_file)
index_file = static_dir / "index.html"
if index_file.exists():
+3 -2
View File
@@ -6,6 +6,7 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.user_manager import user_manager
from app.logger import get_logger
from app.security import verify_session_token
logger = get_logger(__name__)
@@ -46,8 +47,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
request.state.is_proxy_request = False
request.state.proxy_instance_id = None
# 从 Cookie 中获取用户 ID
user_id = request.cookies.get("user_id")
# 优先验证签名会话 Cookie;不再信任客户端可伪造的明文 user_id。
user_id = verify_session_token(request.cookies.get("session_token"))
if user_id:
user = await user_manager.get_user(user_id)
+1 -1
View File
@@ -41,7 +41,7 @@ class UserPassword(Base):
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
username = Column(String(100), nullable=False, comment="用户名")
password_hash = Column(String(64), nullable=False, comment="密码哈希SHA256")
password_hash = Column(String(255), nullable=False, comment="密码哈希")
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
+108
View File
@@ -0,0 +1,108 @@
"""Security helpers for sessions and outbound URL validation."""
import base64
import hashlib
import hmac
import ipaddress
import json
import secrets
import socket
import time
from typing import Iterable
from urllib.parse import urlparse
from fastapi import HTTPException
from app.config import settings
def _session_secret() -> bytes:
secret = (
getattr(settings, "SESSION_SECRET_KEY", None)
or getattr(settings, "session_secret_key", None)
or settings.LINUXDO_CLIENT_SECRET
or settings.LOCAL_AUTH_PASSWORD
or settings.openai_api_key
)
if not secret:
secret = "mumuainovel-development-session-secret"
return str(secret).encode("utf-8")
def create_session_token(user_id: str, max_age_seconds: int) -> str:
payload = {
"uid": user_id,
"exp": int(time.time()) + max_age_seconds,
"nonce": secrets.token_urlsafe(16),
}
payload_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
payload_b64 = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode("ascii")
signature = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest()
signature_b64 = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("ascii")
return f"{payload_b64}.{signature_b64}"
def verify_session_token(token: str) -> str | None:
if not token or "." not in token:
return None
payload_b64, signature_b64 = token.split(".", 1)
expected = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest()
try:
provided = base64.urlsafe_b64decode(signature_b64 + "=" * (-len(signature_b64) % 4))
except Exception:
return None
if not hmac.compare_digest(expected, provided):
return None
try:
payload_raw = base64.urlsafe_b64decode(payload_b64 + "=" * (-len(payload_b64) % 4))
payload = json.loads(payload_raw.decode("utf-8"))
except Exception:
return None
if int(payload.get("exp", 0)) < int(time.time()):
return None
user_id = payload.get("uid")
return user_id if isinstance(user_id, str) and user_id else None
def _is_forbidden_ip(ip: ipaddress._BaseAddress) -> bool:
return any([
ip.is_private,
ip.is_loopback,
ip.is_link_local,
ip.is_multicast,
ip.is_reserved,
ip.is_unspecified,
])
def validate_public_http_url(raw_url: str, *, allowed_schemes: Iterable[str] = ("https", "http")) -> str:
"""Validate an outbound URL to reduce SSRF risk."""
if not raw_url or not isinstance(raw_url, str):
raise HTTPException(status_code=400, detail="URL不能为空")
parsed = urlparse(raw_url.strip())
if parsed.scheme not in set(allowed_schemes):
raise HTTPException(status_code=400, detail="仅支持 HTTP/HTTPS URL")
if not parsed.hostname:
raise HTTPException(status_code=400, detail="URL缺少主机名")
if parsed.username or parsed.password:
raise HTTPException(status_code=400, detail="URL不允许包含认证信息")
host = parsed.hostname.strip().rstrip(".")
if host.lower() in {"localhost", "localhost.localdomain"}:
raise HTTPException(status_code=400, detail="URL不允许指向本机地址")
try:
ip = ipaddress.ip_address(host)
if _is_forbidden_ip(ip):
raise HTTPException(status_code=400, detail="URL不允许指向内网或保留地址")
except ValueError:
try:
infos = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80), type=socket.SOCK_STREAM)
except socket.gaierror:
raise HTTPException(status_code=400, detail="URL主机名无法解析")
for info in infos:
resolved_ip = ipaddress.ip_address(info[4][0])
if _is_forbidden_ip(resolved_ip):
raise HTTPException(status_code=400, detail="URL解析到内网或保留地址")
return raw_url.strip().rstrip("/")
+1 -1
View File
@@ -38,7 +38,7 @@ class WorkshopClient:
url = f"{self.base_url}/api/prompt-workshop{path}"
try:
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.request(
method=method,
url=url,
+30 -3
View File
@@ -3,6 +3,8 @@
"""
import asyncio
import hashlib
import hmac
import secrets
from typing import Optional
from datetime import datetime
from sqlalchemy import select
@@ -34,7 +36,28 @@ class UserPasswordManager:
def _hash_password(self, password: str) -> str:
"""密码哈希"""
return hashlib.sha256(password.encode()).hexdigest()
salt = secrets.token_hex(16)
iterations = 260000
digest = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), iterations).hex()
return f"pbkdf2_sha256${iterations}${salt}${digest}"
def _verify_hash(self, password: str, stored_hash: str) -> bool:
if stored_hash.startswith("pbkdf2_sha256$"):
try:
_, iterations, salt, digest = stored_hash.split("$", 3)
candidate = hashlib.pbkdf2_hmac(
"sha256",
password.encode(),
salt.encode(),
int(iterations),
).hex()
return hmac.compare_digest(candidate, digest)
except Exception:
return False
# Legacy unsalted SHA-256 hash support for existing deployments.
legacy_hash = hashlib.sha256(password.encode()).hexdigest()
return hmac.compare_digest(legacy_hash, stored_hash)
async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str:
"""
@@ -104,8 +127,12 @@ class UserPasswordManager:
if not pwd_record:
return False
password_hash = self._hash_password(password)
return pwd_record.password_hash == password_hash
verified = self._verify_hash(password, pwd_record.password_hash)
if verified and not pwd_record.password_hash.startswith("pbkdf2_sha256$"):
pwd_record.password_hash = self._hash_password(password)
pwd_record.updated_at = datetime.now()
await session.commit()
return verified
async def has_password(self, user_id: str) -> bool:
"""