update: 修复基于长亭monkeycode扫描结果的12处安全漏洞
This commit is contained in:
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import secrets
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
@@ -46,7 +47,7 @@ class ToggleStatusRequest(BaseModel):
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
"""重置密码请求"""
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码")
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则生成临时密码")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
@@ -308,10 +309,14 @@ async def reset_password(
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 重置密码
|
||||
actual_password = await password_manager.set_password(
|
||||
generated_password = data.new_password
|
||||
if not generated_password:
|
||||
generated_password = secrets.token_urlsafe(12)
|
||||
|
||||
await password_manager.set_password(
|
||||
user_id=user_id,
|
||||
username=target_user.username,
|
||||
password=data.new_password
|
||||
password=generated_password
|
||||
)
|
||||
|
||||
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
|
||||
@@ -319,7 +324,7 @@ async def reset_password(
|
||||
return {
|
||||
"success": True,
|
||||
"message": "密码重置成功",
|
||||
"new_password": actual_password
|
||||
"temporary_password": generated_password if not data.new_password else None
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
@@ -385,4 +390,4 @@ async def delete_user(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
||||
|
||||
+25
-8
@@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import hashlib
|
||||
import random
|
||||
import secrets
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy import select
|
||||
@@ -21,6 +21,7 @@ from app.database import get_engine
|
||||
from app.models.user import User as UserModel
|
||||
from app.models.settings import Settings as SettingsModel
|
||||
from app.services.email_service import email_service
|
||||
from app.security import create_session_token
|
||||
|
||||
# 中国时区 UTC+8
|
||||
CHINA_TZ = timezone(timedelta(hours=8))
|
||||
@@ -43,6 +44,7 @@ _state_storage = {}
|
||||
|
||||
# 邮箱验证码临时存储(生产环境应使用 Redis)
|
||||
_email_verification_storage = {}
|
||||
MAX_VERIFICATION_ATTEMPTS = 5
|
||||
|
||||
EMAIL_REGEX = re.compile(r"^[^\s@]+@[^\s@]+\.[^\s@]+$")
|
||||
|
||||
@@ -246,12 +248,14 @@ def _validate_password(password: str):
|
||||
def _set_login_cookies(response: Response, user_id: str):
|
||||
"""设置登录 Cookie"""
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
session_token = create_session_token(user_id, max_age)
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=user_id,
|
||||
key="session_token",
|
||||
value=session_token,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
samesite="lax",
|
||||
secure=not settings.debug,
|
||||
)
|
||||
|
||||
china_now = get_china_now()
|
||||
@@ -263,12 +267,13 @@ def _set_login_cookies(response: Response, user_id: str):
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False,
|
||||
samesite="lax"
|
||||
samesite="lax",
|
||||
secure=not settings.debug,
|
||||
)
|
||||
|
||||
|
||||
def _generate_verification_code() -> str:
|
||||
return f"{random.randint(0, 999999):06d}"
|
||||
return f"{secrets.randbelow(1000000):06d}"
|
||||
|
||||
|
||||
def _build_verification_mail_content(scene: str, code: str, ttl_minutes: int) -> tuple[str, str, str]:
|
||||
@@ -456,6 +461,7 @@ async def send_email_verification_code(request: EmailSendCodeRequest):
|
||||
"code": code,
|
||||
"expires_at": expires_at,
|
||||
"last_sent_at": now,
|
||||
"attempts": 0,
|
||||
}
|
||||
|
||||
logger.info(f"[邮箱验证码] 场景={scene} 已发送到 {email}")
|
||||
@@ -493,6 +499,10 @@ async def email_register(request: EmailRegisterRequest, response: Response):
|
||||
raise HTTPException(status_code=400, detail="验证码已过期,请重新发送")
|
||||
|
||||
if cached["code"] != code:
|
||||
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||
_email_verification_storage.pop(_get_verification_storage_key("register", email), None)
|
||||
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||
raise HTTPException(status_code=400, detail="验证码错误")
|
||||
|
||||
existing_user = await _find_user_by_email(email)
|
||||
@@ -541,6 +551,10 @@ async def email_login(request: EmailLoginRequest, response: Response):
|
||||
raise HTTPException(status_code=400, detail="登录验证码已过期,请重新发送")
|
||||
|
||||
if cached["code"] != code:
|
||||
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||
_email_verification_storage.pop(storage_key, None)
|
||||
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||
raise HTTPException(status_code=400, detail="登录验证码错误")
|
||||
|
||||
_email_verification_storage.pop(storage_key, None)
|
||||
@@ -588,6 +602,10 @@ async def email_reset_password(request: EmailResetPasswordRequest):
|
||||
raise HTTPException(status_code=400, detail="重置密码验证码已过期,请重新发送")
|
||||
|
||||
if cached["code"] != code:
|
||||
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||
_email_verification_storage.pop(storage_key, None)
|
||||
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||
raise HTTPException(status_code=400, detail="重置密码验证码错误")
|
||||
|
||||
await password_manager.set_password(user.user_id, email, request.new_password)
|
||||
@@ -757,6 +775,7 @@ async def logout(request: Request, response: Response):
|
||||
logger.info(f"🚪 [退出] 用户 {user_id} 退出登录")
|
||||
|
||||
response.delete_cookie("user_id")
|
||||
response.delete_cookie("session_token")
|
||||
response.delete_cookie("session_expire_at")
|
||||
return {"message": "退出登录成功"}
|
||||
|
||||
@@ -782,8 +801,6 @@ async def get_password_status(request: Request):
|
||||
username = await password_manager.get_username(user.user_id)
|
||||
|
||||
default_password = None
|
||||
if has_password and not has_custom:
|
||||
default_password = f"{user.username}@666"
|
||||
|
||||
return PasswordStatusResponse(
|
||||
has_password=has_password,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
更新日志API
|
||||
提供GitHub提交历史的缓存和代理服务
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, Depends
|
||||
from typing import List, Optional
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
@@ -13,6 +13,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def require_login(request: Request):
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
# GitHub API配置
|
||||
GITHUB_API_BASE = "https://api.github.com"
|
||||
REPO_OWNER = "xiamuceer-j"
|
||||
@@ -173,7 +179,7 @@ async def get_changelog(
|
||||
|
||||
|
||||
@router.post("/changelog/refresh")
|
||||
async def refresh_changelog():
|
||||
async def refresh_changelog(user=Depends(require_login)):
|
||||
"""
|
||||
刷新更新日志缓存
|
||||
|
||||
@@ -230,4 +236,4 @@ async def refresh_changelog():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"刷新缓存失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -24,11 +24,22 @@ from app.user_manager import User
|
||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.logger import get_logger
|
||||
from app.security import validate_public_http_url
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
||||
|
||||
HTTP_PLUGIN_TYPES = {"http", "streamable_http", "sse"}
|
||||
|
||||
|
||||
def _validate_mcp_server_url(plugin_type: str, server_url: Optional[str]) -> Optional[str]:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES:
|
||||
if not server_url:
|
||||
raise HTTPException(status_code=400, detail=f"{plugin_type}类型插件必须提供server_url")
|
||||
return validate_public_http_url(server_url)
|
||||
return server_url
|
||||
|
||||
|
||||
def require_login(request: Request) -> User:
|
||||
"""依赖:要求用户已登录"""
|
||||
@@ -53,7 +64,8 @@ async def _register_plugin_background(
|
||||
try:
|
||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
@@ -123,11 +135,12 @@ async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||
return await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
url=server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers,
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
@@ -187,6 +200,10 @@ async def create_plugin(
|
||||
|
||||
# 创建插件数据
|
||||
plugin_data = data.model_dump()
|
||||
plugin_data["server_url"] = _validate_mcp_server_url(
|
||||
plugin_data.get("plugin_type", "http"),
|
||||
plugin_data.get("server_url")
|
||||
)
|
||||
|
||||
# 如果没有提供display_name,使用plugin_name作为默认值
|
||||
if not plugin_data.get("display_name"):
|
||||
@@ -278,12 +295,9 @@ async def create_plugin_simple(
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type in ["http", "streamable_http", "sse"]:
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
if server_type in HTTP_PLUGIN_TYPES:
|
||||
plugin_data["server_url"] = _validate_mcp_server_url(server_type, server_config.get("url"))
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
@@ -415,6 +429,12 @@ async def update_plugin(
|
||||
|
||||
# 更新字段
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
target_type = update_data.get("plugin_type", plugin.plugin_type)
|
||||
if "server_url" in update_data or target_type in HTTP_PLUGIN_TYPES:
|
||||
update_data["server_url"] = _validate_mcp_server_url(
|
||||
target_type,
|
||||
update_data.get("server_url", plugin.server_url)
|
||||
)
|
||||
for key, value in update_data.items():
|
||||
setattr(plugin, key, value)
|
||||
|
||||
@@ -501,7 +521,8 @@ async def toggle_plugin(
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
@@ -647,11 +668,12 @@ async def _ensure_plugin_registered(
|
||||
"""
|
||||
try:
|
||||
# 使用ensure_registered方法,它会检查是否已注册
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
url=server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
@@ -912,4 +934,4 @@ async def call_mcp_tool(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider
|
||||
from app.services.email_service import email_service
|
||||
from app.security import validate_public_http_url
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -452,7 +453,8 @@ async def delete_settings(
|
||||
async def get_available_models(
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
provider: str = "openai"
|
||||
provider: str = "openai",
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
从配置的 API 获取可用的模型列表
|
||||
@@ -467,6 +469,7 @@ async def get_available_models(
|
||||
"""
|
||||
try:
|
||||
provider = normalize_provider(provider)
|
||||
api_base_url = validate_public_http_url(api_base_url)
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai" or provider == "azure" or provider == "custom":
|
||||
# OpenAI 兼容接口获取模型列表
|
||||
@@ -1291,4 +1294,4 @@ async def create_preset_from_current(
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}")
|
||||
return await create_preset(create_request, user, db)
|
||||
return await create_preset(create_request, user, db)
|
||||
|
||||
@@ -32,7 +32,7 @@ class SetAdminRequest(BaseModel):
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
user_id: str
|
||||
new_password: Optional[str] = None # 如果为空则使用默认密码
|
||||
new_password: Optional[str] = None # 如果为空则由系统生成临时密码
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
@@ -140,7 +140,7 @@ async def reset_user_password(
|
||||
重置用户密码(仅管理员)
|
||||
|
||||
如果提供了 new_password,则设置为指定密码
|
||||
如果未提供 new_password,则重置为默认密码(username@666)
|
||||
如果未提供 new_password,则由系统生成临时密码
|
||||
|
||||
限制:
|
||||
- 不能重置自己的密码(应该使用修改密码功能)
|
||||
@@ -162,10 +162,15 @@ async def reset_user_password(
|
||||
|
||||
# 重置密码
|
||||
try:
|
||||
actual_password = await password_manager.set_password(
|
||||
generated_password = data.new_password
|
||||
if not generated_password:
|
||||
import secrets
|
||||
generated_password = secrets.token_urlsafe(12)
|
||||
|
||||
await password_manager.set_password(
|
||||
target_user.user_id,
|
||||
target_user.username,
|
||||
data.new_password
|
||||
generated_password
|
||||
)
|
||||
|
||||
# 如果使用了默认密码,返回密码供管理员告知用户
|
||||
@@ -177,8 +182,8 @@ async def reset_user_password(
|
||||
}
|
||||
|
||||
if not data.new_password:
|
||||
response_data["default_password"] = actual_password
|
||||
response_data["message"] = f"密码已重置为默认密码: {actual_password}"
|
||||
response_data["temporary_password"] = generated_password
|
||||
response_data["message"] = "密码已重置为系统生成的临时密码,请尽快通知用户修改"
|
||||
|
||||
return response_data
|
||||
|
||||
@@ -186,4 +191,4 @@ async def reset_user_password(
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"重置密码失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -26,6 +26,18 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None:
|
||||
if not project_id or not user_id:
|
||||
return None
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def world_building_generator(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
@@ -326,12 +338,9 @@ async def career_system_generator(
|
||||
|
||||
# 获取项目信息
|
||||
yield await tracker.loading("加载项目信息...")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
# 设置用户信息以启用MCP
|
||||
@@ -599,12 +608,9 @@ async def characters_generator(
|
||||
|
||||
# 验证项目
|
||||
yield await tracker.loading("验证项目...", 0.3)
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
project.wizard_step = 2
|
||||
@@ -1270,12 +1276,9 @@ async def outline_generator(
|
||||
|
||||
# 获取项目信息
|
||||
yield await tracker.loading("加载项目信息...", 0.3)
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
# 获取角色信息
|
||||
@@ -1551,12 +1554,9 @@ async def world_building_regenerate_generator(
|
||||
|
||||
# 获取项目信息
|
||||
yield await tracker.loading("加载项目信息...")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
# 提取参数
|
||||
|
||||
@@ -275,15 +275,19 @@ async def get_project_styles(
|
||||
@router.get("/{style_id}", response_model=WritingStyleResponse)
|
||||
async def get_writing_style(
|
||||
style_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取单个写作风格详情"""
|
||||
user_id = get_current_user_id(request)
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
if style.user_id is not None and style.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||||
result = await db.execute(
|
||||
@@ -501,4 +505,4 @@ async def initialize_default_styles(
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
|
||||
Reference in New Issue
Block a user