update:1.更新根据分析建议重新生成章节内容

This commit is contained in:
xiamuceer
2025-11-11 19:50:12 +08:00
parent 5b46d657f3
commit 913edd0cce
30 changed files with 3896 additions and 1928 deletions
+212 -18
View File
@@ -9,6 +9,7 @@ import hashlib
from datetime import datetime, timedelta, timezone
from app.services.oauth_service import LinuxDOOAuthService
from app.user_manager import user_manager
from app.user_password import password_manager
from app.database import init_db
from app.logger import get_logger
from app.config import settings
@@ -49,6 +50,25 @@ class LocalLoginResponse(BaseModel):
user: Optional[dict] = None
class SetPasswordRequest(BaseModel):
"""设置密码请求"""
password: str
class SetPasswordResponse(BaseModel):
"""设置密码响应"""
success: bool
message: str
class PasswordStatusResponse(BaseModel):
"""密码状态响应"""
has_password: bool
has_custom_password: bool
username: Optional[str] = None
default_password: Optional[str] = None
@router.get("/config")
async def get_auth_config():
"""获取认证配置信息"""
@@ -60,30 +80,77 @@ async def get_auth_config():
@router.post("/local/login", response_model=LocalLoginResponse)
async def local_login(request: LocalLoginRequest, response: Response):
"""本地账户登录"""
"""本地账户登录(支持.env配置的管理员账号和Linux DO授权后绑定的账号)"""
# 检查是否启用本地登录
if not settings.LOCAL_AUTH_ENABLED:
raise HTTPException(status_code=403, detail="本地账户登录未启用")
# 检查是否配置了本地账户
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=500, detail="本地账户未配置")
logger.info(f"[本地登录] 尝试登录用户名: {request.username}")
# 验证用户名和密码
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 首先尝试查找 Linux DO 授权后绑定的账号
all_users = await user_manager.get_all_users()
target_user = None
# 生成本地用户ID(使用用户名的hash)
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
for user in all_users:
# 同时检查 users 表的 username 和 user_passwords 表的 username
password_username = await password_manager.get_username(user.user_id)
if user.username == request.username or password_username == request.username:
target_user = user
logger.info(f"[本地登录] 找到 Linux DO 授权用户: {user.user_id}")
break
# 创建或更新本地用户
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=request.username,
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
avatar_url=None,
trust_level=9 # 本地用户给予高信任级别
)
# 如果找到了 Linux DO 授权的用户
if target_user:
# 检查是否有密码
if not await password_manager.has_password(target_user.user_id):
logger.warning(f"[本地登录] 用户 {target_user.user_id} 没有设置密码")
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 验证密码
if not await password_manager.verify_password(target_user.user_id, request.password):
logger.warning(f"[本地登录] 用户 {target_user.user_id} 密码验证失败")
raise HTTPException(status_code=401, detail="用户名或密码错误")
logger.info(f"[本地登录] Linux DO 授权用户 {target_user.user_id} 登录成功")
user = target_user
else:
# 没有找到 Linux DO 用户,尝试 .env 配置的管理员账号
logger.info(f"[本地登录] 未找到 Linux DO 用户,检查 .env 管理员账号")
# 检查是否配置了本地账户
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 生成本地用户ID(使用用户名的hash)
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
# 检查用户是否存在
user = await user_manager.get_user(user_id)
# 如果用户不存在,使用.env中的默认密码验证
if not user:
# 验证用户名和密码(使用.env配置)
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 创建本地用户
user = await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_id,
username=request.username,
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
avatar_url=None,
trust_level=9 # 本地用户给予高信任级别
)
# 为新用户设置默认密码到数据库
await password_manager.set_password(user.user_id, request.username, request.password)
logger.info(f"[本地登录] 管理员用户 {user.user_id} 初始密码已设置到数据库")
else:
# 用户已存在,使用数据库中的密码验证
if not await password_manager.verify_password(user.user_id, request.password):
raise HTTPException(status_code=401, detail="用户名或密码错误")
logger.info(f"[本地登录] 管理员用户 {user.user_id} 登录成功")
# 初始化用户数据库
try:
@@ -189,6 +256,11 @@ async def _handle_callback(
trust_level=trust_level
)
# 3.1. 自动绑定密码(如果还没有设置)
if not await password_manager.has_password(user.user_id):
default_password = await password_manager.set_password(user.user_id, username)
logger.info(f"用户 {user.user_id} ({username}) 自动绑定默认密码: {default_password}")
# 3.5. 初始化用户数据库(如果是新用户)
try:
await init_db(user.user_id)
@@ -337,4 +409,126 @@ async def get_current_user(request: Request):
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
return request.state.user.dict()
return request.state.user.dict()
@router.get("/password/status", response_model=PasswordStatusResponse)
async def get_password_status(request: Request):
"""获取当前用户的密码状态"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
user = request.state.user
has_password = await password_manager.has_password(user.user_id)
has_custom = await password_manager.has_custom_password(user.user_id)
username = await password_manager.get_username(user.user_id)
# 如果使用默认密码,返回默认密码供用户查看
default_password = None
if has_password and not has_custom:
default_password = f"{user.username}@666"
return PasswordStatusResponse(
has_password=has_password,
has_custom_password=has_custom,
username=username or user.username,
default_password=default_password
)
@router.post("/password/set", response_model=SetPasswordResponse)
async def set_user_password(request: Request, password_req: SetPasswordRequest):
"""设置当前用户的密码"""
if not hasattr(request.state, "user") or not request.state.user:
raise HTTPException(status_code=401, detail="未登录")
user = request.state.user
# 验证密码强度(至少6个字符)
if len(password_req.password) < 6:
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
# 设置密码
await password_manager.set_password(user.user_id, user.username, password_req.password)
logger.info(f"用户 {user.user_id} ({user.username}) 设置了自定义密码")
return SetPasswordResponse(
success=True,
message="密码设置成功"
)
@router.post("/bind/login", response_model=LocalLoginResponse)
async def bind_account_login(request: LocalLoginRequest, response: Response):
"""使用绑定的账号密码登录(LinuxDO授权后绑定的账号)"""
# 查找用户
all_users = await user_manager.get_all_users()
target_user = None
logger.info(f"[绑定账号登录] 尝试登录用户名: {request.username}")
logger.info(f"[绑定账号登录] 当前共有 {len(all_users)} 个用户")
for user in all_users:
# 同时检查 users 表的 username 和 user_passwords 表的 username
password_username = await password_manager.get_username(user.user_id)
logger.info(f"[绑定账号登录] 检查用户 {user.user_id}: users.username={user.username}, passwords.username={password_username}")
if user.username == request.username or password_username == request.username:
target_user = user
logger.info(f"[绑定账号登录] 找到匹配用户: {user.user_id}")
break
if not target_user:
logger.warning(f"[绑定账号登录] 用户名 {request.username} 未找到")
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 检查是否有密码记录
has_pwd = await password_manager.has_password(target_user.user_id)
if not has_pwd:
logger.warning(f"[绑定账号登录] 用户 {target_user.user_id} 没有设置密码")
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 验证密码
is_valid = await password_manager.verify_password(target_user.user_id, request.password)
logger.info(f"[绑定账号登录] 用户 {target_user.user_id} 密码验证结果: {is_valid}")
if not is_valid:
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 初始化用户数据库
try:
await init_db(target_user.user_id)
logger.info(f"绑定账号用户 {target_user.user_id} 数据库初始化成功")
except Exception as e:
logger.error(f"绑定账号用户 {target_user.user_id} 数据库初始化失败: {e}")
# 设置 Cookie2小时有效)
max_age = settings.SESSION_EXPIRE_MINUTES * 60
response.set_cookie(
key="user_id",
value=target_user.user_id,
max_age=max_age,
httponly=True,
samesite="lax"
)
# 设置过期时间戳 Cookie(用于前端判断)
china_now = get_china_now()
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
expire_at = int(expire_time.timestamp())
logger.info(f"✅ [绑定账号登录] 用户 {target_user.user_id} ({request.username}) 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
response.set_cookie(
key="session_expire_at",
value=str(expire_at),
max_age=max_age,
httponly=False, # 前端需要读取
samesite="lax"
)
return LocalLoginResponse(
success=True,
message="登录成功",
user=target_user.dict()
)
+296 -10
View File
@@ -1,6 +1,5 @@
"""章节管理API"""
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import json
@@ -19,6 +18,7 @@ from app.models.writing_style import WritingStyle
from app.models.analysis_task import AnalysisTask
from app.models.memory import PlotAnalysis, StoryMemory
from app.models.batch_generation_task import BatchGenerationTask
from app.models.regeneration_task import RegenerationTask
from app.schemas.chapter import (
ChapterCreate,
ChapterUpdate,
@@ -29,12 +29,19 @@ from app.schemas.chapter import (
BatchGenerateResponse,
BatchGenerateStatusResponse
)
from app.schemas.regeneration import (
ChapterRegenerateRequest,
RegenerationTaskResponse,
RegenerationTaskStatus
)
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.services.plot_analyzer import PlotAnalyzer
from app.services.memory_service import memory_service
from app.services.chapter_regenerator import ChapterRegenerator
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import create_sse_response
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
@@ -1284,15 +1291,7 @@ async def generate_chapter_content_stream(
except:
pass
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
return create_sse_response(event_generator())
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
@@ -2293,3 +2292,290 @@ async def generate_single_chapter_for_batch(
await db_session.refresh(chapter)
logger.info(f"✅ 单章节生成完成: 第{chapter.chapter_number}章,共 {new_word_count}")
# ==================== 章节重新生成相关API ====================
@router.post("/{chapter_id}/regenerate-stream", summary="流式重新生成章节内容")
async def regenerate_chapter_stream(
chapter_id: str,
request: Request,
regenerate_request: ChapterRegenerateRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
根据分析建议或自定义指令重新生成章节内容(流式返回)
工作流程:
1. 验证章节和分析结果
2. 创建重新生成任务
3. 构建修改指令
4. 流式生成新内容
5. 保存为版本历史
6. 可选自动应用
"""
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 验证章节存在
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
if not chapter.content or chapter.content.strip() == "":
raise HTTPException(status_code=400, detail="章节内容为空,无法重新生成")
# 验证用户权限
await verify_project_access(chapter.project_id, user_id, db)
# 获取分析结果(如果使用分析建议)
analysis = None
if regenerate_request.modification_source in ['analysis_suggestions', 'mixed']:
analysis_result = await db.execute(
select(PlotAnalysis)
.where(PlotAnalysis.chapter_id == chapter_id)
.order_by(PlotAnalysis.created_at.desc())
.limit(1)
)
analysis = analysis_result.scalar_one_or_none()
if not analysis:
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
# 预先获取项目上下文数据
async for temp_db in get_db(request):
try:
# 获取项目信息
project_result = await temp_db.execute(
select(Project).where(Project.id == chapter.project_id)
)
project = project_result.scalar_one_or_none()
# 获取角色信息
characters_result = await temp_db.execute(
select(Character).where(Character.project_id == chapter.project_id)
)
characters = characters_result.scalars().all()
# 获取章节大纲
outline_result = await temp_db.execute(
select(Outline)
.where(Outline.project_id == chapter.project_id)
.where(Outline.order_index == chapter.chapter_number)
)
outline = outline_result.scalar_one_or_none()
# 构建项目上下文
project_context = {
'project_title': project.title if project else '未知',
'genre': project.genre if project else '未设定',
'theme': project.theme if project else '未设定',
'narrative_perspective': project.narrative_perspective if project else '第三人称',
'time_period': project.world_time_period if project else '未设定',
'location': project.world_location if project else '未设定',
'atmosphere': project.world_atmosphere if project else '未设定',
'characters_info': "\n".join([
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
for c in characters
]) if characters else '暂无角色信息',
'chapter_outline': outline.content if outline else chapter.summary or '暂无大纲',
'previous_context': '' # 可以后续扩展添加前置章节上下文
}
finally:
await temp_db.close()
break
async def event_generator():
"""流式生成事件生成器"""
db_session = None
db_committed = False
try:
# 创建独立数据库会话
async for db_session in get_db(request):
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
# 创建重新生成任务
regen_task = RegenerationTask(
chapter_id=chapter_id,
analysis_id=analysis.id if analysis else None,
user_id=user_id,
project_id=chapter.project_id,
modification_instructions="", # 稍后填充
original_suggestions=analysis.suggestions if analysis else None,
selected_suggestion_indices=regenerate_request.selected_suggestion_indices,
custom_instructions=regenerate_request.custom_instructions,
style_id=regenerate_request.style_id,
target_word_count=regenerate_request.target_word_count,
focus_areas=regenerate_request.focus_areas,
preserve_elements=regenerate_request.preserve_elements.model_dump() if regenerate_request.preserve_elements else None,
status='running',
original_content=chapter.content,
original_word_count=chapter.word_count or len(chapter.content),
version_note=regenerate_request.version_note,
started_at=datetime.now()
)
db_session.add(regen_task)
await db_session.commit()
await db_session.refresh(regen_task)
task_id = regen_task.id
logger.info(f"📝 创建重新生成任务: {task_id}")
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
# 初始化重新生成器
regenerator = ChapterRegenerator(user_ai_service)
# 流式生成新内容
full_content = ""
async for event in regenerator.regenerate_with_feedback(
chapter=chapter,
analysis=analysis,
regenerate_request=regenerate_request,
project_context=project_context
):
# 处理不同类型的事件
if event['type'] == 'chunk':
# 内容块
chunk = event['content']
full_content += chunk
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
elif event['type'] == 'progress':
# 进度更新
progress_data = {
'type': 'progress',
'progress': event.get('progress', 0),
'message': event.get('message', ''),
'word_count': event.get('word_count', 0)
}
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
await asyncio.sleep(0)
# 更新任务状态
regen_task.status = 'completed'
regen_task.regenerated_content = full_content
regen_task.regenerated_word_count = len(full_content)
regen_task.completed_at = datetime.now()
# 计算差异统计
diff_stats = regenerator.calculate_content_diff(chapter.content, full_content)
await db_session.commit()
db_committed = True
# 先发送结果数据
result_data = {
'type': 'result',
'data': {
'task_id': task_id,
'word_count': len(full_content),
'version_number': regen_task.version_number,
'auto_applied': regenerate_request.auto_apply,
'diff_stats': diff_stats
}
}
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
# 再发送完成事件
completion_data = {
'type': 'done',
'message': '重新生成完成'
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
break
except Exception as e:
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
# 更新任务状态为失败
if db_session and not db_committed:
try:
task_result = await db_session.execute(
select(RegenerationTask).where(RegenerationTask.chapter_id == chapter_id)
.order_by(RegenerationTask.created_at.desc()).limit(1)
)
task = task_result.scalar_one_or_none()
if task:
task.status = 'failed'
task.error_message = str(e)[:500]
task.completed_at = datetime.now()
await db_session.commit()
except Exception as update_error:
logger.error(f"更新任务失败状态失败: {str(update_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
if db_session:
try:
if not db_committed and db_session.in_transaction():
await db_session.rollback()
await db_session.close()
except Exception as close_error:
logger.error(f"关闭数据库会话失败: {str(close_error)}")
return create_sse_response(event_generator())
@router.get("/{chapter_id}/regeneration/tasks", summary="获取章节的重新生成任务列表")
async def get_regeneration_tasks(
chapter_id: str,
request: Request,
limit: int = Query(10, ge=1, le=50),
db: AsyncSession = Depends(get_db)
):
"""获取指定章节的重新生成任务历史"""
user_id = getattr(request.state, 'user_id', None)
# 验证章节存在和权限
chapter_result = await db.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
chapter = chapter_result.scalar_one_or_none()
if not chapter:
raise HTTPException(status_code=404, detail="章节不存在")
await verify_project_access(chapter.project_id, user_id, db)
# 获取任务列表
result = await db.execute(
select(RegenerationTask)
.where(RegenerationTask.chapter_id == chapter_id)
.order_by(RegenerationTask.created_at.desc())
.limit(limit)
)
tasks = result.scalars().all()
return {
"chapter_id": chapter_id,
"total": len(tasks),
"tasks": [
{
"task_id": task.id,
"status": task.status,
"version_number": task.version_number,
"version_note": task.version_note,
"original_word_count": task.original_word_count,
"regenerated_word_count": task.regenerated_word_count,
"created_at": task.created_at.isoformat() if task.created_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None
}
for task in tasks
]
}
-383
View File
@@ -1135,386 +1135,3 @@ async def generate_outline_stream(
"""
return create_sse_response(outline_generator(data, db, user_ai_service))
async def update_world_building_generator(
project_id: str,
data: Dict[str, Any],
db: AsyncSession
) -> AsyncGenerator[str, None]:
"""更新世界观流式生成器"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始更新世界观...", 10)
# 获取项目
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
yield await SSEResponse.send_progress("验证数据...", 30)
# 更新世界观字段
if "time_period" in data:
project.world_time_period = data["time_period"]
if "location" in data:
project.world_location = data["location"]
if "atmosphere" in data:
project.world_atmosphere = data["atmosphere"]
if "rules" in data:
project.world_rules = data["rules"]
yield await SSEResponse.send_progress("保存到数据库...", 70)
await db.commit()
db_committed = True
await db.refresh(project)
# 发送结果
yield await SSEResponse.send_result({
"project_id": project.id,
"time_period": project.world_time_period,
"location": project.world_location,
"atmosphere": project.world_atmosphere,
"rules": project.world_rules
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("更新世界观生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("更新世界观事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"更新世界观失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("更新世界观事务已回滚(异常)")
yield await SSEResponse.send_error(f"更新失败: {str(e)}")
@router.post("/world-building/{project_id}", summary="流式更新世界观")
async def update_world_building_stream(
project_id: str,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db)
):
"""
使用SSE流式更新项目的世界观信息
请求体格式:
{
"time_period": "时间背景",
"location": "地理位置",
"atmosphere": "氛围基调",
"rules": "世界规则"
}
"""
return create_sse_response(update_world_building_generator(project_id, data, db))
async def regenerate_world_building_generator(
project_id: str,
data: Dict[str, Any],
db: AsyncSession,
user_ai_service: AIService
) -> AsyncGenerator[str, None]:
"""重新生成世界观流式生成器 - 支持MCP工具增强"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始重新生成世界观...", 10)
# 获取项目
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
provider = data.get("provider")
model = data.get("model")
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
user_id = data.get("user_id") # 从中间件注入
# 获取基础提示词
yield await SSEResponse.send_progress("准备AI提示词...", 15)
base_prompt = prompt_service.get_world_building_prompt(
title=project.title,
theme=project.theme or "",
genre=project.genre or ""
)
# MCP工具增强:收集参考资料
reference_materials = ""
if enable_mcp and user_id:
try:
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
# 直接调用MCP增强的AI,内部会自动检查和加载工具
# 构建资料收集提示词
planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。
【小说信息】
- 题材:{project.genre or '未设定'}
- 主题:{project.theme or '未设定'}
【任务】
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
你可以查询:
1. 历史背景(如果是历史题材)
2. 地理环境和文化特征
3. 相关领域的专业知识
4. 类似作品的设定参考
请根据题材特点,有针对性地查询2-3个关键问题。"""
# 调用MCP增强的AI(非流式,最多2轮工具调用)
planning_result = await user_ai_service.generate_text_with_mcp(
prompt=planning_prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
)
# 提取参考资料
if planning_result.get("tool_calls_made", 0) > 0:
yield await SSEResponse.send_progress(
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
25
)
reference_materials = planning_result.get("content", "")
else:
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
except Exception as e:
logger.warning(f"MCP工具调用失败(降级处理): {e}")
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
# 构建增强提示词
if reference_materials:
enhanced_prompt = f"""{base_prompt}
【参考资料】
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
{reference_materials}
请结合上述资料,生成符合历史/现实的世界观设定。"""
final_prompt = enhanced_prompt
yield await SSEResponse.send_progress("💡 已整合参考资料,开始重新生成世界观...", 30)
else:
final_prompt = base_prompt
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
# 流式生成世界观
accumulated_text = ""
chunk_count = 0
async for chunk in user_ai_service.generate_text_stream(
prompt=final_prompt,
provider=provider,
model=model
):
chunk_count += 1
accumulated_text += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
progress = min(30 + (chunk_count // 5), 70)
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
# 每20个块发送心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
# 解析结果
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
world_data = {}
try:
cleaned_text = accumulated_text.strip()
# 移除markdown代码块标记
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:].lstrip('\n\r')
elif cleaned_text.startswith('```'):
cleaned_text = cleaned_text[3:].lstrip('\n\r')
if cleaned_text.endswith('```'):
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
cleaned_text = cleaned_text.strip()
world_data = json.loads(cleaned_text)
except json.JSONDecodeError as e:
logger.error(f"AI返回非JSON格式: {e}")
logger.info(world_data)
world_data = {
"time_period": "AI返回格式错误,请重试",
"location": "AI返回格式错误,请重试",
"atmosphere": "AI返回格式错误,请重试",
"rules": "AI返回格式错误,请重试"
}
# 更新项目世界观
yield await SSEResponse.send_progress("保存到数据库...", 90)
project.world_time_period = world_data.get("time_period")
project.world_location = world_data.get("location")
project.world_atmosphere = world_data.get("atmosphere")
project.world_rules = world_data.get("rules")
await db.commit()
db_committed = True
await db.refresh(project)
# 发送结果
yield await SSEResponse.send_result({
"project_id": project.id,
"time_period": project.world_time_period,
"location": project.world_location,
"atmosphere": project.world_atmosphere,
"rules": project.world_rules
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("重新生成世界观生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("重新生成世界观事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"重新生成世界观失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("重新生成世界观事务已回滚(异常)")
yield await SSEResponse.send_error(f"重新生成失败: {str(e)}")
@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观")
async def regenerate_world_building_stream(
request: Request,
project_id: str,
data: Dict[str, Any],
db: AsyncSession = Depends(get_db),
user_ai_service: AIService = Depends(get_user_ai_service)
):
"""
使用SSE流式重新生成项目的世界观
请求体格式:
{
"provider": "AI提供商(可选)",
"model": "模型名称(可选)"
}
"""
# 从中间件注入user_id到data中
if hasattr(request.state, 'user_id'):
data['user_id'] = request.state.user_id
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
async def cleanup_wizard_data_generator(
project_id: str,
db: AsyncSession
) -> AsyncGenerator[str, None]:
"""清理向导数据流式生成器"""
db_committed = False
try:
yield await SSEResponse.send_progress("开始清理向导数据...", 10)
# 获取项目
result = await db.execute(
select(Project).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
yield await SSEResponse.send_error("项目不存在", 404)
return
# 删除相关的角色
yield await SSEResponse.send_progress("删除角色数据...", 30)
characters = await db.execute(
select(Character).where(Character.project_id == project_id)
)
char_count = 0
for character in characters.scalars():
await db.delete(character)
char_count += 1
# 删除相关的大纲
yield await SSEResponse.send_progress("删除大纲数据...", 50)
outlines = await db.execute(
select(Outline).where(Outline.project_id == project_id)
)
outline_count = 0
for outline in outlines.scalars():
await db.delete(outline)
outline_count += 1
# 删除相关的章节
yield await SSEResponse.send_progress("删除章节数据...", 70)
chapters = await db.execute(
select(Chapter).where(Chapter.project_id == project_id)
)
chapter_count = 0
for chapter in chapters.scalars():
await db.delete(chapter)
chapter_count += 1
# 删除项目
yield await SSEResponse.send_progress("删除项目...", 85)
await db.delete(project)
yield await SSEResponse.send_progress("提交数据库更改...", 95)
await db.commit()
db_committed = True
# 发送结果
yield await SSEResponse.send_result({
"message": "项目及相关数据已清理",
"deleted": {
"characters": char_count,
"outlines": outline_count,
"chapters": chapter_count
}
})
yield await SSEResponse.send_progress("完成!", 100, "success")
yield await SSEResponse.send_done()
except GeneratorExit:
logger.warning("清理向导数据生成器被提前关闭")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("清理向导数据事务已回滚(GeneratorExit")
except Exception as e:
logger.error(f"清理数据失败: {str(e)}")
if not db_committed and db.in_transaction():
await db.rollback()
logger.info("清理向导数据事务已回滚(异常)")
yield await SSEResponse.send_error(f"清理失败: {str(e)}")
@router.post("/cleanup/{project_id}", summary="流式清理向导数据")
async def cleanup_wizard_data_stream(
project_id: str,
db: AsyncSession = Depends(get_db)
):
"""
使用SSE流式清理向导过程中创建的项目及相关数据
用于返回上一步时清理已生成的内容
"""
return create_sse_response(cleanup_wizard_data_generator(project_id, db))
+1
View File
@@ -110,6 +110,7 @@ class Settings(BaseSettings):
class Config:
env_file = ".env"
case_sensitive = False
extra = "ignore" # 忽略未定义的环境变量,避免验证错误
# 创建全局配置实例
+2 -1
View File
@@ -21,7 +21,8 @@ from app.models import (
Project, Outline, Character, Chapter, GenerationHistory,
Settings, WritingStyle, ProjectDefaultStyle,
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
RegenerationTask
)
# 引擎缓存:每个用户一个引擎
+6 -1
View File
@@ -12,6 +12,8 @@ from app.models.memory import StoryMemory, PlotAnalysis
from app.models.writing_style import WritingStyle
from app.models.project_default_style import ProjectDefaultStyle
from app.models.mcp_plugin import MCPPlugin
from app.models.user import User, UserPassword
from app.models.regeneration_task import RegenerationTask
__all__ = [
"Project",
@@ -30,5 +32,8 @@ __all__ = [
"PlotAnalysis",
"WritingStyle",
"ProjectDefaultStyle",
"MCPPlugin"
"MCPPlugin",
"User",
"UserPassword",
"RegenerationTask"
]
+51
View File
@@ -0,0 +1,51 @@
"""章节重新生成任务模型"""
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, JSON, Boolean
from sqlalchemy.sql import func
from app.database import Base
import uuid
class RegenerationTask(Base):
"""章节重新生成任务表"""
__tablename__ = "regeneration_tasks"
# 基本信息
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
chapter_id = Column(String(36), ForeignKey('chapters.id', ondelete='CASCADE'), nullable=False, index=True)
analysis_id = Column(String(36), nullable=True, comment="关联的分析结果ID")
user_id = Column(String(50), nullable=False, index=True)
project_id = Column(String(36), nullable=False, index=True)
# 修改指令
modification_instructions = Column(Text, nullable=False, comment="综合修改指令")
original_suggestions = Column(JSON, comment="来自分析的原始建议列表")
selected_suggestion_indices = Column(JSON, comment="用户选择的建议索引")
custom_instructions = Column(Text, comment="用户自定义修改意见")
# 生成参数
style_id = Column(Integer, nullable=True, comment="写作风格ID")
target_word_count = Column(Integer, default=3000, comment="目标字数")
focus_areas = Column(JSON, comment="重点优化方向")
preserve_elements = Column(JSON, comment="需要保留的元素配置")
# 状态跟踪
status = Column(String(20), default='pending', comment="pending/running/completed/failed")
progress = Column(Integer, default=0, comment="进度 0-100")
error_message = Column(Text, nullable=True)
# 内容版本
original_content = Column(Text, comment="原始章节内容快照")
original_word_count = Column(Integer, comment="原始字数")
regenerated_content = Column(Text, comment="重新生成的内容")
regenerated_word_count = Column(Integer, comment="新内容字数")
version_number = Column(Integer, default=1, comment="版本号")
version_note = Column(String(500), comment="版本说明")
# 时间戳
created_at = Column(DateTime, server_default=func.now())
started_at = Column(DateTime, nullable=True)
completed_at = Column(DateTime, nullable=True)
def __repr__(self):
return f"<RegenerationTask(id={self.id[:8]}..., chapter_id={self.chapter_id[:8]}..., status={self.status})>"
+47
View File
@@ -0,0 +1,47 @@
"""
用户数据模型 - 存储用户基本信息
"""
from sqlalchemy import Column, String, Integer, Boolean, DateTime
from sqlalchemy.sql import func
from app.database import Base
class User(Base):
"""用户模型 - 存储OAuth和本地用户信息"""
__tablename__ = "users"
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID,格式:linuxdo_{id} 或 local_{id}")
username = Column(String(100), nullable=False, index=True, comment="用户名")
display_name = Column(String(200), nullable=False, comment="显示名称")
avatar_url = Column(String(500), nullable=True, comment="头像URL")
trust_level = Column(Integer, default=0, comment="信任等级(仅用于显示)")
is_admin = Column(Boolean, default=False, comment="是否为管理员")
linuxdo_id = Column(String(100), nullable=False, unique=True, index=True, comment="LinuxDO用户ID或本地用户ID")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
last_login = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后登录时间")
def to_dict(self):
"""转换为字典"""
return {
"user_id": self.user_id,
"username": self.username,
"display_name": self.display_name,
"avatar_url": self.avatar_url,
"trust_level": self.trust_level,
"is_admin": self.is_admin,
"linuxdo_id": self.linuxdo_id,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
}
class UserPassword(Base):
"""用户密码模型 - 存储用户密码信息"""
__tablename__ = "user_passwords"
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
username = Column(String(100), nullable=False, comment="用户名")
password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256")
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
+65
View File
@@ -0,0 +1,65 @@
"""章节重新生成相关的Schema定义"""
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from datetime import datetime
class PreserveElementsConfig(BaseModel):
"""保留元素配置"""
preserve_structure: bool = Field(False, description="是否保留整体结构")
preserve_dialogues: List[str] = Field(default_factory=list, description="需要保留的对话片段关键词")
preserve_plot_points: List[str] = Field(default_factory=list, description="需要保留的情节点关键词")
preserve_character_traits: bool = Field(True, description="保持角色性格一致")
class ChapterRegenerateRequest(BaseModel):
"""章节重新生成请求"""
# 修改来源
modification_source: str = Field("custom", description="修改来源: custom/analysis_suggestions/mixed")
# 基于分析建议
selected_suggestion_indices: Optional[List[int]] = Field(None, description="选中的建议索引列表")
# 自定义修改指令
custom_instructions: Optional[str] = Field(None, description="用户自定义的修改要求")
# 保留配置
preserve_elements: Optional[PreserveElementsConfig] = Field(None, description="保留元素配置")
# 生成参数
style_id: Optional[int] = Field(None, description="写作风格ID")
target_word_count: int = Field(3000, description="目标字数", ge=500, le=10000)
focus_areas: List[str] = Field(default_factory=list, description="重点优化方向")
# 版本管理
save_as_version: bool = Field(True, description="是否保存为新版本")
version_note: Optional[str] = Field(None, description="版本说明", max_length=500)
auto_apply: bool = Field(False, description="是否自动应用(替换当前内容)")
class RegenerationTaskResponse(BaseModel):
"""重新生成任务响应"""
task_id: str
chapter_id: str
status: str
message: str
estimated_time_seconds: int = 120
class RegenerationTaskStatus(BaseModel):
"""重新生成任务状态"""
task_id: str
chapter_id: str
status: str
progress: int
error_message: Optional[str] = None
created_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
# 结果信息
original_word_count: Optional[int] = None
regenerated_word_count: Optional[int] = None
version_number: Optional[int] = None
+308
View File
@@ -0,0 +1,308 @@
"""章节重新生成服务"""
from typing import Dict, Any, AsyncGenerator, Optional, List
from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service
from app.models.chapter import Chapter
from app.models.memory import PlotAnalysis
from app.schemas.regeneration import ChapterRegenerateRequest, PreserveElementsConfig
from app.logger import get_logger
import difflib
logger = get_logger(__name__)
class ChapterRegenerator:
"""章节重新生成服务"""
def __init__(self, ai_service: AIService):
self.ai_service = ai_service
logger.info("✅ ChapterRegenerator初始化成功")
async def regenerate_with_feedback(
self,
chapter: Chapter,
analysis: Optional[PlotAnalysis],
regenerate_request: ChapterRegenerateRequest,
project_context: Dict[str, Any]
) -> AsyncGenerator[Dict[str, Any], None]:
"""
根据反馈重新生成章节(流式)
Args:
chapter: 原始章节对象
analysis: 分析结果(可选)
regenerate_request: 重新生成请求参数
project_context: 项目上下文(项目信息、角色、大纲等)
Yields:
包含类型和数据的字典: {'type': 'progress'/'chunk', 'data': ...}
"""
try:
logger.info(f"🔄 开始重新生成章节: 第{chapter.chapter_number}")
# 1. 构建修改指令
yield {'type': 'progress', 'progress': 5, 'message': '正在构建修改指令...'}
modification_instructions = self._build_modification_instructions(
analysis=analysis,
regenerate_request=regenerate_request
)
logger.info(f"📝 修改指令构建完成,长度: {len(modification_instructions)}字符")
# 2. 构建完整提示词
yield {'type': 'progress', 'progress': 10, 'message': '正在构建生成提示词...'}
full_prompt = self._build_regeneration_prompt(
chapter=chapter,
modification_instructions=modification_instructions,
project_context=project_context,
regenerate_request=regenerate_request
)
logger.info(f"🎯 提示词构建完成,开始AI生成")
yield {'type': 'progress', 'progress': 15, 'message': '开始AI生成内容...'}
# 3. 流式生成新内容,同时跟踪进度
target_word_count = regenerate_request.target_word_count
accumulated_length = 0
async for chunk in self.ai_service.generate_text_stream(
prompt=full_prompt,
temperature=0.7
):
# 发送内容块
yield {'type': 'chunk', 'content': chunk}
# 更新累积字数并计算进度(15%-95%)
accumulated_length += len(chunk)
# 进度从15%开始,到95%结束,为后处理预留5%
generation_progress = min(15 + (accumulated_length / target_word_count) * 80, 95)
yield {'type': 'progress', 'progress': int(generation_progress), 'word_count': accumulated_length}
logger.info(f"✅ 章节重新生成完成,共生成 {accumulated_length}")
yield {'type': 'progress', 'progress': 100, 'message': '生成完成'}
except Exception as e:
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
raise
def _build_modification_instructions(
self,
analysis: Optional[PlotAnalysis],
regenerate_request: ChapterRegenerateRequest
) -> str:
"""构建修改指令"""
instructions = []
# 标题
instructions.append("# 章节修改指令\n")
# 1. 来自分析的建议
if (analysis and
regenerate_request.selected_suggestion_indices and
analysis.suggestions):
instructions.append("## 📋 需要改进的问题(来自AI分析):\n")
for idx in regenerate_request.selected_suggestion_indices:
if 0 <= idx < len(analysis.suggestions):
suggestion = analysis.suggestions[idx]
instructions.append(f"{idx + 1}. {suggestion}")
instructions.append("")
# 2. 用户自定义指令
if regenerate_request.custom_instructions:
instructions.append("## ✍️ 用户自定义修改要求:\n")
instructions.append(regenerate_request.custom_instructions)
instructions.append("")
# 3. 重点优化方向
if regenerate_request.focus_areas:
instructions.append("## 🎯 重点优化方向:\n")
focus_map = {
"pacing": "节奏把控 - 调整叙事速度,避免拖沓或过快",
"emotion": "情感渲染 - 深化人物情感表达,增强感染力",
"description": "场景描写 - 丰富环境细节,增强画面感",
"dialogue": "对话质量 - 让对话更自然真实,推动剧情",
"conflict": "冲突强度 - 强化矛盾冲突,提升戏剧张力"
}
for area in regenerate_request.focus_areas:
if area in focus_map:
instructions.append(f"- {focus_map[area]}")
instructions.append("")
# 4. 保留要求
if regenerate_request.preserve_elements:
preserve = regenerate_request.preserve_elements
instructions.append("## 🔒 必须保留的元素:\n")
if preserve.preserve_structure:
instructions.append("- 保持原章节的整体结构和情节框架")
if preserve.preserve_dialogues:
instructions.append("- 必须保留以下关键对话:")
for dialogue in preserve.preserve_dialogues:
instructions.append(f" * {dialogue}")
if preserve.preserve_plot_points:
instructions.append("- 必须保留以下关键情节点:")
for plot in preserve.preserve_plot_points:
instructions.append(f" * {plot}")
if preserve.preserve_character_traits:
instructions.append("- 保持所有角色的性格特征和行为模式一致")
instructions.append("")
return "\n".join(instructions)
def _build_regeneration_prompt(
self,
chapter: Chapter,
modification_instructions: str,
project_context: Dict[str, Any],
regenerate_request: ChapterRegenerateRequest
) -> str:
"""构建完整的重新生成提示词"""
prompt_parts = []
# 系统角色
prompt_parts.append("""你是一位经验丰富的专业小说编辑和作家。现在需要根据反馈意见重新创作一个章节。
你的任务是:
1. 仔细理解原章节的内容和意图
2. 认真分析所有的修改要求
3. 在保持故事连贯性的前提下,创作一个改进后的新版本
4. 确保新版本在艺术性和可读性上都有明显提升
---
""")
# 原始章节信息
prompt_parts.append(f"""## 📖 原始章节信息
**章节**:第{chapter.chapter_number}
**标题**{chapter.title}
**字数**{chapter.word_count}
**原始内容**
{chapter.content}
---
""")
# 修改指令
prompt_parts.append(modification_instructions)
prompt_parts.append("\n---\n")
# 项目背景信息
prompt_parts.append(f"""## 🌍 项目背景信息
**小说标题**{project_context.get('project_title', '未知')}
**题材**{project_context.get('genre', '未设定')}
**主题**{project_context.get('theme', '未设定')}
**叙事视角**{project_context.get('narrative_perspective', '第三人称')}
**世界观设定**
- 时代背景:{project_context.get('time_period', '未设定')}
- 地理位置:{project_context.get('location', '未设定')}
- 氛围基调:{project_context.get('atmosphere', '未设定')}
---
""")
# 角色信息
if project_context.get('characters_info'):
prompt_parts.append(f"""## 👥 角色信息
{project_context['characters_info']}
---
""")
# 章节大纲
if project_context.get('chapter_outline'):
prompt_parts.append(f"""## 📝 本章大纲
{project_context['chapter_outline']}
---
""")
# 前置章节上下文
if project_context.get('previous_context'):
prompt_parts.append(f"""## 📚 前置章节上下文
{project_context['previous_context']}
---
""")
# 创作要求
prompt_parts.append(f"""## ✨ 创作要求
1. **解决问题**:针对上述修改指令中提到的所有问题进行改进
2. **保持连贯**:确保与前后章节的情节、人物、风格保持一致
3. **提升质量**:在节奏、情感、描写等方面明显优于原版
4. **保留精华**:保持原章节中优秀的部分和关键情节
5. **字数控制**:目标字数约{regenerate_request.target_word_count}字(可适当浮动±20%
---
## 🎬 开始创作
请现在开始创作改进后的新版本章节内容。
**重要提示**
- 直接输出章节正文内容,从故事内容开始写
- **不要**输出章节标题(如"第X章""第X章:XXX"等)
- **不要**输出任何额外的说明、注释或元数据
- 只需要纯粹的故事正文内容
现在开始:
""")
return "\n".join(prompt_parts)
def calculate_content_diff(
self,
original_content: str,
new_content: str
) -> Dict[str, Any]:
"""
计算两个版本的差异
Returns:
差异统计信息
"""
# 基本统计
diff_stats = {
'original_length': len(original_content),
'new_length': len(new_content),
'length_change': len(new_content) - len(original_content),
'length_change_percent': round((len(new_content) - len(original_content)) / len(original_content) * 100, 2) if len(original_content) > 0 else 0
}
# 计算相似度
similarity = difflib.SequenceMatcher(None, original_content, new_content).ratio()
diff_stats['similarity'] = round(similarity * 100, 2)
diff_stats['difference'] = round((1 - similarity) * 100, 2)
# 段落统计
original_paragraphs = [p for p in original_content.split('\n\n') if p.strip()]
new_paragraphs = [p for p in new_content.split('\n\n') if p.strip()]
diff_stats['original_paragraph_count'] = len(original_paragraphs)
diff_stats['new_paragraph_count'] = len(new_paragraphs)
return diff_stats
# 全局实例
_regenerator_instance = None
def get_chapter_regenerator(ai_service: AIService) -> ChapterRegenerator:
"""获取章节重新生成器实例"""
global _regenerator_instance
if _regenerator_instance is None:
_regenerator_instance = ChapterRegenerator(ai_service)
return _regenerator_instance
+145 -213
View File
@@ -1,106 +1,49 @@
"""
用户管理模块 - 支持 LinuxDO OAuth2
用户管理模块 - 使用数据库存储
"""
import json
import os
import asyncio
from datetime import datetime
from typing import Optional, Dict, List
from typing import Optional, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from pydantic import BaseModel
from app.config import settings, DATA_DIR
from app.config import settings
class User(BaseModel):
"""用户模型"""
user_id: str # 格式: linuxdo_{linuxdo_id}
"""用户数据传输对象"""
user_id: str
username: str
display_name: str
avatar_url: Optional[str] = None
trust_level: int = 0 # 仅用于显示
is_admin: bool = False # 手动设置的管理员权限
linuxdo_id: str # LinuxDO 用户 ID
trust_level: int = 0
is_admin: bool = False
linuxdo_id: str
created_at: str
last_login: str
class UserManager:
"""用户管理器 - 线程安全版本"""
USERS_FILE = str(DATA_DIR / "users.json")
ADMINS_FILE = str(DATA_DIR / "admins.json")
"""用户管理器 - 使用数据库存储(PostgreSQL共享库)"""
def __init__(self):
"""初始化用户管理器"""
# DATA_DIR 已在 config.py 中创建,无需重复创建
# 添加文件锁保护并发读写
self._users_lock = asyncio.Lock()
self._admins_lock = asyncio.Lock()
self._ensure_files_exist()
pass
def _ensure_files_exist(self):
"""确保必要的文件存在"""
if not os.path.exists(self.USERS_FILE):
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
json.dump({}, f, ensure_ascii=False, indent=2)
async def _get_session(self) -> AsyncSession:
"""获取数据库会话 - 使用共享的PostgreSQL引擎"""
from app.database import get_engine
if not os.path.exists(self.ADMINS_FILE):
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
json.dump({"admins": []}, f, ensure_ascii=False, indent=2)
def _load_users_unsafe(self) -> Dict[str, dict]:
"""加载用户数据(不加锁,内部使用)"""
try:
with open(self.USERS_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
print(f"加载用户数据失败: {e}")
return {}
def _save_users_unsafe(self, users: Dict[str, dict]):
"""保存用户数据(不加锁,内部使用)"""
try:
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
json.dump(users, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存用户数据失败: {e}")
async def _load_users(self) -> Dict[str, dict]:
"""加载用户数据(加锁)"""
async with self._users_lock:
return self._load_users_unsafe()
async def _save_users(self, users: Dict[str, dict]):
"""保存用户数据(加锁)"""
async with self._users_lock:
self._save_users_unsafe(users)
def _load_admin_list_unsafe(self) -> List[str]:
"""加载管理员列表(不加锁,内部使用)"""
try:
with open(self.ADMINS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
return data.get("admins", [])
except Exception as e:
print(f"加载管理员列表失败: {e}")
return []
def _save_admin_list_unsafe(self, admin_list: List[str]):
"""保存管理员列表(不加锁,内部使用)"""
try:
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
json.dump({"admins": admin_list}, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"保存管理员列表失败: {e}")
async def _load_admin_list(self) -> List[str]:
"""加载管理员列表(加锁)"""
async with self._admins_lock:
return self._load_admin_list_unsafe()
async def _save_admin_list(self, admin_list: List[str]):
"""保存管理员列表(加锁)"""
async with self._admins_lock:
self._save_admin_list_unsafe(admin_list)
# 使用共享的PostgreSQL引擎(user_id使用特殊标识)
engine = await get_engine("_global_users_")
session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
return session_maker()
async def create_or_update_from_linuxdo(
self,
@@ -111,106 +54,97 @@ class UserManager:
trust_level: int
) -> User:
"""
从 LinuxDO 用户信息创建或更新用户(线程安全)
从 LinuxDO 用户信息创建或更新用户
Args:
linuxdo_id: LinuxDO 用户 ID(本地用户时为 local_xxx 格式)
username: 用户名
display_name: 显示名称
avatar_url: 头像 URL
trust_level: 信任等级 (仅用于显示)
trust_level: 信任等级
Returns:
用户对象
"""
# 如果已经是 local_ 开头,直接使用;否则添加 linuxdo_ 前缀
from app.models.user import User as UserModel
# 生成 user_id
if linuxdo_id.startswith("local_"):
user_id = linuxdo_id
else:
user_id = f"linuxdo_{linuxdo_id}"
# 使用锁保护整个读-改-写操作
async with self._users_lock:
async with self._admins_lock:
users = self._load_users_unsafe()
admin_list = self._load_admin_list_unsafe()
async with await self._get_session() as session:
# 查询用户是否存在
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
user = result.scalar_one_or_none()
# 检查是否为初始管理员或本地用户
initial_admin_id = settings.INITIAL_ADMIN_LINUXDO_ID
is_initial_admin = (initial_admin_id and linuxdo_id == initial_admin_id)
is_local_user = user_id.startswith("local_")
is_admin = is_initial_admin or is_local_user
if user:
# 更新现有用户
user.username = username
user.display_name = display_name
user.avatar_url = avatar_url
user.trust_level = trust_level
user.last_login = datetime.now()
now = datetime.now().isoformat()
# 检查是否为初始管理员
initial_admin_id = settings.INITIAL_ADMIN_LINUXDO_ID
is_initial_admin = (initial_admin_id and linuxdo_id == initial_admin_id)
# 检查是否为本地用户(所有 local_ 开头的用户默认为管理员)
is_local_user = user_id.startswith("local_")
if user_id in users:
# 更新现有用户
user_data = users[user_id]
user_data["username"] = username
user_data["display_name"] = display_name
user_data["avatar_url"] = avatar_url
user_data["trust_level"] = trust_level
user_data["last_login"] = now
# 如果是初始管理员或本地用户且还不在管理员列表中,添加进去
if (is_initial_admin or is_local_user) and user_id not in admin_list:
admin_list.append(user_id)
self._save_admin_list_unsafe(admin_list)
user_data["is_admin"] = True
else:
# 从管理员列表同步 is_admin 状态
user_data["is_admin"] = user_id in admin_list
else:
# 创建新用户(本地用户默认为管理员)
is_admin = is_initial_admin or is_local_user
if is_admin and user_id not in admin_list:
admin_list.append(user_id)
self._save_admin_list_unsafe(admin_list)
user_data = {
"user_id": user_id,
"username": username,
"display_name": display_name,
"avatar_url": avatar_url,
"trust_level": trust_level,
"is_admin": is_admin,
"linuxdo_id": linuxdo_id,
"created_at": now,
"last_login": now
}
users[user_id] = user_data
self._save_users_unsafe(users)
return User(**user_data)
# 更新管理员状态
if is_admin and not user.is_admin:
user.is_admin = True
else:
# 创建新用户
user = UserModel(
user_id=user_id,
username=username,
display_name=display_name,
avatar_url=avatar_url,
trust_level=trust_level,
is_admin=is_admin,
linuxdo_id=linuxdo_id,
created_at=datetime.now(),
last_login=datetime.now()
)
session.add(user)
await session.commit()
await session.refresh(user)
return User(**user.to_dict())
async def get_user(self, user_id: str) -> Optional[User]:
"""获取用户(线程安全)"""
users = await self._load_users()
user_data = users.get(user_id)
if user_data:
# 同步管理员状态
admin_list = await self._load_admin_list()
user_data["is_admin"] = user_id in admin_list
return User(**user_data)
return None
"""获取用户"""
from app.models.user import User as UserModel
async with await self._get_session() as session:
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
user = result.scalar_one_or_none()
if user:
return User(**user.to_dict())
return None
async def get_all_users(self) -> List[User]:
"""获取所有用户(线程安全)"""
users = await self._load_users()
admin_list = await self._load_admin_list()
"""获取所有用户"""
from app.models.user import User as UserModel
user_list = []
for user_data in users.values():
# 同步管理员状态
user_data["is_admin"] = user_data["user_id"] in admin_list
user_list.append(User(**user_data))
return user_list
async with await self._get_session() as session:
result = await session.execute(select(UserModel))
users = result.scalars().all()
return [User(**user.to_dict()) for user in users]
async def set_admin(self, user_id: str, is_admin: bool) -> bool:
"""
设置用户的管理员权限(线程安全)
设置用户的管理员权限
Args:
user_id: 用户 ID
@@ -219,38 +153,35 @@ class UserManager:
Returns:
是否成功
"""
# 使用锁保护整个读-改-写操作
async with self._users_lock:
async with self._admins_lock:
users = self._load_users_unsafe()
if user_id not in users:
from app.models.user import User as UserModel
async with await self._get_session() as session:
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
user = result.scalar_one_or_none()
if not user:
return False
if not is_admin:
# 撤销管理员权限时,确保至少保留一个管理员
admin_result = await session.execute(
select(UserModel).where(UserModel.is_admin == True)
)
admin_count = len(admin_result.scalars().all())
if admin_count <= 1:
return False
admin_list = self._load_admin_list_unsafe()
if is_admin:
# 授予管理员权限
if user_id not in admin_list:
admin_list.append(user_id)
self._save_admin_list_unsafe(admin_list)
else:
# 撤销管理员权限
if user_id in admin_list:
# 确保至少保留一个管理员
if len(admin_list) <= 1:
return False
admin_list.remove(user_id)
self._save_admin_list_unsafe(admin_list)
# 更新用户数据中的 is_admin 字段
users[user_id]["is_admin"] = is_admin
self._save_users_unsafe(users)
return True
user.is_admin = is_admin
await session.commit()
return True
async def delete_user(self, user_id: str) -> bool:
"""
删除用户(线程安全)
删除用户
Args:
user_id: 用户 ID
@@ -258,36 +189,37 @@ class UserManager:
Returns:
是否成功
"""
# 使用锁保护整个读-改-写操作
async with self._users_lock:
async with self._admins_lock:
users = self._load_users_unsafe()
if user_id not in users:
return False
# 不能删除管理员
admin_list = self._load_admin_list_unsafe()
if user_id in admin_list:
return False
# 删除用户数据
del users[user_id]
self._save_users_unsafe(users)
from app.models.user import User as UserModel
# 删除用户数据库文件(在锁外执行,避免阻塞)
db_file = str(DATA_DIR / f"ai_story_user_{user_id}.db")
if os.path.exists(db_file):
try:
os.remove(db_file)
except Exception as e:
print(f"删除用户数据库文件失败: {e}")
return True
async with await self._get_session() as session:
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
user = result.scalar_one_or_none()
if not user:
return False
# 不能删除管理员
if user.is_admin:
return False
await session.delete(user)
await session.commit()
return True
async def is_admin(self, user_id: str) -> bool:
"""检查用户是否为管理员(线程安全)"""
admin_list = await self._load_admin_list()
return user_id in admin_list
"""检查用户是否为管理员"""
from app.models.user import User as UserModel
async with await self._get_session() as session:
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
user = result.scalar_one_or_none()
return user.is_admin if user else False
# 全局用户管理器实例
+178
View File
@@ -0,0 +1,178 @@
"""
用户密码管理模块 - 使用数据库存储
"""
import asyncio
import hashlib
from typing import Optional
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from app.config import settings
class UserPasswordManager:
"""用户密码管理器 - 使用数据库存储(PostgreSQL共享库)"""
def __init__(self):
"""初始化密码管理器"""
pass
async def _get_session(self) -> AsyncSession:
"""获取数据库会话 - 使用共享的PostgreSQL引擎"""
from app.database import get_engine
# 使用共享的PostgreSQL引擎(user_id使用特殊标识)
engine = await get_engine("_global_users_")
session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
return session_maker()
def _hash_password(self, password: str) -> str:
"""密码哈希"""
return hashlib.sha256(password.encode()).hexdigest()
async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str:
"""
设置用户密码
Args:
user_id: 用户ID
username: 用户名
password: 密码,如果为None则使用默认密码(username+@666
Returns:
实际使用的密码(明文,仅用于首次设置时返回给用户)
"""
from app.models.user import UserPassword as UserPasswordModel
# 如果没有提供密码,使用默认密码
actual_password = password if password else f"{username}@666"
async with await self._get_session() as session:
# 查询密码记录是否存在
result = await session.execute(
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
)
pwd_record = result.scalar_one_or_none()
if pwd_record:
# 更新现有密码
pwd_record.username = username
pwd_record.password_hash = self._hash_password(actual_password)
pwd_record.has_custom_password = password is not None
pwd_record.updated_at = datetime.now()
else:
# 创建新密码记录
pwd_record = UserPasswordModel(
user_id=user_id,
username=username,
password_hash=self._hash_password(actual_password),
has_custom_password=password is not None,
created_at=datetime.now(),
updated_at=datetime.now()
)
session.add(pwd_record)
await session.commit()
return actual_password
async def verify_password(self, user_id: str, password: str) -> bool:
"""
验证用户密码
Args:
user_id: 用户ID
password: 待验证的密码
Returns:
是否验证通过
"""
from app.models.user import UserPassword as UserPasswordModel
async with await self._get_session() as session:
result = await session.execute(
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
)
pwd_record = result.scalar_one_or_none()
if not pwd_record:
return False
password_hash = self._hash_password(password)
return pwd_record.password_hash == password_hash
async def has_password(self, user_id: str) -> bool:
"""
检查用户是否已设置密码
Args:
user_id: 用户ID
Returns:
是否已设置密码
"""
from app.models.user import UserPassword as UserPasswordModel
async with await self._get_session() as session:
result = await session.execute(
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
)
pwd_record = result.scalar_one_or_none()
return pwd_record is not None
async def has_custom_password(self, user_id: str) -> bool:
"""
检查用户是否设置了自定义密码(非默认密码)
Args:
user_id: 用户ID
Returns:
是否使用自定义密码
"""
from app.models.user import UserPassword as UserPasswordModel
async with await self._get_session() as session:
result = await session.execute(
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
)
pwd_record = result.scalar_one_or_none()
if not pwd_record:
return False
return pwd_record.has_custom_password
async def get_username(self, user_id: str) -> Optional[str]:
"""
获取用户名
Args:
user_id: 用户ID
Returns:
用户名,如果不存在返回None
"""
from app.models.user import UserPassword as UserPasswordModel
async with await self._get_session() as session:
result = await session.execute(
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
)
pwd_record = result.scalar_one_or_none()
if not pwd_record:
return None
return pwd_record.username
# 全局密码管理器实例
password_manager = UserPasswordManager()
+11 -1
View File
@@ -158,8 +158,18 @@ def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingRespon
Returns:
StreamingResponse对象
"""
async def wrapper():
"""包装生成器以捕获StreamingResponse初始化时的GeneratorExit"""
try:
async for chunk in generator:
yield chunk
except GeneratorExit:
# StreamingResponse在初始化时会进行类型检查,导致GeneratorExit
# 这是正常行为,不需要记录警告
pass
return StreamingResponse(
generator,
wrapper(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -0,0 +1,73 @@
-- 创建章节重新生成任务表
-- 用于支持根据AI分析建议重新生成章节内容的功能
-- 创建重新生成任务表
CREATE TABLE IF NOT EXISTS regeneration_tasks (
id VARCHAR(36) PRIMARY KEY,
chapter_id VARCHAR(36) NOT NULL,
analysis_id VARCHAR(36),
user_id VARCHAR(100) NOT NULL,
project_id VARCHAR(36) NOT NULL,
-- 修改指令
modification_instructions TEXT NOT NULL,
original_suggestions JSON,
selected_suggestion_indices JSON,
custom_instructions TEXT,
-- 生成配置
style_id INTEGER,
target_word_count INTEGER DEFAULT 3000,
focus_areas JSON,
preserve_elements JSON,
-- 任务状态
status VARCHAR(20) DEFAULT 'pending',
progress INTEGER DEFAULT 0,
error_message TEXT,
-- 内容数据
original_content TEXT,
original_word_count INTEGER,
regenerated_content TEXT,
regenerated_word_count INTEGER,
-- 版本信息
version_number INTEGER DEFAULT 1,
version_note TEXT,
-- 时间戳
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
-- 外键约束
CONSTRAINT fk_regeneration_chapter FOREIGN KEY (chapter_id) REFERENCES chapters(id) ON DELETE CASCADE,
CONSTRAINT fk_regeneration_project FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
CONSTRAINT fk_regeneration_analysis FOREIGN KEY (analysis_id) REFERENCES analysis_tasks(id) ON DELETE SET NULL,
CONSTRAINT fk_regeneration_style FOREIGN KEY (style_id) REFERENCES writing_styles(id) ON DELETE SET NULL
);
-- 创建索引以提升查询性能
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_chapter ON regeneration_tasks(chapter_id);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_project ON regeneration_tasks(project_id);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_user ON regeneration_tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_status ON regeneration_tasks(status);
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_created ON regeneration_tasks(created_at DESC);
-- 添加注释
COMMENT ON TABLE regeneration_tasks IS '章节重新生成任务表,记录每次根据AI建议重新生成章节的任务';
COMMENT ON COLUMN regeneration_tasks.modification_instructions IS '合并后的完整修改指令';
COMMENT ON COLUMN regeneration_tasks.original_suggestions IS '原始AI分析建议列表';
COMMENT ON COLUMN regeneration_tasks.selected_suggestion_indices IS '用户选择的建议索引';
COMMENT ON COLUMN regeneration_tasks.preserve_elements IS '需要保留的元素配置(JSON)';
COMMENT ON COLUMN regeneration_tasks.focus_areas IS '重点优化方向列表(JSON)';
-- 修复外键约束(合并自 fix_all_missing_columns.sql
-- 删除可能存在问题的外键约束
ALTER TABLE regeneration_tasks
DROP CONSTRAINT IF EXISTS fk_regeneration_analysis;
-- 完成提示
SELECT '✅ 重新生成任务表创建完成,外键约束已修复' AS status;
+224
View File
@@ -0,0 +1,224 @@
"""
用户数据迁移脚本 - 从JSON文件迁移到数据库
"""
import asyncio
import json
import os
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from app.user_manager import user_manager
from app.user_password import password_manager
from app.config import DATA_DIR
async def migrate_users():
"""迁移用户数据"""
users_file = DATA_DIR / "users.json"
if not users_file.exists():
print("❌ 用户数据文件不存在,跳过迁移")
return 0
try:
with open(users_file, "r", encoding="utf-8") as f:
users_data = json.load(f)
if not users_data:
print("️ 用户数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, user_info in users_data.items():
try:
# 迁移用户基本信息
await user_manager.create_or_update_from_linuxdo(
linuxdo_id=user_info["linuxdo_id"],
username=user_info["username"],
display_name=user_info["display_name"],
avatar_url=user_info.get("avatar_url"),
trust_level=user_info.get("trust_level", 0)
)
# 如果用户是管理员,设置管理员权限
if user_info.get("is_admin", False):
await user_manager.set_admin(user_id, True)
migrated_count += 1
print(f"✅ 迁移用户: {user_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移用户 {user_id} 失败: {e}")
print(f"\n✅ 用户数据迁移完成: {migrated_count}/{len(users_data)} 个用户")
# 备份原文件
backup_file = DATA_DIR / "users.json.backup"
os.rename(users_file, backup_file)
print(f"📦 原文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
print(f"❌ 迁移用户数据失败: {e}")
return 0
async def migrate_passwords():
"""迁移密码数据"""
passwords_file = DATA_DIR / "user_passwords.json"
if not passwords_file.exists():
print("❌ 密码数据文件不存在,跳过迁移")
return 0
try:
with open(passwords_file, "r", encoding="utf-8") as f:
passwords_data = json.load(f)
if not passwords_data:
print("️ 密码数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, pwd_info in passwords_data.items():
try:
# 直接插入密码记录(已经是哈希值)
from app.models.user import UserPassword
from app.user_password import password_manager as pm
async with await pm._get_session() as session:
from sqlalchemy import select
# 检查是否已存在
result = await session.execute(
select(UserPassword).where(UserPassword.user_id == user_id)
)
existing = result.scalar_one_or_none()
if existing:
print(f"️ 密码已存在,跳过: {pwd_info['username']} ({user_id})")
continue
# 创建密码记录
from datetime import datetime
pwd_record = UserPassword(
user_id=user_id,
username=pwd_info["username"],
password_hash=pwd_info["password_hash"],
has_custom_password=pwd_info.get("has_custom_password", False),
created_at=datetime.now(),
updated_at=datetime.now()
)
session.add(pwd_record)
await session.commit()
migrated_count += 1
print(f"✅ 迁移密码: {pwd_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移密码 {user_id} 失败: {e}")
print(f"\n✅ 密码数据迁移完成: {migrated_count}/{len(passwords_data)} 个密码")
# 备份原文件
backup_file = DATA_DIR / "user_passwords.json.backup"
os.rename(passwords_file, backup_file)
print(f"📦 原文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
print(f"❌ 迁移密码数据失败: {e}")
return 0
async def migrate_admins():
"""迁移管理员列表"""
admins_file = DATA_DIR / "admins.json"
if not admins_file.exists():
print("❌ 管理员数据文件不存在,跳过迁移")
return 0
try:
with open(admins_file, "r", encoding="utf-8") as f:
admins_data = json.load(f)
admin_list = admins_data.get("admins", [])
if not admin_list:
print("️ 管理员列表为空,跳过迁移")
return 0
migrated_count = 0
for user_id in admin_list:
try:
# 设置管理员权限
success = await user_manager.set_admin(user_id, True)
if success:
migrated_count += 1
print(f"✅ 设置管理员: {user_id}")
else:
print(f"⚠️ 用户不存在或已是管理员: {user_id}")
except Exception as e:
print(f"❌ 设置管理员 {user_id} 失败: {e}")
print(f"\n✅ 管理员数据迁移完成: {migrated_count}/{len(admin_list)} 个管理员")
# 备份原文件
backup_file = DATA_DIR / "admins.json.backup"
os.rename(admins_file, backup_file)
print(f"📦 原文件已备份到: {backup_file}")
return migrated_count
except Exception as e:
print(f"❌ 迁移管理员数据失败: {e}")
return 0
async def main():
"""主函数"""
print("=" * 60)
print("用户数据迁移工具 - JSON 到数据库")
print("=" * 60)
print()
# 迁移用户
print("📋 步骤 1/3: 迁移用户数据")
print("-" * 60)
user_count = await migrate_users()
print()
# 迁移密码
print("📋 步骤 2/3: 迁移密码数据")
print("-" * 60)
pwd_count = await migrate_passwords()
print()
# 迁移管理员
print("📋 步骤 3/3: 迁移管理员数据")
print("-" * 60)
admin_count = await migrate_admins()
print()
# 总结
print("=" * 60)
print("迁移完成")
print("=" * 60)
print(f"✅ 用户: {user_count}")
print(f"✅ 密码: {pwd_count}")
print(f"✅ 管理员: {admin_count}")
print()
print("💡 提示: 原文件已备份为 .backup 后缀")
print("💡 如需回滚,请删除数据库文件并恢复 .backup 文件")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,300 @@
"""
用户数据迁移脚本 - 从JSON文件迁移到PostgreSQL数据库
使用方法:
python migrate_users_to_postgres.py
python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname
"""
import asyncio
import json
import os
import sys
import argparse
from pathlib import Path
from datetime import datetime
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from app.config import settings, DATA_DIR
async def create_tables(engine):
"""创建用户相关表"""
from app.database import Base
from app.models.user import User, UserPassword
print("📋 创建数据库表...")
async with engine.begin() as conn:
# 只创建用户相关的表
await conn.run_sync(User.metadata.create_all)
await conn.run_sync(UserPassword.metadata.create_all)
print("✅ 表创建成功")
async def migrate_users(session):
"""迁移用户数据"""
from app.models.user import User as UserModel
users_file = DATA_DIR / "users.json"
if not users_file.exists():
print("️ 用户数据文件不存在,跳过迁移")
return 0
try:
with open(users_file, "r", encoding="utf-8") as f:
users_data = json.load(f)
if not users_data:
print("️ 用户数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, user_info in users_data.items():
try:
# 检查用户是否已存在
result = await session.execute(
select(UserModel).where(UserModel.user_id == user_id)
)
existing = result.scalar_one_or_none()
if existing:
print(f"️ 用户已存在,跳过: {user_info['username']} ({user_id})")
continue
# 创建用户记录
user = UserModel(
user_id=user_id,
username=user_info["username"],
display_name=user_info["display_name"],
avatar_url=user_info.get("avatar_url"),
trust_level=user_info.get("trust_level", 0),
is_admin=user_info.get("is_admin", False),
linuxdo_id=user_info["linuxdo_id"],
created_at=datetime.fromisoformat(user_info.get("created_at", datetime.now().isoformat())),
last_login=datetime.fromisoformat(user_info.get("last_login", datetime.now().isoformat()))
)
session.add(user)
migrated_count += 1
print(f"✅ 迁移用户: {user_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移用户 {user_id} 失败: {e}")
await session.commit()
print(f"\n✅ 用户数据迁移完成: {migrated_count}/{len(users_data)} 个用户")
return migrated_count
except Exception as e:
print(f"❌ 迁移用户数据失败: {e}")
await session.rollback()
return 0
async def migrate_passwords(session):
"""迁移密码数据"""
from app.models.user import UserPassword
passwords_file = DATA_DIR / "user_passwords.json"
if not passwords_file.exists():
print("️ 密码数据文件不存在,跳过迁移")
return 0
try:
with open(passwords_file, "r", encoding="utf-8") as f:
passwords_data = json.load(f)
if not passwords_data:
print("️ 密码数据为空,跳过迁移")
return 0
migrated_count = 0
for user_id, pwd_info in passwords_data.items():
try:
# 检查密码是否已存在
result = await session.execute(
select(UserPassword).where(UserPassword.user_id == user_id)
)
existing = result.scalar_one_or_none()
if existing:
print(f"️ 密码已存在,跳过: {pwd_info['username']} ({user_id})")
continue
# 创建密码记录
pwd_record = UserPassword(
user_id=user_id,
username=pwd_info["username"],
password_hash=pwd_info["password_hash"],
has_custom_password=pwd_info.get("has_custom_password", False),
created_at=datetime.now(),
updated_at=datetime.now()
)
session.add(pwd_record)
migrated_count += 1
print(f"✅ 迁移密码: {pwd_info['username']} ({user_id})")
except Exception as e:
print(f"❌ 迁移密码 {user_id} 失败: {e}")
await session.commit()
print(f"\n✅ 密码数据迁移完成: {migrated_count}/{len(passwords_data)} 个密码")
return migrated_count
except Exception as e:
print(f"❌ 迁移密码数据失败: {e}")
await session.rollback()
return 0
async def backup_json_files():
"""备份原始JSON文件"""
files_to_backup = ["users.json", "user_passwords.json", "admins.json"]
print("\n📦 备份原始文件...")
for filename in files_to_backup:
source = DATA_DIR / filename
if source.exists():
backup = DATA_DIR / f"{filename}.backup.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
import shutil
shutil.copy2(source, backup)
print(f"✅ 备份: {filename} -> {backup.name}")
async def main(db_url=None):
"""主函数
Args:
db_url: 可选的数据库URL,如果不提供则使用配置文件中的
"""
print("=" * 70)
print("用户数据迁移工具 - JSON 到 PostgreSQL")
print("=" * 70)
print()
# 确定使用的数据库URL
target_db_url = db_url if db_url else settings.database_url
# 检查数据库配置
if "postgresql" not in target_db_url:
print("❌ 错误: 未指定 PostgreSQL 数据库")
if not db_url:
print(f" 当前配置: {settings.database_url}")
print(" 请使用 --db-url 参数指定PostgreSQL数据库,或在 .env 中配置 DATABASE_URL")
else:
print(f" 提供的URL: {target_db_url}")
print()
print("示例:")
print(" python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname")
return
# 隐藏密码部分显示
display_url = target_db_url
if '@' in display_url:
parts = display_url.split('@')
if ':' in parts[0]:
user_part = parts[0].split(':')[0]
display_url = f"{user_part}:****@{parts[1]}"
print(f"📊 目标数据库: {display_url}")
print()
try:
# 创建数据库引擎
engine = create_async_engine(
target_db_url,
echo=False,
future=True,
pool_pre_ping=True,
)
# 创建表
await create_tables(engine)
print()
# 创建会话
async_session = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
# 迁移用户
print("📋 步骤 1/2: 迁移用户数据")
print("-" * 70)
async with async_session() as session:
user_count = await migrate_users(session)
print()
# 迁移密码
print("📋 步骤 2/2: 迁移密码数据")
print("-" * 70)
async with async_session() as session:
pwd_count = await migrate_passwords(session)
print()
# 备份原文件
await backup_json_files()
print()
# 总结
print("=" * 70)
print("迁移完成")
print("=" * 70)
print(f"✅ 用户: {user_count}")
print(f"✅ 密码: {pwd_count}")
print()
print("💡 提示:")
print(" - 原文件已备份(带时间戳)")
print(" - 可以安全删除 users.json 和 user_passwords.json")
print(" - 如需回滚,请从备份文件恢复")
print()
# 关闭引擎
await engine.dispose()
except Exception as e:
print(f"\n❌ 迁移过程出错: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
# 解析命令行参数
parser = argparse.ArgumentParser(
description="迁移用户数据从JSON到PostgreSQL数据库",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 使用 .env 配置的数据库
python migrate_users_to_postgres.py
# 指定数据库URL
python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname
# 使用环境变量
DATABASE_URL=postgresql+asyncpg://user:pass@localhost/db python migrate_users_to_postgres.py
"""
)
parser.add_argument(
"--db-url",
type=str,
help="PostgreSQL数据库连接URL (格式: postgresql+asyncpg://user:password@host:port/database)",
default=None
)
args = parser.parse_args()
# 运行迁移
asyncio.run(main(db_url=args.db_url))