update: 修复基于长亭monkeycode扫描结果的12处安全漏洞
This commit is contained in:
+42
@@ -0,0 +1,42 @@
|
|||||||
|
"""expand password hash length
|
||||||
|
|
||||||
|
Revision ID: 9a1b2c3d4e5f
|
||||||
|
Revises: 6eb27fce64de
|
||||||
|
Create Date: 2026-04-24 10:06:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '9a1b2c3d4e5f'
|
||||||
|
down_revision: Union[str, None] = '6eb27fce64de'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
'user_passwords',
|
||||||
|
'password_hash',
|
||||||
|
existing_type=sa.String(length=64),
|
||||||
|
type_=sa.String(length=255),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_comment='密码哈希(SHA256)',
|
||||||
|
comment='密码哈希',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
'user_passwords',
|
||||||
|
'password_hash',
|
||||||
|
existing_type=sa.String(length=255),
|
||||||
|
type_=sa.String(length=64),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_comment='密码哈希',
|
||||||
|
comment='密码哈希(SHA256)',
|
||||||
|
)
|
||||||
+38
@@ -0,0 +1,38 @@
|
|||||||
|
"""expand password hash length
|
||||||
|
|
||||||
|
Revision ID: ab12cd34ef56
|
||||||
|
Revises: 6ff45db05863
|
||||||
|
Create Date: 2026-04-24 10:06:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'ab12cd34ef56'
|
||||||
|
down_revision: Union[str, None] = '6ff45db05863'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
'password_hash',
|
||||||
|
existing_type=sa.String(length=64),
|
||||||
|
type_=sa.String(length=255),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column(
|
||||||
|
'password_hash',
|
||||||
|
existing_type=sa.String(length=255),
|
||||||
|
type_=sa.String(length=64),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import secrets
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
@@ -46,7 +47,7 @@ class ToggleStatusRequest(BaseModel):
|
|||||||
|
|
||||||
class ResetPasswordRequest(BaseModel):
|
class ResetPasswordRequest(BaseModel):
|
||||||
"""重置密码请求"""
|
"""重置密码请求"""
|
||||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码")
|
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则生成临时密码")
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
@@ -308,10 +309,14 @@ async def reset_password(
|
|||||||
raise HTTPException(status_code=404, detail="用户不存在")
|
raise HTTPException(status_code=404, detail="用户不存在")
|
||||||
|
|
||||||
# 重置密码
|
# 重置密码
|
||||||
actual_password = await password_manager.set_password(
|
generated_password = data.new_password
|
||||||
|
if not generated_password:
|
||||||
|
generated_password = secrets.token_urlsafe(12)
|
||||||
|
|
||||||
|
await password_manager.set_password(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=target_user.username,
|
username=target_user.username,
|
||||||
password=data.new_password
|
password=generated_password
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
|
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
|
||||||
@@ -319,7 +324,7 @@ async def reset_password(
|
|||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "密码重置成功",
|
"message": "密码重置成功",
|
||||||
"new_password": actual_password
|
"temporary_password": generated_password if not data.new_password else None
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -385,4 +390,4 @@ async def delete_user(
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除用户失败: {str(e)}", exc_info=True)
|
logger.error(f"删除用户失败: {str(e)}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
||||||
|
|||||||
+25
-8
@@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import secrets
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -21,6 +21,7 @@ from app.database import get_engine
|
|||||||
from app.models.user import User as UserModel
|
from app.models.user import User as UserModel
|
||||||
from app.models.settings import Settings as SettingsModel
|
from app.models.settings import Settings as SettingsModel
|
||||||
from app.services.email_service import email_service
|
from app.services.email_service import email_service
|
||||||
|
from app.security import create_session_token
|
||||||
|
|
||||||
# 中国时区 UTC+8
|
# 中国时区 UTC+8
|
||||||
CHINA_TZ = timezone(timedelta(hours=8))
|
CHINA_TZ = timezone(timedelta(hours=8))
|
||||||
@@ -43,6 +44,7 @@ _state_storage = {}
|
|||||||
|
|
||||||
# 邮箱验证码临时存储(生产环境应使用 Redis)
|
# 邮箱验证码临时存储(生产环境应使用 Redis)
|
||||||
_email_verification_storage = {}
|
_email_verification_storage = {}
|
||||||
|
MAX_VERIFICATION_ATTEMPTS = 5
|
||||||
|
|
||||||
EMAIL_REGEX = re.compile(r"^[^\s@]+@[^\s@]+\.[^\s@]+$")
|
EMAIL_REGEX = re.compile(r"^[^\s@]+@[^\s@]+\.[^\s@]+$")
|
||||||
|
|
||||||
@@ -246,12 +248,14 @@ def _validate_password(password: str):
|
|||||||
def _set_login_cookies(response: Response, user_id: str):
|
def _set_login_cookies(response: Response, user_id: str):
|
||||||
"""设置登录 Cookie"""
|
"""设置登录 Cookie"""
|
||||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||||
|
session_token = create_session_token(user_id, max_age)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="user_id",
|
key="session_token",
|
||||||
value=user_id,
|
value=session_token,
|
||||||
max_age=max_age,
|
max_age=max_age,
|
||||||
httponly=True,
|
httponly=True,
|
||||||
samesite="lax"
|
samesite="lax",
|
||||||
|
secure=not settings.debug,
|
||||||
)
|
)
|
||||||
|
|
||||||
china_now = get_china_now()
|
china_now = get_china_now()
|
||||||
@@ -263,12 +267,13 @@ def _set_login_cookies(response: Response, user_id: str):
|
|||||||
value=str(expire_at),
|
value=str(expire_at),
|
||||||
max_age=max_age,
|
max_age=max_age,
|
||||||
httponly=False,
|
httponly=False,
|
||||||
samesite="lax"
|
samesite="lax",
|
||||||
|
secure=not settings.debug,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_verification_code() -> str:
|
def _generate_verification_code() -> str:
|
||||||
return f"{random.randint(0, 999999):06d}"
|
return f"{secrets.randbelow(1000000):06d}"
|
||||||
|
|
||||||
|
|
||||||
def _build_verification_mail_content(scene: str, code: str, ttl_minutes: int) -> tuple[str, str, str]:
|
def _build_verification_mail_content(scene: str, code: str, ttl_minutes: int) -> tuple[str, str, str]:
|
||||||
@@ -456,6 +461,7 @@ async def send_email_verification_code(request: EmailSendCodeRequest):
|
|||||||
"code": code,
|
"code": code,
|
||||||
"expires_at": expires_at,
|
"expires_at": expires_at,
|
||||||
"last_sent_at": now,
|
"last_sent_at": now,
|
||||||
|
"attempts": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"[邮箱验证码] 场景={scene} 已发送到 {email}")
|
logger.info(f"[邮箱验证码] 场景={scene} 已发送到 {email}")
|
||||||
@@ -493,6 +499,10 @@ async def email_register(request: EmailRegisterRequest, response: Response):
|
|||||||
raise HTTPException(status_code=400, detail="验证码已过期,请重新发送")
|
raise HTTPException(status_code=400, detail="验证码已过期,请重新发送")
|
||||||
|
|
||||||
if cached["code"] != code:
|
if cached["code"] != code:
|
||||||
|
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||||
|
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||||
|
_email_verification_storage.pop(_get_verification_storage_key("register", email), None)
|
||||||
|
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||||
raise HTTPException(status_code=400, detail="验证码错误")
|
raise HTTPException(status_code=400, detail="验证码错误")
|
||||||
|
|
||||||
existing_user = await _find_user_by_email(email)
|
existing_user = await _find_user_by_email(email)
|
||||||
@@ -541,6 +551,10 @@ async def email_login(request: EmailLoginRequest, response: Response):
|
|||||||
raise HTTPException(status_code=400, detail="登录验证码已过期,请重新发送")
|
raise HTTPException(status_code=400, detail="登录验证码已过期,请重新发送")
|
||||||
|
|
||||||
if cached["code"] != code:
|
if cached["code"] != code:
|
||||||
|
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||||
|
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||||
|
_email_verification_storage.pop(storage_key, None)
|
||||||
|
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||||
raise HTTPException(status_code=400, detail="登录验证码错误")
|
raise HTTPException(status_code=400, detail="登录验证码错误")
|
||||||
|
|
||||||
_email_verification_storage.pop(storage_key, None)
|
_email_verification_storage.pop(storage_key, None)
|
||||||
@@ -588,6 +602,10 @@ async def email_reset_password(request: EmailResetPasswordRequest):
|
|||||||
raise HTTPException(status_code=400, detail="重置密码验证码已过期,请重新发送")
|
raise HTTPException(status_code=400, detail="重置密码验证码已过期,请重新发送")
|
||||||
|
|
||||||
if cached["code"] != code:
|
if cached["code"] != code:
|
||||||
|
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||||
|
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||||
|
_email_verification_storage.pop(storage_key, None)
|
||||||
|
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||||
raise HTTPException(status_code=400, detail="重置密码验证码错误")
|
raise HTTPException(status_code=400, detail="重置密码验证码错误")
|
||||||
|
|
||||||
await password_manager.set_password(user.user_id, email, request.new_password)
|
await password_manager.set_password(user.user_id, email, request.new_password)
|
||||||
@@ -757,6 +775,7 @@ async def logout(request: Request, response: Response):
|
|||||||
logger.info(f"🚪 [退出] 用户 {user_id} 退出登录")
|
logger.info(f"🚪 [退出] 用户 {user_id} 退出登录")
|
||||||
|
|
||||||
response.delete_cookie("user_id")
|
response.delete_cookie("user_id")
|
||||||
|
response.delete_cookie("session_token")
|
||||||
response.delete_cookie("session_expire_at")
|
response.delete_cookie("session_expire_at")
|
||||||
return {"message": "退出登录成功"}
|
return {"message": "退出登录成功"}
|
||||||
|
|
||||||
@@ -782,8 +801,6 @@ async def get_password_status(request: Request):
|
|||||||
username = await password_manager.get_username(user.user_id)
|
username = await password_manager.get_username(user.user_id)
|
||||||
|
|
||||||
default_password = None
|
default_password = None
|
||||||
if has_password and not has_custom:
|
|
||||||
default_password = f"{user.username}@666"
|
|
||||||
|
|
||||||
return PasswordStatusResponse(
|
return PasswordStatusResponse(
|
||||||
has_password=has_password,
|
has_password=has_password,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
更新日志API
|
更新日志API
|
||||||
提供GitHub提交历史的缓存和代理服务
|
提供GitHub提交历史的缓存和代理服务
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query, Request, Depends
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import httpx
|
import httpx
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@@ -13,6 +13,12 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def require_login(request: Request):
|
||||||
|
if not hasattr(request.state, "user") or not request.state.user:
|
||||||
|
raise HTTPException(status_code=401, detail="需要登录")
|
||||||
|
return request.state.user
|
||||||
|
|
||||||
# GitHub API配置
|
# GitHub API配置
|
||||||
GITHUB_API_BASE = "https://api.github.com"
|
GITHUB_API_BASE = "https://api.github.com"
|
||||||
REPO_OWNER = "xiamuceer-j"
|
REPO_OWNER = "xiamuceer-j"
|
||||||
@@ -173,7 +179,7 @@ async def get_changelog(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/changelog/refresh")
|
@router.post("/changelog/refresh")
|
||||||
async def refresh_changelog():
|
async def refresh_changelog(user=Depends(require_login)):
|
||||||
"""
|
"""
|
||||||
刷新更新日志缓存
|
刷新更新日志缓存
|
||||||
|
|
||||||
@@ -230,4 +236,4 @@ async def refresh_changelog():
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"刷新缓存失败: {str(e)}"
|
detail=f"刷新缓存失败: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,11 +24,22 @@ from app.user_manager import User
|
|||||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||||
from app.services.mcp_test_service import mcp_test_service
|
from app.services.mcp_test_service import mcp_test_service
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.security import validate_public_http_url
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
||||||
|
|
||||||
|
HTTP_PLUGIN_TYPES = {"http", "streamable_http", "sse"}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_mcp_server_url(plugin_type: str, server_url: Optional[str]) -> Optional[str]:
|
||||||
|
if plugin_type in HTTP_PLUGIN_TYPES:
|
||||||
|
if not server_url:
|
||||||
|
raise HTTPException(status_code=400, detail=f"{plugin_type}类型插件必须提供server_url")
|
||||||
|
return validate_public_http_url(server_url)
|
||||||
|
return server_url
|
||||||
|
|
||||||
|
|
||||||
def require_login(request: Request) -> User:
|
def require_login(request: Request) -> User:
|
||||||
"""依赖:要求用户已登录"""
|
"""依赖:要求用户已登录"""
|
||||||
@@ -53,7 +64,8 @@ async def _register_plugin_background(
|
|||||||
try:
|
try:
|
||||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||||
|
|
||||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||||
|
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||||
success = await mcp_client.register(MCPPluginConfig(
|
success = await mcp_client.register(MCPPluginConfig(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
plugin_name=plugin_name,
|
plugin_name=plugin_name,
|
||||||
@@ -123,11 +135,12 @@ async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
是否注册成功
|
是否注册成功
|
||||||
"""
|
"""
|
||||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||||
|
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||||
return await mcp_client.register(MCPPluginConfig(
|
return await mcp_client.register(MCPPluginConfig(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
plugin_name=plugin.plugin_name,
|
plugin_name=plugin.plugin_name,
|
||||||
url=plugin.server_url,
|
url=server_url,
|
||||||
plugin_type=plugin.plugin_type,
|
plugin_type=plugin.plugin_type,
|
||||||
headers=plugin.headers,
|
headers=plugin.headers,
|
||||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||||
@@ -187,6 +200,10 @@ async def create_plugin(
|
|||||||
|
|
||||||
# 创建插件数据
|
# 创建插件数据
|
||||||
plugin_data = data.model_dump()
|
plugin_data = data.model_dump()
|
||||||
|
plugin_data["server_url"] = _validate_mcp_server_url(
|
||||||
|
plugin_data.get("plugin_type", "http"),
|
||||||
|
plugin_data.get("server_url")
|
||||||
|
)
|
||||||
|
|
||||||
# 如果没有提供display_name,使用plugin_name作为默认值
|
# 如果没有提供display_name,使用plugin_name作为默认值
|
||||||
if not plugin_data.get("display_name"):
|
if not plugin_data.get("display_name"):
|
||||||
@@ -278,12 +295,9 @@ async def create_plugin_simple(
|
|||||||
"sort_order": 0
|
"sort_order": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if server_type in ["http", "streamable_http", "sse"]:
|
if server_type in HTTP_PLUGIN_TYPES:
|
||||||
plugin_data["server_url"] = server_config.get("url")
|
plugin_data["server_url"] = _validate_mcp_server_url(server_type, server_config.get("url"))
|
||||||
plugin_data["headers"] = server_config.get("headers", {})
|
plugin_data["headers"] = server_config.get("headers", {})
|
||||||
|
|
||||||
if not plugin_data["server_url"]:
|
|
||||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
|
||||||
|
|
||||||
elif server_type == "stdio":
|
elif server_type == "stdio":
|
||||||
plugin_data["command"] = server_config.get("command")
|
plugin_data["command"] = server_config.get("command")
|
||||||
@@ -415,6 +429,12 @@ async def update_plugin(
|
|||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
target_type = update_data.get("plugin_type", plugin.plugin_type)
|
||||||
|
if "server_url" in update_data or target_type in HTTP_PLUGIN_TYPES:
|
||||||
|
update_data["server_url"] = _validate_mcp_server_url(
|
||||||
|
target_type,
|
||||||
|
update_data.get("server_url", plugin.server_url)
|
||||||
|
)
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(plugin, key, value)
|
setattr(plugin, key, value)
|
||||||
|
|
||||||
@@ -501,7 +521,8 @@ async def toggle_plugin(
|
|||||||
if enabled:
|
if enabled:
|
||||||
# 启用:注册到统一门面
|
# 启用:注册到统一门面
|
||||||
try:
|
try:
|
||||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||||
|
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||||
success = await mcp_client.register(MCPPluginConfig(
|
success = await mcp_client.register(MCPPluginConfig(
|
||||||
user_id=user.user_id,
|
user_id=user.user_id,
|
||||||
plugin_name=plugin_name,
|
plugin_name=plugin_name,
|
||||||
@@ -647,11 +668,12 @@ async def _ensure_plugin_registered(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用ensure_registered方法,它会检查是否已注册
|
# 使用ensure_registered方法,它会检查是否已注册
|
||||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||||
|
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||||
return await mcp_client.ensure_registered(
|
return await mcp_client.ensure_registered(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
plugin_name=plugin.plugin_name,
|
plugin_name=plugin.plugin_name,
|
||||||
url=plugin.server_url,
|
url=server_url,
|
||||||
plugin_type=plugin.plugin_type,
|
plugin_type=plugin.plugin_type,
|
||||||
headers=plugin.headers
|
headers=plugin.headers
|
||||||
)
|
)
|
||||||
@@ -912,4 +934,4 @@ async def call_mcp_tool(
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from app.logger import get_logger
|
|||||||
from app.config import settings as app_settings, PROJECT_ROOT
|
from app.config import settings as app_settings, PROJECT_ROOT
|
||||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider
|
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider
|
||||||
from app.services.email_service import email_service
|
from app.services.email_service import email_service
|
||||||
|
from app.security import validate_public_http_url
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -452,7 +453,8 @@ async def delete_settings(
|
|||||||
async def get_available_models(
|
async def get_available_models(
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base_url: str,
|
api_base_url: str,
|
||||||
provider: str = "openai"
|
provider: str = "openai",
|
||||||
|
user: User = Depends(require_login)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
从配置的 API 获取可用的模型列表
|
从配置的 API 获取可用的模型列表
|
||||||
@@ -467,6 +469,7 @@ async def get_available_models(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
provider = normalize_provider(provider)
|
provider = normalize_provider(provider)
|
||||||
|
api_base_url = validate_public_http_url(api_base_url)
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
if provider == "openai" or provider == "azure" or provider == "custom":
|
if provider == "openai" or provider == "azure" or provider == "custom":
|
||||||
# OpenAI 兼容接口获取模型列表
|
# OpenAI 兼容接口获取模型列表
|
||||||
@@ -1291,4 +1294,4 @@ async def create_preset_from_current(
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}")
|
logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}")
|
||||||
return await create_preset(create_request, user, db)
|
return await create_preset(create_request, user, db)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class SetAdminRequest(BaseModel):
|
|||||||
|
|
||||||
class ResetPasswordRequest(BaseModel):
|
class ResetPasswordRequest(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
new_password: Optional[str] = None # 如果为空则使用默认密码
|
new_password: Optional[str] = None # 如果为空则由系统生成临时密码
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current")
|
@router.get("/current")
|
||||||
@@ -140,7 +140,7 @@ async def reset_user_password(
|
|||||||
重置用户密码(仅管理员)
|
重置用户密码(仅管理员)
|
||||||
|
|
||||||
如果提供了 new_password,则设置为指定密码
|
如果提供了 new_password,则设置为指定密码
|
||||||
如果未提供 new_password,则重置为默认密码(username@666)
|
如果未提供 new_password,则由系统生成临时密码
|
||||||
|
|
||||||
限制:
|
限制:
|
||||||
- 不能重置自己的密码(应该使用修改密码功能)
|
- 不能重置自己的密码(应该使用修改密码功能)
|
||||||
@@ -162,10 +162,15 @@ async def reset_user_password(
|
|||||||
|
|
||||||
# 重置密码
|
# 重置密码
|
||||||
try:
|
try:
|
||||||
actual_password = await password_manager.set_password(
|
generated_password = data.new_password
|
||||||
|
if not generated_password:
|
||||||
|
import secrets
|
||||||
|
generated_password = secrets.token_urlsafe(12)
|
||||||
|
|
||||||
|
await password_manager.set_password(
|
||||||
target_user.user_id,
|
target_user.user_id,
|
||||||
target_user.username,
|
target_user.username,
|
||||||
data.new_password
|
generated_password
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果使用了默认密码,返回密码供管理员告知用户
|
# 如果使用了默认密码,返回密码供管理员告知用户
|
||||||
@@ -177,8 +182,8 @@ async def reset_user_password(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not data.new_password:
|
if not data.new_password:
|
||||||
response_data["default_password"] = actual_password
|
response_data["temporary_password"] = generated_password
|
||||||
response_data["message"] = f"密码已重置为默认密码: {actual_password}"
|
response_data["message"] = "密码已重置为系统生成的临时密码,请尽快通知用户修改"
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
@@ -186,4 +191,4 @@ async def reset_user_password(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail=f"重置密码失败: {str(e)}"
|
detail=f"重置密码失败: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -26,6 +26,18 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None:
|
||||||
|
if not project_id or not user_id:
|
||||||
|
return None
|
||||||
|
result = await db.execute(
|
||||||
|
select(Project).where(
|
||||||
|
Project.id == project_id,
|
||||||
|
Project.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def world_building_generator(
|
async def world_building_generator(
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
@@ -326,12 +338,9 @@ async def career_system_generator(
|
|||||||
|
|
||||||
# 获取项目信息
|
# 获取项目信息
|
||||||
yield await tracker.loading("加载项目信息...")
|
yield await tracker.loading("加载项目信息...")
|
||||||
result = await db.execute(
|
project = await get_owned_project(db, project_id, user_id)
|
||||||
select(Project).where(Project.id == project_id)
|
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
if not project:
|
||||||
yield await tracker.error("项目不存在", 404)
|
yield await tracker.error("项目不存在或无权访问", 404)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 设置用户信息以启用MCP
|
# 设置用户信息以启用MCP
|
||||||
@@ -599,12 +608,9 @@ async def characters_generator(
|
|||||||
|
|
||||||
# 验证项目
|
# 验证项目
|
||||||
yield await tracker.loading("验证项目...", 0.3)
|
yield await tracker.loading("验证项目...", 0.3)
|
||||||
result = await db.execute(
|
project = await get_owned_project(db, project_id, user_id)
|
||||||
select(Project).where(Project.id == project_id)
|
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
if not project:
|
||||||
yield await tracker.error("项目不存在", 404)
|
yield await tracker.error("项目不存在或无权访问", 404)
|
||||||
return
|
return
|
||||||
|
|
||||||
project.wizard_step = 2
|
project.wizard_step = 2
|
||||||
@@ -1270,12 +1276,9 @@ async def outline_generator(
|
|||||||
|
|
||||||
# 获取项目信息
|
# 获取项目信息
|
||||||
yield await tracker.loading("加载项目信息...", 0.3)
|
yield await tracker.loading("加载项目信息...", 0.3)
|
||||||
result = await db.execute(
|
project = await get_owned_project(db, project_id, user_id)
|
||||||
select(Project).where(Project.id == project_id)
|
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
if not project:
|
||||||
yield await tracker.error("项目不存在", 404)
|
yield await tracker.error("项目不存在或无权访问", 404)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 获取角色信息
|
# 获取角色信息
|
||||||
@@ -1551,12 +1554,9 @@ async def world_building_regenerate_generator(
|
|||||||
|
|
||||||
# 获取项目信息
|
# 获取项目信息
|
||||||
yield await tracker.loading("加载项目信息...")
|
yield await tracker.loading("加载项目信息...")
|
||||||
result = await db.execute(
|
project = await get_owned_project(db, project_id, user_id)
|
||||||
select(Project).where(Project.id == project_id)
|
|
||||||
)
|
|
||||||
project = result.scalar_one_or_none()
|
|
||||||
if not project:
|
if not project:
|
||||||
yield await tracker.error("项目不存在", 404)
|
yield await tracker.error("项目不存在或无权访问", 404)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 提取参数
|
# 提取参数
|
||||||
|
|||||||
@@ -275,15 +275,19 @@ async def get_project_styles(
|
|||||||
@router.get("/{style_id}", response_model=WritingStyleResponse)
|
@router.get("/{style_id}", response_model=WritingStyleResponse)
|
||||||
async def get_writing_style(
|
async def get_writing_style(
|
||||||
style_id: int,
|
style_id: int,
|
||||||
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""获取单个写作风格详情"""
|
"""获取单个写作风格详情"""
|
||||||
|
user_id = get_current_user_id(request)
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||||
)
|
)
|
||||||
style = result.scalar_one_or_none()
|
style = result.scalar_one_or_none()
|
||||||
if not style:
|
if not style:
|
||||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||||
|
if style.user_id is not None and style.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||||
|
|
||||||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -501,4 +505,4 @@ async def initialize_default_styles(
|
|||||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||||
"""
|
"""
|
||||||
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
||||||
return await get_project_styles(project_id, request, db)
|
return await get_project_styles(project_id, request, db)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class Settings(BaseSettings):
|
|||||||
app_version: str = "1.0.0"
|
app_version: str = "1.0.0"
|
||||||
app_host: str = "0.0.0.0"
|
app_host: str = "0.0.0.0"
|
||||||
app_port: int = 8000
|
app_port: int = 8000
|
||||||
debug: bool = True
|
debug: bool = False
|
||||||
|
|
||||||
# 日志配置
|
# 日志配置
|
||||||
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||||
@@ -106,6 +106,7 @@ class Settings(BaseSettings):
|
|||||||
# 会话配置
|
# 会话配置
|
||||||
SESSION_EXPIRE_MINUTES: int = 120 # 会话过期时间(分钟),默认2小时
|
SESSION_EXPIRE_MINUTES: int = 120 # 会话过期时间(分钟),默认2小时
|
||||||
SESSION_REFRESH_THRESHOLD_MINUTES: int = 30 # 会话刷新阈值(分钟),剩余时间少于此值时可刷新
|
SESSION_REFRESH_THRESHOLD_MINUTES: int = 30 # 会话刷新阈值(分钟),剩余时间少于此值时可刷新
|
||||||
|
SESSION_SECRET_KEY: Optional[str] = None # 会话签名密钥,生产环境必须配置为高强度随机值
|
||||||
|
|
||||||
# 系统 SMTP 默认配置(可被管理员系统设置覆盖)
|
# 系统 SMTP 默认配置(可被管理员系统设置覆盖)
|
||||||
SMTP_PROVIDER: str = "qq"
|
SMTP_PROVIDER: str = "qq"
|
||||||
|
|||||||
+17
-5
@@ -1,5 +1,5 @@
|
|||||||
"""FastAPI应用主入口"""
|
"""FastAPI应用主入口"""
|
||||||
from fastapi import FastAPI, Request, status
|
from fastapi import FastAPI, Request, status, HTTPException, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import JSONResponse, FileResponse
|
from fastapi.responses import JSONResponse, FileResponse
|
||||||
@@ -106,7 +106,7 @@ async def health_check():
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/health/db-sessions")
|
@app.get("/health/db-sessions")
|
||||||
async def db_session_stats():
|
async def db_session_stats(request: Request):
|
||||||
"""
|
"""
|
||||||
数据库会话统计(监控连接泄漏)
|
数据库会话统计(监控连接泄漏)
|
||||||
|
|
||||||
@@ -118,6 +118,8 @@ async def db_session_stats():
|
|||||||
- generator_exits: SSE断开次数
|
- generator_exits: SSE断开次数
|
||||||
- last_check: 最后检查时间
|
- last_check: 最后检查时间
|
||||||
"""
|
"""
|
||||||
|
if not getattr(request.state, "is_admin", False):
|
||||||
|
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"session_stats": _session_stats,
|
"session_stats": _session_stats,
|
||||||
@@ -176,8 +178,18 @@ if static_dir.exists():
|
|||||||
)
|
)
|
||||||
|
|
||||||
file_path = static_dir / full_path
|
file_path = static_dir / full_path
|
||||||
if file_path.is_file():
|
try:
|
||||||
return FileResponse(file_path)
|
resolved_file = file_path.resolve()
|
||||||
|
resolved_static = static_dir.resolve()
|
||||||
|
resolved_file.relative_to(resolved_static)
|
||||||
|
except ValueError:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={"detail": "页面不存在"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if resolved_file.is_file():
|
||||||
|
return FileResponse(resolved_file)
|
||||||
|
|
||||||
index_file = static_dir / "index.html"
|
index_file = static_dir / "index.html"
|
||||||
if index_file.exists():
|
if index_file.exists():
|
||||||
@@ -207,4 +219,4 @@ if __name__ == "__main__":
|
|||||||
host=config_settings.app_host,
|
host=config_settings.app_host,
|
||||||
port=config_settings.app_port,
|
port=config_settings.app_port,
|
||||||
reload=config_settings.debug
|
reload=config_settings.debug
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from fastapi import Request
|
|||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from app.user_manager import user_manager
|
from app.user_manager import user_manager
|
||||||
from app.logger import get_logger
|
from app.logger import get_logger
|
||||||
|
from app.security import verify_session_token
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -46,8 +47,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
request.state.is_proxy_request = False
|
request.state.is_proxy_request = False
|
||||||
request.state.proxy_instance_id = None
|
request.state.proxy_instance_id = None
|
||||||
|
|
||||||
# 从 Cookie 中获取用户 ID
|
# 优先验证签名会话 Cookie;不再信任客户端可伪造的明文 user_id。
|
||||||
user_id = request.cookies.get("user_id")
|
user_id = verify_session_token(request.cookies.get("session_token"))
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
user = await user_manager.get_user(user_id)
|
user = await user_manager.get_user(user_id)
|
||||||
@@ -77,4 +78,4 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# 继续处理请求
|
# 继续处理请求
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class UserPassword(Base):
|
|||||||
|
|
||||||
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
|
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
|
||||||
username = Column(String(100), nullable=False, comment="用户名")
|
username = Column(String(100), nullable=False, comment="用户名")
|
||||||
password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256)")
|
password_hash = Column(String(255), nullable=False, comment="密码哈希")
|
||||||
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
|
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
"""Security helpers for sessions and outbound URL validation."""
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import ipaddress
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
from typing import Iterable
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _session_secret() -> bytes:
|
||||||
|
secret = (
|
||||||
|
getattr(settings, "SESSION_SECRET_KEY", None)
|
||||||
|
or getattr(settings, "session_secret_key", None)
|
||||||
|
or settings.LINUXDO_CLIENT_SECRET
|
||||||
|
or settings.LOCAL_AUTH_PASSWORD
|
||||||
|
or settings.openai_api_key
|
||||||
|
)
|
||||||
|
if not secret:
|
||||||
|
secret = "mumuainovel-development-session-secret"
|
||||||
|
return str(secret).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_token(user_id: str, max_age_seconds: int) -> str:
|
||||||
|
payload = {
|
||||||
|
"uid": user_id,
|
||||||
|
"exp": int(time.time()) + max_age_seconds,
|
||||||
|
"nonce": secrets.token_urlsafe(16),
|
||||||
|
}
|
||||||
|
payload_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||||
|
payload_b64 = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode("ascii")
|
||||||
|
signature = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest()
|
||||||
|
signature_b64 = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("ascii")
|
||||||
|
return f"{payload_b64}.{signature_b64}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_session_token(token: str) -> str | None:
|
||||||
|
if not token or "." not in token:
|
||||||
|
return None
|
||||||
|
payload_b64, signature_b64 = token.split(".", 1)
|
||||||
|
expected = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest()
|
||||||
|
try:
|
||||||
|
provided = base64.urlsafe_b64decode(signature_b64 + "=" * (-len(signature_b64) % 4))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if not hmac.compare_digest(expected, provided):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
payload_raw = base64.urlsafe_b64decode(payload_b64 + "=" * (-len(payload_b64) % 4))
|
||||||
|
payload = json.loads(payload_raw.decode("utf-8"))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if int(payload.get("exp", 0)) < int(time.time()):
|
||||||
|
return None
|
||||||
|
user_id = payload.get("uid")
|
||||||
|
return user_id if isinstance(user_id, str) and user_id else None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_forbidden_ip(ip: ipaddress._BaseAddress) -> bool:
|
||||||
|
return any([
|
||||||
|
ip.is_private,
|
||||||
|
ip.is_loopback,
|
||||||
|
ip.is_link_local,
|
||||||
|
ip.is_multicast,
|
||||||
|
ip.is_reserved,
|
||||||
|
ip.is_unspecified,
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def validate_public_http_url(raw_url: str, *, allowed_schemes: Iterable[str] = ("https", "http")) -> str:
|
||||||
|
"""Validate an outbound URL to reduce SSRF risk."""
|
||||||
|
if not raw_url or not isinstance(raw_url, str):
|
||||||
|
raise HTTPException(status_code=400, detail="URL不能为空")
|
||||||
|
|
||||||
|
parsed = urlparse(raw_url.strip())
|
||||||
|
if parsed.scheme not in set(allowed_schemes):
|
||||||
|
raise HTTPException(status_code=400, detail="仅支持 HTTP/HTTPS URL")
|
||||||
|
if not parsed.hostname:
|
||||||
|
raise HTTPException(status_code=400, detail="URL缺少主机名")
|
||||||
|
if parsed.username or parsed.password:
|
||||||
|
raise HTTPException(status_code=400, detail="URL不允许包含认证信息")
|
||||||
|
|
||||||
|
host = parsed.hostname.strip().rstrip(".")
|
||||||
|
if host.lower() in {"localhost", "localhost.localdomain"}:
|
||||||
|
raise HTTPException(status_code=400, detail="URL不允许指向本机地址")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(host)
|
||||||
|
if _is_forbidden_ip(ip):
|
||||||
|
raise HTTPException(status_code=400, detail="URL不允许指向内网或保留地址")
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
infos = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80), type=socket.SOCK_STREAM)
|
||||||
|
except socket.gaierror:
|
||||||
|
raise HTTPException(status_code=400, detail="URL主机名无法解析")
|
||||||
|
for info in infos:
|
||||||
|
resolved_ip = ipaddress.ip_address(info[4][0])
|
||||||
|
if _is_forbidden_ip(resolved_ip):
|
||||||
|
raise HTTPException(status_code=400, detail="URL解析到内网或保留地址")
|
||||||
|
|
||||||
|
return raw_url.strip().rstrip("/")
|
||||||
@@ -38,7 +38,7 @@ class WorkshopClient:
|
|||||||
url = f"{self.base_url}/api/prompt-workshop{path}"
|
url = f"{self.base_url}/api/prompt-workshop{path}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
response = await client.request(
|
response = await client.request(
|
||||||
method=method,
|
method=method,
|
||||||
url=url,
|
url=url,
|
||||||
@@ -173,4 +173,4 @@ class WorkshopClient:
|
|||||||
|
|
||||||
|
|
||||||
# 全局客户端实例
|
# 全局客户端实例
|
||||||
workshop_client = WorkshopClient()
|
workshop_client = WorkshopClient()
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import secrets
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -34,7 +36,28 @@ class UserPasswordManager:
|
|||||||
|
|
||||||
def _hash_password(self, password: str) -> str:
|
def _hash_password(self, password: str) -> str:
|
||||||
"""密码哈希"""
|
"""密码哈希"""
|
||||||
return hashlib.sha256(password.encode()).hexdigest()
|
salt = secrets.token_hex(16)
|
||||||
|
iterations = 260000
|
||||||
|
digest = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), iterations).hex()
|
||||||
|
return f"pbkdf2_sha256${iterations}${salt}${digest}"
|
||||||
|
|
||||||
|
def _verify_hash(self, password: str, stored_hash: str) -> bool:
|
||||||
|
if stored_hash.startswith("pbkdf2_sha256$"):
|
||||||
|
try:
|
||||||
|
_, iterations, salt, digest = stored_hash.split("$", 3)
|
||||||
|
candidate = hashlib.pbkdf2_hmac(
|
||||||
|
"sha256",
|
||||||
|
password.encode(),
|
||||||
|
salt.encode(),
|
||||||
|
int(iterations),
|
||||||
|
).hex()
|
||||||
|
return hmac.compare_digest(candidate, digest)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Legacy unsalted SHA-256 hash support for existing deployments.
|
||||||
|
legacy_hash = hashlib.sha256(password.encode()).hexdigest()
|
||||||
|
return hmac.compare_digest(legacy_hash, stored_hash)
|
||||||
|
|
||||||
async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str:
|
async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -104,8 +127,12 @@ class UserPasswordManager:
|
|||||||
if not pwd_record:
|
if not pwd_record:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
password_hash = self._hash_password(password)
|
verified = self._verify_hash(password, pwd_record.password_hash)
|
||||||
return pwd_record.password_hash == password_hash
|
if verified and not pwd_record.password_hash.startswith("pbkdf2_sha256$"):
|
||||||
|
pwd_record.password_hash = self._hash_password(password)
|
||||||
|
pwd_record.updated_at = datetime.now()
|
||||||
|
await session.commit()
|
||||||
|
return verified
|
||||||
|
|
||||||
async def has_password(self, user_id: str) -> bool:
|
async def has_password(self, user_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -175,4 +202,4 @@ class UserPasswordManager:
|
|||||||
|
|
||||||
|
|
||||||
# 全局密码管理器实例
|
# 全局密码管理器实例
|
||||||
password_manager = UserPasswordManager()
|
password_manager = UserPasswordManager()
|
||||||
|
|||||||
Reference in New Issue
Block a user