update: 修复基于长亭monkeycode扫描结果的12处安全漏洞
This commit is contained in:
+42
@@ -0,0 +1,42 @@
|
||||
"""expand password hash length
|
||||
|
||||
Revision ID: 9a1b2c3d4e5f
|
||||
Revises: 6eb27fce64de
|
||||
Create Date: 2026-04-24 10:06:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9a1b2c3d4e5f'
|
||||
down_revision: Union[str, None] = '6eb27fce64de'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
'user_passwords',
|
||||
'password_hash',
|
||||
existing_type=sa.String(length=64),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=False,
|
||||
existing_comment='密码哈希(SHA256)',
|
||||
comment='密码哈希',
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
'user_passwords',
|
||||
'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(length=64),
|
||||
existing_nullable=False,
|
||||
existing_comment='密码哈希',
|
||||
comment='密码哈希(SHA256)',
|
||||
)
|
||||
+38
@@ -0,0 +1,38 @@
|
||||
"""expand password hash length
|
||||
|
||||
Revision ID: ab12cd34ef56
|
||||
Revises: 6ff45db05863
|
||||
Create Date: 2026-04-24 10:06:00
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'ab12cd34ef56'
|
||||
down_revision: Union[str, None] = '6ff45db05863'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
'password_hash',
|
||||
existing_type=sa.String(length=64),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table('user_passwords', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(length=64),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import secrets
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
@@ -46,7 +47,7 @@ class ToggleStatusRequest(BaseModel):
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
"""重置密码请求"""
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则重置为默认密码")
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,留空则生成临时密码")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
@@ -308,10 +309,14 @@ async def reset_password(
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 重置密码
|
||||
actual_password = await password_manager.set_password(
|
||||
generated_password = data.new_password
|
||||
if not generated_password:
|
||||
generated_password = secrets.token_urlsafe(12)
|
||||
|
||||
await password_manager.set_password(
|
||||
user_id=user_id,
|
||||
username=target_user.username,
|
||||
password=data.new_password
|
||||
password=generated_password
|
||||
)
|
||||
|
||||
logger.info(f"管理员 {admin.user_id} 重置了用户 {user_id} 的密码")
|
||||
@@ -319,7 +324,7 @@ async def reset_password(
|
||||
return {
|
||||
"success": True,
|
||||
"message": "密码重置成功",
|
||||
"new_password": actual_password
|
||||
"temporary_password": generated_password if not data.new_password else None
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
@@ -385,4 +390,4 @@ async def delete_user(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {str(e)}")
|
||||
|
||||
+25
-8
@@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import hashlib
|
||||
import random
|
||||
import secrets
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy import select
|
||||
@@ -21,6 +21,7 @@ from app.database import get_engine
|
||||
from app.models.user import User as UserModel
|
||||
from app.models.settings import Settings as SettingsModel
|
||||
from app.services.email_service import email_service
|
||||
from app.security import create_session_token
|
||||
|
||||
# 中国时区 UTC+8
|
||||
CHINA_TZ = timezone(timedelta(hours=8))
|
||||
@@ -43,6 +44,7 @@ _state_storage = {}
|
||||
|
||||
# 邮箱验证码临时存储(生产环境应使用 Redis)
|
||||
_email_verification_storage = {}
|
||||
MAX_VERIFICATION_ATTEMPTS = 5
|
||||
|
||||
EMAIL_REGEX = re.compile(r"^[^\s@]+@[^\s@]+\.[^\s@]+$")
|
||||
|
||||
@@ -246,12 +248,14 @@ def _validate_password(password: str):
|
||||
def _set_login_cookies(response: Response, user_id: str):
|
||||
"""设置登录 Cookie"""
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
session_token = create_session_token(user_id, max_age)
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=user_id,
|
||||
key="session_token",
|
||||
value=session_token,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
samesite="lax",
|
||||
secure=not settings.debug,
|
||||
)
|
||||
|
||||
china_now = get_china_now()
|
||||
@@ -263,12 +267,13 @@ def _set_login_cookies(response: Response, user_id: str):
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False,
|
||||
samesite="lax"
|
||||
samesite="lax",
|
||||
secure=not settings.debug,
|
||||
)
|
||||
|
||||
|
||||
def _generate_verification_code() -> str:
|
||||
return f"{random.randint(0, 999999):06d}"
|
||||
return f"{secrets.randbelow(1000000):06d}"
|
||||
|
||||
|
||||
def _build_verification_mail_content(scene: str, code: str, ttl_minutes: int) -> tuple[str, str, str]:
|
||||
@@ -456,6 +461,7 @@ async def send_email_verification_code(request: EmailSendCodeRequest):
|
||||
"code": code,
|
||||
"expires_at": expires_at,
|
||||
"last_sent_at": now,
|
||||
"attempts": 0,
|
||||
}
|
||||
|
||||
logger.info(f"[邮箱验证码] 场景={scene} 已发送到 {email}")
|
||||
@@ -493,6 +499,10 @@ async def email_register(request: EmailRegisterRequest, response: Response):
|
||||
raise HTTPException(status_code=400, detail="验证码已过期,请重新发送")
|
||||
|
||||
if cached["code"] != code:
|
||||
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||
_email_verification_storage.pop(_get_verification_storage_key("register", email), None)
|
||||
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||
raise HTTPException(status_code=400, detail="验证码错误")
|
||||
|
||||
existing_user = await _find_user_by_email(email)
|
||||
@@ -541,6 +551,10 @@ async def email_login(request: EmailLoginRequest, response: Response):
|
||||
raise HTTPException(status_code=400, detail="登录验证码已过期,请重新发送")
|
||||
|
||||
if cached["code"] != code:
|
||||
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||
_email_verification_storage.pop(storage_key, None)
|
||||
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||
raise HTTPException(status_code=400, detail="登录验证码错误")
|
||||
|
||||
_email_verification_storage.pop(storage_key, None)
|
||||
@@ -588,6 +602,10 @@ async def email_reset_password(request: EmailResetPasswordRequest):
|
||||
raise HTTPException(status_code=400, detail="重置密码验证码已过期,请重新发送")
|
||||
|
||||
if cached["code"] != code:
|
||||
cached["attempts"] = cached.get("attempts", 0) + 1
|
||||
if cached["attempts"] >= MAX_VERIFICATION_ATTEMPTS:
|
||||
_email_verification_storage.pop(storage_key, None)
|
||||
raise HTTPException(status_code=429, detail="验证码错误次数过多,请重新发送")
|
||||
raise HTTPException(status_code=400, detail="重置密码验证码错误")
|
||||
|
||||
await password_manager.set_password(user.user_id, email, request.new_password)
|
||||
@@ -757,6 +775,7 @@ async def logout(request: Request, response: Response):
|
||||
logger.info(f"🚪 [退出] 用户 {user_id} 退出登录")
|
||||
|
||||
response.delete_cookie("user_id")
|
||||
response.delete_cookie("session_token")
|
||||
response.delete_cookie("session_expire_at")
|
||||
return {"message": "退出登录成功"}
|
||||
|
||||
@@ -782,8 +801,6 @@ async def get_password_status(request: Request):
|
||||
username = await password_manager.get_username(user.user_id)
|
||||
|
||||
default_password = None
|
||||
if has_password and not has_custom:
|
||||
default_password = f"{user.username}@666"
|
||||
|
||||
return PasswordStatusResponse(
|
||||
has_password=has_password,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
更新日志API
|
||||
提供GitHub提交历史的缓存和代理服务
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, Depends
|
||||
from typing import List, Optional
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
@@ -13,6 +13,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def require_login(request: Request):
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="需要登录")
|
||||
return request.state.user
|
||||
|
||||
# GitHub API配置
|
||||
GITHUB_API_BASE = "https://api.github.com"
|
||||
REPO_OWNER = "xiamuceer-j"
|
||||
@@ -173,7 +179,7 @@ async def get_changelog(
|
||||
|
||||
|
||||
@router.post("/changelog/refresh")
|
||||
async def refresh_changelog():
|
||||
async def refresh_changelog(user=Depends(require_login)):
|
||||
"""
|
||||
刷新更新日志缓存
|
||||
|
||||
@@ -230,4 +236,4 @@ async def refresh_changelog():
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"刷新缓存失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -24,11 +24,22 @@ from app.user_manager import User
|
||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.logger import get_logger
|
||||
from app.security import validate_public_http_url
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
||||
|
||||
HTTP_PLUGIN_TYPES = {"http", "streamable_http", "sse"}
|
||||
|
||||
|
||||
def _validate_mcp_server_url(plugin_type: str, server_url: Optional[str]) -> Optional[str]:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES:
|
||||
if not server_url:
|
||||
raise HTTPException(status_code=400, detail=f"{plugin_type}类型插件必须提供server_url")
|
||||
return validate_public_http_url(server_url)
|
||||
return server_url
|
||||
|
||||
|
||||
def require_login(request: Request) -> User:
|
||||
"""依赖:要求用户已登录"""
|
||||
@@ -53,7 +64,8 @@ async def _register_plugin_background(
|
||||
try:
|
||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
@@ -123,11 +135,12 @@ async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||
return await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
url=server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers,
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
@@ -187,6 +200,10 @@ async def create_plugin(
|
||||
|
||||
# 创建插件数据
|
||||
plugin_data = data.model_dump()
|
||||
plugin_data["server_url"] = _validate_mcp_server_url(
|
||||
plugin_data.get("plugin_type", "http"),
|
||||
plugin_data.get("server_url")
|
||||
)
|
||||
|
||||
# 如果没有提供display_name,使用plugin_name作为默认值
|
||||
if not plugin_data.get("display_name"):
|
||||
@@ -278,12 +295,9 @@ async def create_plugin_simple(
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type in ["http", "streamable_http", "sse"]:
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
if server_type in HTTP_PLUGIN_TYPES:
|
||||
plugin_data["server_url"] = _validate_mcp_server_url(server_type, server_config.get("url"))
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
@@ -415,6 +429,12 @@ async def update_plugin(
|
||||
|
||||
# 更新字段
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
target_type = update_data.get("plugin_type", plugin.plugin_type)
|
||||
if "server_url" in update_data or target_type in HTTP_PLUGIN_TYPES:
|
||||
update_data["server_url"] = _validate_mcp_server_url(
|
||||
target_type,
|
||||
update_data.get("server_url", plugin.server_url)
|
||||
)
|
||||
for key, value in update_data.items():
|
||||
setattr(plugin, key, value)
|
||||
|
||||
@@ -501,7 +521,8 @@ async def toggle_plugin(
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
@@ -647,11 +668,12 @@ async def _ensure_plugin_registered(
|
||||
"""
|
||||
try:
|
||||
# 使用ensure_registered方法,它会检查是否已注册
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
url=server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
@@ -912,4 +934,4 @@ async def call_mcp_tool(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.logger import get_logger
|
||||
from app.config import settings as app_settings, PROJECT_ROOT
|
||||
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp, normalize_provider
|
||||
from app.services.email_service import email_service
|
||||
from app.security import validate_public_http_url
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -452,7 +453,8 @@ async def delete_settings(
|
||||
async def get_available_models(
|
||||
api_key: str,
|
||||
api_base_url: str,
|
||||
provider: str = "openai"
|
||||
provider: str = "openai",
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
从配置的 API 获取可用的模型列表
|
||||
@@ -467,6 +469,7 @@ async def get_available_models(
|
||||
"""
|
||||
try:
|
||||
provider = normalize_provider(provider)
|
||||
api_base_url = validate_public_http_url(api_base_url)
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
if provider == "openai" or provider == "azure" or provider == "custom":
|
||||
# OpenAI 兼容接口获取模型列表
|
||||
@@ -1291,4 +1294,4 @@ async def create_preset_from_current(
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 从当前配置创建预设: {name}")
|
||||
return await create_preset(create_request, user, db)
|
||||
return await create_preset(create_request, user, db)
|
||||
|
||||
@@ -32,7 +32,7 @@ class SetAdminRequest(BaseModel):
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
user_id: str
|
||||
new_password: Optional[str] = None # 如果为空则使用默认密码
|
||||
new_password: Optional[str] = None # 如果为空则由系统生成临时密码
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
@@ -140,7 +140,7 @@ async def reset_user_password(
|
||||
重置用户密码(仅管理员)
|
||||
|
||||
如果提供了 new_password,则设置为指定密码
|
||||
如果未提供 new_password,则重置为默认密码(username@666)
|
||||
如果未提供 new_password,则由系统生成临时密码
|
||||
|
||||
限制:
|
||||
- 不能重置自己的密码(应该使用修改密码功能)
|
||||
@@ -162,10 +162,15 @@ async def reset_user_password(
|
||||
|
||||
# 重置密码
|
||||
try:
|
||||
actual_password = await password_manager.set_password(
|
||||
generated_password = data.new_password
|
||||
if not generated_password:
|
||||
import secrets
|
||||
generated_password = secrets.token_urlsafe(12)
|
||||
|
||||
await password_manager.set_password(
|
||||
target_user.user_id,
|
||||
target_user.username,
|
||||
data.new_password
|
||||
generated_password
|
||||
)
|
||||
|
||||
# 如果使用了默认密码,返回密码供管理员告知用户
|
||||
@@ -177,8 +182,8 @@ async def reset_user_password(
|
||||
}
|
||||
|
||||
if not data.new_password:
|
||||
response_data["default_password"] = actual_password
|
||||
response_data["message"] = f"密码已重置为默认密码: {actual_password}"
|
||||
response_data["temporary_password"] = generated_password
|
||||
response_data["message"] = "密码已重置为系统生成的临时密码,请尽快通知用户修改"
|
||||
|
||||
return response_data
|
||||
|
||||
@@ -186,4 +191,4 @@ async def reset_user_password(
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"重置密码失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -26,6 +26,18 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None:
|
||||
if not project_id or not user_id:
|
||||
return None
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def world_building_generator(
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
@@ -326,12 +338,9 @@ async def career_system_generator(
|
||||
|
||||
# 获取项目信息
|
||||
yield await tracker.loading("加载项目信息...")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
# 设置用户信息以启用MCP
|
||||
@@ -599,12 +608,9 @@ async def characters_generator(
|
||||
|
||||
# 验证项目
|
||||
yield await tracker.loading("验证项目...", 0.3)
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
project.wizard_step = 2
|
||||
@@ -1270,12 +1276,9 @@ async def outline_generator(
|
||||
|
||||
# 获取项目信息
|
||||
yield await tracker.loading("加载项目信息...", 0.3)
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
# 获取角色信息
|
||||
@@ -1551,12 +1554,9 @@ async def world_building_regenerate_generator(
|
||||
|
||||
# 获取项目信息
|
||||
yield await tracker.loading("加载项目信息...")
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
project = await get_owned_project(db, project_id, user_id)
|
||||
if not project:
|
||||
yield await tracker.error("项目不存在", 404)
|
||||
yield await tracker.error("项目不存在或无权访问", 404)
|
||||
return
|
||||
|
||||
# 提取参数
|
||||
|
||||
@@ -275,15 +275,19 @@ async def get_project_styles(
|
||||
@router.get("/{style_id}", response_model=WritingStyleResponse)
|
||||
async def get_writing_style(
|
||||
style_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取单个写作风格详情"""
|
||||
user_id = get_current_user_id(request)
|
||||
result = await db.execute(
|
||||
select(WritingStyle).where(WritingStyle.id == style_id)
|
||||
)
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
if style.user_id is not None and style.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="写作风格不存在")
|
||||
|
||||
# 检查是否有项目将其设置为默认风格(一个风格可能被多个项目使用,使用 first() 避免 MultipleResultsFound)
|
||||
result = await db.execute(
|
||||
@@ -501,4 +505,4 @@ async def initialize_default_styles(
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 直接返回项目可用的所有风格(全局预设 + 用户自定义)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
|
||||
@@ -29,7 +29,7 @@ class Settings(BaseSettings):
|
||||
app_version: str = "1.0.0"
|
||||
app_host: str = "0.0.0.0"
|
||||
app_port: int = 8000
|
||||
debug: bool = True
|
||||
debug: bool = False
|
||||
|
||||
# 日志配置
|
||||
log_level: str = "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
@@ -106,6 +106,7 @@ class Settings(BaseSettings):
|
||||
# 会话配置
|
||||
SESSION_EXPIRE_MINUTES: int = 120 # 会话过期时间(分钟),默认2小时
|
||||
SESSION_REFRESH_THRESHOLD_MINUTES: int = 30 # 会话刷新阈值(分钟),剩余时间少于此值时可刷新
|
||||
SESSION_SECRET_KEY: Optional[str] = None # 会话签名密钥,生产环境必须配置为高强度随机值
|
||||
|
||||
# 系统 SMTP 默认配置(可被管理员系统设置覆盖)
|
||||
SMTP_PROVIDER: str = "qq"
|
||||
|
||||
+17
-5
@@ -1,5 +1,5 @@
|
||||
"""FastAPI应用主入口"""
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi import FastAPI, Request, status, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
@@ -106,7 +106,7 @@ async def health_check():
|
||||
|
||||
|
||||
@app.get("/health/db-sessions")
|
||||
async def db_session_stats():
|
||||
async def db_session_stats(request: Request):
|
||||
"""
|
||||
数据库会话统计(监控连接泄漏)
|
||||
|
||||
@@ -118,6 +118,8 @@ async def db_session_stats():
|
||||
- generator_exits: SSE断开次数
|
||||
- last_check: 最后检查时间
|
||||
"""
|
||||
if not getattr(request.state, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
return {
|
||||
"status": "ok",
|
||||
"session_stats": _session_stats,
|
||||
@@ -176,8 +178,18 @@ if static_dir.exists():
|
||||
)
|
||||
|
||||
file_path = static_dir / full_path
|
||||
if file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
try:
|
||||
resolved_file = file_path.resolve()
|
||||
resolved_static = static_dir.resolve()
|
||||
resolved_file.relative_to(resolved_static)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"detail": "页面不存在"}
|
||||
)
|
||||
|
||||
if resolved_file.is_file():
|
||||
return FileResponse(resolved_file)
|
||||
|
||||
index_file = static_dir / "index.html"
|
||||
if index_file.exists():
|
||||
@@ -207,4 +219,4 @@ if __name__ == "__main__":
|
||||
host=config_settings.app_host,
|
||||
port=config_settings.app_port,
|
||||
reload=config_settings.debug
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.user_manager import user_manager
|
||||
from app.logger import get_logger
|
||||
from app.security import verify_session_token
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -46,8 +47,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
request.state.is_proxy_request = False
|
||||
request.state.proxy_instance_id = None
|
||||
|
||||
# 从 Cookie 中获取用户 ID
|
||||
user_id = request.cookies.get("user_id")
|
||||
# 优先验证签名会话 Cookie;不再信任客户端可伪造的明文 user_id。
|
||||
user_id = verify_session_token(request.cookies.get("session_token"))
|
||||
|
||||
if user_id:
|
||||
user = await user_manager.get_user(user_id)
|
||||
@@ -77,4 +78,4 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# 继续处理请求
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return response
|
||||
|
||||
@@ -41,7 +41,7 @@ class UserPassword(Base):
|
||||
|
||||
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
|
||||
username = Column(String(100), nullable=False, comment="用户名")
|
||||
password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256)")
|
||||
password_hash = Column(String(255), nullable=False, comment="密码哈希")
|
||||
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Security helpers for sessions and outbound URL validation."""
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import json
|
||||
import secrets
|
||||
import socket
|
||||
import time
|
||||
from typing import Iterable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _session_secret() -> bytes:
|
||||
secret = (
|
||||
getattr(settings, "SESSION_SECRET_KEY", None)
|
||||
or getattr(settings, "session_secret_key", None)
|
||||
or settings.LINUXDO_CLIENT_SECRET
|
||||
or settings.LOCAL_AUTH_PASSWORD
|
||||
or settings.openai_api_key
|
||||
)
|
||||
if not secret:
|
||||
secret = "mumuainovel-development-session-secret"
|
||||
return str(secret).encode("utf-8")
|
||||
|
||||
|
||||
def create_session_token(user_id: str, max_age_seconds: int) -> str:
|
||||
payload = {
|
||||
"uid": user_id,
|
||||
"exp": int(time.time()) + max_age_seconds,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
}
|
||||
payload_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||
payload_b64 = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode("ascii")
|
||||
signature = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest()
|
||||
signature_b64 = base64.urlsafe_b64encode(signature).rstrip(b"=").decode("ascii")
|
||||
return f"{payload_b64}.{signature_b64}"
|
||||
|
||||
|
||||
def verify_session_token(token: str) -> str | None:
|
||||
if not token or "." not in token:
|
||||
return None
|
||||
payload_b64, signature_b64 = token.split(".", 1)
|
||||
expected = hmac.new(_session_secret(), payload_b64.encode("ascii"), hashlib.sha256).digest()
|
||||
try:
|
||||
provided = base64.urlsafe_b64decode(signature_b64 + "=" * (-len(signature_b64) % 4))
|
||||
except Exception:
|
||||
return None
|
||||
if not hmac.compare_digest(expected, provided):
|
||||
return None
|
||||
try:
|
||||
payload_raw = base64.urlsafe_b64decode(payload_b64 + "=" * (-len(payload_b64) % 4))
|
||||
payload = json.loads(payload_raw.decode("utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
if int(payload.get("exp", 0)) < int(time.time()):
|
||||
return None
|
||||
user_id = payload.get("uid")
|
||||
return user_id if isinstance(user_id, str) and user_id else None
|
||||
|
||||
|
||||
def _is_forbidden_ip(ip: ipaddress._BaseAddress) -> bool:
|
||||
return any([
|
||||
ip.is_private,
|
||||
ip.is_loopback,
|
||||
ip.is_link_local,
|
||||
ip.is_multicast,
|
||||
ip.is_reserved,
|
||||
ip.is_unspecified,
|
||||
])
|
||||
|
||||
|
||||
def validate_public_http_url(raw_url: str, *, allowed_schemes: Iterable[str] = ("https", "http")) -> str:
|
||||
"""Validate an outbound URL to reduce SSRF risk."""
|
||||
if not raw_url or not isinstance(raw_url, str):
|
||||
raise HTTPException(status_code=400, detail="URL不能为空")
|
||||
|
||||
parsed = urlparse(raw_url.strip())
|
||||
if parsed.scheme not in set(allowed_schemes):
|
||||
raise HTTPException(status_code=400, detail="仅支持 HTTP/HTTPS URL")
|
||||
if not parsed.hostname:
|
||||
raise HTTPException(status_code=400, detail="URL缺少主机名")
|
||||
if parsed.username or parsed.password:
|
||||
raise HTTPException(status_code=400, detail="URL不允许包含认证信息")
|
||||
|
||||
host = parsed.hostname.strip().rstrip(".")
|
||||
if host.lower() in {"localhost", "localhost.localdomain"}:
|
||||
raise HTTPException(status_code=400, detail="URL不允许指向本机地址")
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if _is_forbidden_ip(ip):
|
||||
raise HTTPException(status_code=400, detail="URL不允许指向内网或保留地址")
|
||||
except ValueError:
|
||||
try:
|
||||
infos = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80), type=socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
raise HTTPException(status_code=400, detail="URL主机名无法解析")
|
||||
for info in infos:
|
||||
resolved_ip = ipaddress.ip_address(info[4][0])
|
||||
if _is_forbidden_ip(resolved_ip):
|
||||
raise HTTPException(status_code=400, detail="URL解析到内网或保留地址")
|
||||
|
||||
return raw_url.strip().rstrip("/")
|
||||
@@ -38,7 +38,7 @@ class WorkshopClient:
|
||||
url = f"{self.base_url}/api/prompt-workshop{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
@@ -173,4 +173,4 @@ class WorkshopClient:
|
||||
|
||||
|
||||
# 全局客户端实例
|
||||
workshop_client = WorkshopClient()
|
||||
workshop_client = WorkshopClient()
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select
|
||||
@@ -34,7 +36,28 @@ class UserPasswordManager:
|
||||
|
||||
def _hash_password(self, password: str) -> str:
|
||||
"""密码哈希"""
|
||||
return hashlib.sha256(password.encode()).hexdigest()
|
||||
salt = secrets.token_hex(16)
|
||||
iterations = 260000
|
||||
digest = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), iterations).hex()
|
||||
return f"pbkdf2_sha256${iterations}${salt}${digest}"
|
||||
|
||||
def _verify_hash(self, password: str, stored_hash: str) -> bool:
|
||||
if stored_hash.startswith("pbkdf2_sha256$"):
|
||||
try:
|
||||
_, iterations, salt, digest = stored_hash.split("$", 3)
|
||||
candidate = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
password.encode(),
|
||||
salt.encode(),
|
||||
int(iterations),
|
||||
).hex()
|
||||
return hmac.compare_digest(candidate, digest)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Legacy unsalted SHA-256 hash support for existing deployments.
|
||||
legacy_hash = hashlib.sha256(password.encode()).hexdigest()
|
||||
return hmac.compare_digest(legacy_hash, stored_hash)
|
||||
|
||||
async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str:
|
||||
"""
|
||||
@@ -104,8 +127,12 @@ class UserPasswordManager:
|
||||
if not pwd_record:
|
||||
return False
|
||||
|
||||
password_hash = self._hash_password(password)
|
||||
return pwd_record.password_hash == password_hash
|
||||
verified = self._verify_hash(password, pwd_record.password_hash)
|
||||
if verified and not pwd_record.password_hash.startswith("pbkdf2_sha256$"):
|
||||
pwd_record.password_hash = self._hash_password(password)
|
||||
pwd_record.updated_at = datetime.now()
|
||||
await session.commit()
|
||||
return verified
|
||||
|
||||
async def has_password(self, user_id: str) -> bool:
|
||||
"""
|
||||
@@ -175,4 +202,4 @@ class UserPasswordManager:
|
||||
|
||||
|
||||
# 全局密码管理器实例
|
||||
password_manager = UserPasswordManager()
|
||||
password_manager = UserPasswordManager()
|
||||
|
||||
Reference in New Issue
Block a user