update:1.更新根据分析建议重新生成章节内容
This commit is contained in:
+212
-18
@@ -9,6 +9,7 @@ import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.services.oauth_service import LinuxDOOAuthService
|
||||
from app.user_manager import user_manager
|
||||
from app.user_password import password_manager
|
||||
from app.database import init_db
|
||||
from app.logger import get_logger
|
||||
from app.config import settings
|
||||
@@ -49,6 +50,25 @@ class LocalLoginResponse(BaseModel):
|
||||
user: Optional[dict] = None
|
||||
|
||||
|
||||
class SetPasswordRequest(BaseModel):
|
||||
"""设置密码请求"""
|
||||
password: str
|
||||
|
||||
|
||||
class SetPasswordResponse(BaseModel):
|
||||
"""设置密码响应"""
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PasswordStatusResponse(BaseModel):
|
||||
"""密码状态响应"""
|
||||
has_password: bool
|
||||
has_custom_password: bool
|
||||
username: Optional[str] = None
|
||||
default_password: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_auth_config():
|
||||
"""获取认证配置信息"""
|
||||
@@ -60,30 +80,77 @@ async def get_auth_config():
|
||||
|
||||
@router.post("/local/login", response_model=LocalLoginResponse)
|
||||
async def local_login(request: LocalLoginRequest, response: Response):
|
||||
"""本地账户登录"""
|
||||
"""本地账户登录(支持.env配置的管理员账号和Linux DO授权后绑定的账号)"""
|
||||
# 检查是否启用本地登录
|
||||
if not settings.LOCAL_AUTH_ENABLED:
|
||||
raise HTTPException(status_code=403, detail="本地账户登录未启用")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=500, detail="本地账户未配置")
|
||||
logger.info(f"[本地登录] 尝试登录用户名: {request.username}")
|
||||
|
||||
# 验证用户名和密码
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
# 首先尝试查找 Linux DO 授权后绑定的账号
|
||||
all_users = await user_manager.get_all_users()
|
||||
target_user = None
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
for user in all_users:
|
||||
# 同时检查 users 表的 username 和 user_passwords 表的 username
|
||||
password_username = await password_manager.get_username(user.user_id)
|
||||
if user.username == request.username or password_username == request.username:
|
||||
target_user = user
|
||||
logger.info(f"[本地登录] 找到 Linux DO 授权用户: {user.user_id}")
|
||||
break
|
||||
|
||||
# 创建或更新本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
# 如果找到了 Linux DO 授权的用户
|
||||
if target_user:
|
||||
# 检查是否有密码
|
||||
if not await password_manager.has_password(target_user.user_id):
|
||||
logger.warning(f"[本地登录] 用户 {target_user.user_id} 没有设置密码")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
if not await password_manager.verify_password(target_user.user_id, request.password):
|
||||
logger.warning(f"[本地登录] 用户 {target_user.user_id} 密码验证失败")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
logger.info(f"[本地登录] Linux DO 授权用户 {target_user.user_id} 登录成功")
|
||||
user = target_user
|
||||
else:
|
||||
# 没有找到 Linux DO 用户,尝试 .env 配置的管理员账号
|
||||
logger.info(f"[本地登录] 未找到 Linux DO 用户,检查 .env 管理员账号")
|
||||
|
||||
# 检查是否配置了本地账户
|
||||
if not settings.LOCAL_AUTH_USERNAME or not settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 生成本地用户ID(使用用户名的hash)
|
||||
user_id = f"local_{hashlib.md5(request.username.encode()).hexdigest()[:16]}"
|
||||
|
||||
# 检查用户是否存在
|
||||
user = await user_manager.get_user(user_id)
|
||||
|
||||
# 如果用户不存在,使用.env中的默认密码验证
|
||||
if not user:
|
||||
# 验证用户名和密码(使用.env配置)
|
||||
if request.username != settings.LOCAL_AUTH_USERNAME or request.password != settings.LOCAL_AUTH_PASSWORD:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 创建本地用户
|
||||
user = await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_id,
|
||||
username=request.username,
|
||||
display_name=settings.LOCAL_AUTH_DISPLAY_NAME,
|
||||
avatar_url=None,
|
||||
trust_level=9 # 本地用户给予高信任级别
|
||||
)
|
||||
|
||||
# 为新用户设置默认密码到数据库
|
||||
await password_manager.set_password(user.user_id, request.username, request.password)
|
||||
logger.info(f"[本地登录] 管理员用户 {user.user_id} 初始密码已设置到数据库")
|
||||
else:
|
||||
# 用户已存在,使用数据库中的密码验证
|
||||
if not await password_manager.verify_password(user.user_id, request.password):
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
logger.info(f"[本地登录] 管理员用户 {user.user_id} 登录成功")
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
@@ -189,6 +256,11 @@ async def _handle_callback(
|
||||
trust_level=trust_level
|
||||
)
|
||||
|
||||
# 3.1. 自动绑定密码(如果还没有设置)
|
||||
if not await password_manager.has_password(user.user_id):
|
||||
default_password = await password_manager.set_password(user.user_id, username)
|
||||
logger.info(f"用户 {user.user_id} ({username}) 自动绑定默认密码: {default_password}")
|
||||
|
||||
# 3.5. 初始化用户数据库(如果是新用户)
|
||||
try:
|
||||
await init_db(user.user_id)
|
||||
@@ -337,4 +409,126 @@ async def get_current_user(request: Request):
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
return request.state.user.dict()
|
||||
return request.state.user.dict()
|
||||
|
||||
|
||||
@router.get("/password/status", response_model=PasswordStatusResponse)
|
||||
async def get_password_status(request: Request):
|
||||
"""获取当前用户的密码状态"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
has_password = await password_manager.has_password(user.user_id)
|
||||
has_custom = await password_manager.has_custom_password(user.user_id)
|
||||
username = await password_manager.get_username(user.user_id)
|
||||
|
||||
# 如果使用默认密码,返回默认密码供用户查看
|
||||
default_password = None
|
||||
if has_password and not has_custom:
|
||||
default_password = f"{user.username}@666"
|
||||
|
||||
return PasswordStatusResponse(
|
||||
has_password=has_password,
|
||||
has_custom_password=has_custom,
|
||||
username=username or user.username,
|
||||
default_password=default_password
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password/set", response_model=SetPasswordResponse)
|
||||
async def set_user_password(request: Request, password_req: SetPasswordRequest):
|
||||
"""设置当前用户的密码"""
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
user = request.state.user
|
||||
|
||||
# 验证密码强度(至少6个字符)
|
||||
if len(password_req.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="密码长度至少为6个字符")
|
||||
|
||||
# 设置密码
|
||||
await password_manager.set_password(user.user_id, user.username, password_req.password)
|
||||
logger.info(f"用户 {user.user_id} ({user.username}) 设置了自定义密码")
|
||||
|
||||
return SetPasswordResponse(
|
||||
success=True,
|
||||
message="密码设置成功"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bind/login", response_model=LocalLoginResponse)
|
||||
async def bind_account_login(request: LocalLoginRequest, response: Response):
|
||||
"""使用绑定的账号密码登录(LinuxDO授权后绑定的账号)"""
|
||||
# 查找用户
|
||||
all_users = await user_manager.get_all_users()
|
||||
target_user = None
|
||||
|
||||
logger.info(f"[绑定账号登录] 尝试登录用户名: {request.username}")
|
||||
logger.info(f"[绑定账号登录] 当前共有 {len(all_users)} 个用户")
|
||||
|
||||
for user in all_users:
|
||||
# 同时检查 users 表的 username 和 user_passwords 表的 username
|
||||
password_username = await password_manager.get_username(user.user_id)
|
||||
logger.info(f"[绑定账号登录] 检查用户 {user.user_id}: users.username={user.username}, passwords.username={password_username}")
|
||||
|
||||
if user.username == request.username or password_username == request.username:
|
||||
target_user = user
|
||||
logger.info(f"[绑定账号登录] 找到匹配用户: {user.user_id}")
|
||||
break
|
||||
|
||||
if not target_user:
|
||||
logger.warning(f"[绑定账号登录] 用户名 {request.username} 未找到")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 检查是否有密码记录
|
||||
has_pwd = await password_manager.has_password(target_user.user_id)
|
||||
if not has_pwd:
|
||||
logger.warning(f"[绑定账号登录] 用户 {target_user.user_id} 没有设置密码")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
is_valid = await password_manager.verify_password(target_user.user_id, request.password)
|
||||
logger.info(f"[绑定账号登录] 用户 {target_user.user_id} 密码验证结果: {is_valid}")
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 初始化用户数据库
|
||||
try:
|
||||
await init_db(target_user.user_id)
|
||||
logger.info(f"绑定账号用户 {target_user.user_id} 数据库初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"绑定账号用户 {target_user.user_id} 数据库初始化失败: {e}")
|
||||
|
||||
# 设置 Cookie(2小时有效)
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=target_user.user_id,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 设置过期时间戳 Cookie(用于前端判断)
|
||||
china_now = get_china_now()
|
||||
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
|
||||
expire_at = int(expire_time.timestamp())
|
||||
|
||||
logger.info(f"✅ [绑定账号登录] 用户 {target_user.user_id} ({request.username}) 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
|
||||
|
||||
response.set_cookie(
|
||||
key="session_expire_at",
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False, # 前端需要读取
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return LocalLoginResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
user=target_user.dict()
|
||||
)
|
||||
+296
-10
@@ -1,6 +1,5 @@
|
||||
"""章节管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Query, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
import json
|
||||
@@ -19,6 +18,7 @@ from app.models.writing_style import WritingStyle
|
||||
from app.models.analysis_task import AnalysisTask
|
||||
from app.models.memory import PlotAnalysis, StoryMemory
|
||||
from app.models.batch_generation_task import BatchGenerationTask
|
||||
from app.models.regeneration_task import RegenerationTask
|
||||
from app.schemas.chapter import (
|
||||
ChapterCreate,
|
||||
ChapterUpdate,
|
||||
@@ -29,12 +29,19 @@ from app.schemas.chapter import (
|
||||
BatchGenerateResponse,
|
||||
BatchGenerateStatusResponse
|
||||
)
|
||||
from app.schemas.regeneration import (
|
||||
ChapterRegenerateRequest,
|
||||
RegenerationTaskResponse,
|
||||
RegenerationTaskStatus
|
||||
)
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.services.plot_analyzer import PlotAnalyzer
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.chapter_regenerator import ChapterRegenerator
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.utils.sse_response import create_sse_response
|
||||
|
||||
router = APIRouter(prefix="/chapters", tags=["章节管理"])
|
||||
logger = get_logger(__name__)
|
||||
@@ -1284,15 +1291,7 @@ async def generate_chapter_content_stream(
|
||||
except:
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
return create_sse_response(event_generator())
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
||||
@@ -2293,3 +2292,290 @@ async def generate_single_chapter_for_batch(
|
||||
await db_session.refresh(chapter)
|
||||
|
||||
logger.info(f"✅ 单章节生成完成: 第{chapter.chapter_number}章,共 {new_word_count} 字")
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 章节重新生成相关API ====================
|
||||
|
||||
@router.post("/{chapter_id}/regenerate-stream", summary="流式重新生成章节内容")
|
||||
async def regenerate_chapter_stream(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
regenerate_request: ChapterRegenerateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
根据分析建议或自定义指令重新生成章节内容(流式返回)
|
||||
|
||||
工作流程:
|
||||
1. 验证章节和分析结果
|
||||
2. 创建重新生成任务
|
||||
3. 构建修改指令
|
||||
4. 流式生成新内容
|
||||
5. 保存为版本历史
|
||||
6. 可选自动应用
|
||||
"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证章节存在
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
if not chapter.content or chapter.content.strip() == "":
|
||||
raise HTTPException(status_code=400, detail="章节内容为空,无法重新生成")
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取分析结果(如果使用分析建议)
|
||||
analysis = None
|
||||
if regenerate_request.modification_source in ['analysis_suggestions', 'mixed']:
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
.where(PlotAnalysis.chapter_id == chapter_id)
|
||||
.order_by(PlotAnalysis.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
analysis = analysis_result.scalar_one_or_none()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(status_code=404, detail="该章节暂无分析结果")
|
||||
|
||||
# 预先获取项目上下文数据
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
# 获取项目信息
|
||||
project_result = await temp_db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
|
||||
# 获取角色信息
|
||||
characters_result = await temp_db.execute(
|
||||
select(Character).where(Character.project_id == chapter.project_id)
|
||||
)
|
||||
characters = characters_result.scalars().all()
|
||||
|
||||
# 获取章节大纲
|
||||
outline_result = await temp_db.execute(
|
||||
select(Outline)
|
||||
.where(Outline.project_id == chapter.project_id)
|
||||
.where(Outline.order_index == chapter.chapter_number)
|
||||
)
|
||||
outline = outline_result.scalar_one_or_none()
|
||||
|
||||
# 构建项目上下文
|
||||
project_context = {
|
||||
'project_title': project.title if project else '未知',
|
||||
'genre': project.genre if project else '未设定',
|
||||
'theme': project.theme if project else '未设定',
|
||||
'narrative_perspective': project.narrative_perspective if project else '第三人称',
|
||||
'time_period': project.world_time_period if project else '未设定',
|
||||
'location': project.world_location if project else '未设定',
|
||||
'atmosphere': project.world_atmosphere if project else '未设定',
|
||||
'characters_info': "\n".join([
|
||||
f"- {c.name}({'组织' if c.is_organization else '角色'}, {c.role_type}): {c.personality[:100] if c.personality else ''}"
|
||||
for c in characters
|
||||
]) if characters else '暂无角色信息',
|
||||
'chapter_outline': outline.content if outline else chapter.summary or '暂无大纲',
|
||||
'previous_context': '' # 可以后续扩展添加前置章节上下文
|
||||
}
|
||||
finally:
|
||||
await temp_db.close()
|
||||
break
|
||||
|
||||
async def event_generator():
|
||||
"""流式生成事件生成器"""
|
||||
db_session = None
|
||||
db_committed = False
|
||||
|
||||
try:
|
||||
# 创建独立数据库会话
|
||||
async for db_session in get_db(request):
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 创建重新生成任务
|
||||
regen_task = RegenerationTask(
|
||||
chapter_id=chapter_id,
|
||||
analysis_id=analysis.id if analysis else None,
|
||||
user_id=user_id,
|
||||
project_id=chapter.project_id,
|
||||
modification_instructions="", # 稍后填充
|
||||
original_suggestions=analysis.suggestions if analysis else None,
|
||||
selected_suggestion_indices=regenerate_request.selected_suggestion_indices,
|
||||
custom_instructions=regenerate_request.custom_instructions,
|
||||
style_id=regenerate_request.style_id,
|
||||
target_word_count=regenerate_request.target_word_count,
|
||||
focus_areas=regenerate_request.focus_areas,
|
||||
preserve_elements=regenerate_request.preserve_elements.model_dump() if regenerate_request.preserve_elements else None,
|
||||
status='running',
|
||||
original_content=chapter.content,
|
||||
original_word_count=chapter.word_count or len(chapter.content),
|
||||
version_note=regenerate_request.version_note,
|
||||
started_at=datetime.now()
|
||||
)
|
||||
db_session.add(regen_task)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(regen_task)
|
||||
|
||||
task_id = regen_task.id
|
||||
logger.info(f"📝 创建重新生成任务: {task_id}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 初始化重新生成器
|
||||
regenerator = ChapterRegenerator(user_ai_service)
|
||||
|
||||
# 流式生成新内容
|
||||
full_content = ""
|
||||
async for event in regenerator.regenerate_with_feedback(
|
||||
chapter=chapter,
|
||||
analysis=analysis,
|
||||
regenerate_request=regenerate_request,
|
||||
project_context=project_context
|
||||
):
|
||||
# 处理不同类型的事件
|
||||
if event['type'] == 'chunk':
|
||||
# 内容块
|
||||
chunk = event['content']
|
||||
full_content += chunk
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
elif event['type'] == 'progress':
|
||||
# 进度更新
|
||||
progress_data = {
|
||||
'type': 'progress',
|
||||
'progress': event.get('progress', 0),
|
||||
'message': event.get('message', ''),
|
||||
'word_count': event.get('word_count', 0)
|
||||
}
|
||||
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 更新任务状态
|
||||
regen_task.status = 'completed'
|
||||
regen_task.regenerated_content = full_content
|
||||
regen_task.regenerated_word_count = len(full_content)
|
||||
regen_task.completed_at = datetime.now()
|
||||
|
||||
# 计算差异统计
|
||||
diff_stats = regenerator.calculate_content_diff(chapter.content, full_content)
|
||||
|
||||
await db_session.commit()
|
||||
db_committed = True
|
||||
|
||||
# 先发送结果数据
|
||||
result_data = {
|
||||
'type': 'result',
|
||||
'data': {
|
||||
'task_id': task_id,
|
||||
'word_count': len(full_content),
|
||||
'version_number': regen_task.version_number,
|
||||
'auto_applied': regenerate_request.auto_apply,
|
||||
'diff_stats': diff_stats
|
||||
}
|
||||
}
|
||||
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 再发送完成事件
|
||||
completion_data = {
|
||||
'type': 'done',
|
||||
'message': '重新生成完成'
|
||||
}
|
||||
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
|
||||
|
||||
# 更新任务状态为失败
|
||||
if db_session and not db_committed:
|
||||
try:
|
||||
task_result = await db_session.execute(
|
||||
select(RegenerationTask).where(RegenerationTask.chapter_id == chapter_id)
|
||||
.order_by(RegenerationTask.created_at.desc()).limit(1)
|
||||
)
|
||||
task = task_result.scalar_one_or_none()
|
||||
if task:
|
||||
task.status = 'failed'
|
||||
task.error_message = str(e)[:500]
|
||||
task.completed_at = datetime.now()
|
||||
await db_session.commit()
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新任务失败状态失败: {str(update_error)}")
|
||||
|
||||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
finally:
|
||||
if db_session:
|
||||
try:
|
||||
if not db_committed and db_session.in_transaction():
|
||||
await db_session.rollback()
|
||||
await db_session.close()
|
||||
except Exception as close_error:
|
||||
logger.error(f"关闭数据库会话失败: {str(close_error)}")
|
||||
|
||||
return create_sse_response(event_generator())
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/regeneration/tasks", summary="获取章节的重新生成任务列表")
|
||||
async def get_regeneration_tasks(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(10, ge=1, le=50),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定章节的重新生成任务历史"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证章节存在和权限
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取任务列表
|
||||
result = await db.execute(
|
||||
select(RegenerationTask)
|
||||
.where(RegenerationTask.chapter_id == chapter_id)
|
||||
.order_by(RegenerationTask.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return {
|
||||
"chapter_id": chapter_id,
|
||||
"total": len(tasks),
|
||||
"tasks": [
|
||||
{
|
||||
"task_id": task.id,
|
||||
"status": task.status,
|
||||
"version_number": task.version_number,
|
||||
"version_note": task.version_note,
|
||||
"original_word_count": task.original_word_count,
|
||||
"regenerated_word_count": task.regenerated_word_count,
|
||||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||||
"completed_at": task.completed_at.isoformat() if task.completed_at else None
|
||||
}
|
||||
for task in tasks
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -1135,386 +1135,3 @@ async def generate_outline_stream(
|
||||
"""
|
||||
return create_sse_response(outline_generator(data, db, user_ai_service))
|
||||
|
||||
|
||||
async def update_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""更新世界观流式生成器"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始更新世界观...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
yield await SSEResponse.send_progress("验证数据...", 30)
|
||||
|
||||
# 更新世界观字段
|
||||
if "time_period" in data:
|
||||
project.world_time_period = data["time_period"]
|
||||
if "location" in data:
|
||||
project.world_location = data["location"]
|
||||
if "atmosphere" in data:
|
||||
project.world_atmosphere = data["atmosphere"]
|
||||
if "rules" in data:
|
||||
project.world_rules = data["rules"]
|
||||
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 70)
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
"time_period": project.world_time_period,
|
||||
"location": project.world_location,
|
||||
"atmosphere": project.world_atmosphere,
|
||||
"rules": project.world_rules
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("更新世界观生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("更新世界观事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"更新世界观失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("更新世界观事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"更新失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/world-building/{project_id}", summary="流式更新世界观")
|
||||
async def update_world_building_stream(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用SSE流式更新项目的世界观信息
|
||||
请求体格式:
|
||||
{
|
||||
"time_period": "时间背景",
|
||||
"location": "地理位置",
|
||||
"atmosphere": "氛围基调",
|
||||
"rules": "世界规则"
|
||||
}
|
||||
"""
|
||||
return create_sse_response(update_world_building_generator(project_id, data, db))
|
||||
|
||||
|
||||
async def regenerate_world_building_generator(
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession,
|
||||
user_ai_service: AIService
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""重新生成世界观流式生成器 - 支持MCP工具增强"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始重新生成世界观...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
provider = data.get("provider")
|
||||
model = data.get("model")
|
||||
enable_mcp = data.get("enable_mcp", True) # 默认启用MCP
|
||||
user_id = data.get("user_id") # 从中间件注入
|
||||
|
||||
# 获取基础提示词
|
||||
yield await SSEResponse.send_progress("准备AI提示词...", 15)
|
||||
base_prompt = prompt_service.get_world_building_prompt(
|
||||
title=project.title,
|
||||
theme=project.theme or "",
|
||||
genre=project.genre or ""
|
||||
)
|
||||
|
||||
# MCP工具增强:收集参考资料
|
||||
reference_materials = ""
|
||||
if enable_mcp and user_id:
|
||||
try:
|
||||
yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18)
|
||||
|
||||
# 直接调用MCP增强的AI,内部会自动检查和加载工具
|
||||
# 构建资料收集提示词
|
||||
planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。
|
||||
|
||||
【小说信息】
|
||||
- 题材:{project.genre or '未设定'}
|
||||
- 主题:{project.theme or '未设定'}
|
||||
|
||||
【任务】
|
||||
请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。
|
||||
你可以查询:
|
||||
1. 历史背景(如果是历史题材)
|
||||
2. 地理环境和文化特征
|
||||
3. 相关领域的专业知识
|
||||
4. 类似作品的设定参考
|
||||
|
||||
请根据题材特点,有针对性地查询2-3个关键问题。"""
|
||||
|
||||
# 调用MCP增强的AI(非流式,最多2轮工具调用)
|
||||
planning_result = await user_ai_service.generate_text_with_mcp(
|
||||
prompt=planning_prompt,
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=2,
|
||||
tool_choice="auto",
|
||||
provider=None,
|
||||
model=None
|
||||
)
|
||||
|
||||
# 提取参考资料
|
||||
if planning_result.get("tool_calls_made", 0) > 0:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)",
|
||||
25
|
||||
)
|
||||
reference_materials = planning_result.get("content", "")
|
||||
else:
|
||||
yield await SSEResponse.send_progress("ℹ️ 未使用MCP工具(无可用工具或不需要)", 25)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP工具调用失败(降级处理): {e}")
|
||||
yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25)
|
||||
|
||||
# 构建增强提示词
|
||||
if reference_materials:
|
||||
enhanced_prompt = f"""{base_prompt}
|
||||
|
||||
【参考资料】
|
||||
以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观:
|
||||
|
||||
{reference_materials}
|
||||
|
||||
请结合上述资料,生成符合历史/现实的世界观设定。"""
|
||||
final_prompt = enhanced_prompt
|
||||
yield await SSEResponse.send_progress("💡 已整合参考资料,开始重新生成世界观...", 30)
|
||||
else:
|
||||
final_prompt = base_prompt
|
||||
yield await SSEResponse.send_progress("正在调用AI生成...", 30)
|
||||
|
||||
# 流式生成世界观
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
|
||||
async for chunk in user_ai_service.generate_text_stream(
|
||||
prompt=final_prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
):
|
||||
chunk_count += 1
|
||||
accumulated_text += chunk
|
||||
|
||||
# 发送内容块
|
||||
yield await SSEResponse.send_chunk(chunk)
|
||||
|
||||
# 定期更新进度
|
||||
if chunk_count % 5 == 0:
|
||||
progress = min(30 + (chunk_count // 5), 70)
|
||||
yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress)
|
||||
|
||||
# 每20个块发送心跳
|
||||
if chunk_count % 20 == 0:
|
||||
yield await SSEResponse.send_heartbeat()
|
||||
|
||||
# 解析结果
|
||||
yield await SSEResponse.send_progress("解析AI返回结果...", 80)
|
||||
|
||||
world_data = {}
|
||||
try:
|
||||
cleaned_text = accumulated_text.strip()
|
||||
# 移除markdown代码块标记
|
||||
if cleaned_text.startswith('```json'):
|
||||
cleaned_text = cleaned_text[7:].lstrip('\n\r')
|
||||
elif cleaned_text.startswith('```'):
|
||||
cleaned_text = cleaned_text[3:].lstrip('\n\r')
|
||||
if cleaned_text.endswith('```'):
|
||||
cleaned_text = cleaned_text[:-3].rstrip('\n\r')
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
world_data = json.loads(cleaned_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"AI返回非JSON格式: {e}")
|
||||
logger.info(world_data)
|
||||
world_data = {
|
||||
"time_period": "AI返回格式错误,请重试",
|
||||
"location": "AI返回格式错误,请重试",
|
||||
"atmosphere": "AI返回格式错误,请重试",
|
||||
"rules": "AI返回格式错误,请重试"
|
||||
}
|
||||
|
||||
# 更新项目世界观
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
|
||||
project.world_time_period = world_data.get("time_period")
|
||||
project.world_location = world_data.get("location")
|
||||
project.world_atmosphere = world_data.get("atmosphere")
|
||||
project.world_rules = world_data.get("rules")
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
await db.refresh(project)
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"project_id": project.id,
|
||||
"time_period": project.world_time_period,
|
||||
"location": project.world_location,
|
||||
"atmosphere": project.world_atmosphere,
|
||||
"rules": project.world_rules
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("重新生成世界观生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("重新生成世界观事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"重新生成世界观失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("重新生成世界观事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"重新生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观")
|
||||
async def regenerate_world_building_stream(
|
||||
request: Request,
|
||||
project_id: str,
|
||||
data: Dict[str, Any],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
"""
|
||||
使用SSE流式重新生成项目的世界观
|
||||
请求体格式:
|
||||
{
|
||||
"provider": "AI提供商(可选)",
|
||||
"model": "模型名称(可选)"
|
||||
}
|
||||
"""
|
||||
# 从中间件注入user_id到data中
|
||||
if hasattr(request.state, 'user_id'):
|
||||
data['user_id'] = request.state.user_id
|
||||
|
||||
return create_sse_response(regenerate_world_building_generator(project_id, data, db, user_ai_service))
|
||||
|
||||
|
||||
async def cleanup_wizard_data_generator(
|
||||
project_id: str,
|
||||
db: AsyncSession
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""清理向导数据流式生成器"""
|
||||
db_committed = False
|
||||
try:
|
||||
yield await SSEResponse.send_progress("开始清理向导数据...", 10)
|
||||
|
||||
# 获取项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
yield await SSEResponse.send_error("项目不存在", 404)
|
||||
return
|
||||
|
||||
# 删除相关的角色
|
||||
yield await SSEResponse.send_progress("删除角色数据...", 30)
|
||||
characters = await db.execute(
|
||||
select(Character).where(Character.project_id == project_id)
|
||||
)
|
||||
char_count = 0
|
||||
for character in characters.scalars():
|
||||
await db.delete(character)
|
||||
char_count += 1
|
||||
|
||||
# 删除相关的大纲
|
||||
yield await SSEResponse.send_progress("删除大纲数据...", 50)
|
||||
outlines = await db.execute(
|
||||
select(Outline).where(Outline.project_id == project_id)
|
||||
)
|
||||
outline_count = 0
|
||||
for outline in outlines.scalars():
|
||||
await db.delete(outline)
|
||||
outline_count += 1
|
||||
|
||||
# 删除相关的章节
|
||||
yield await SSEResponse.send_progress("删除章节数据...", 70)
|
||||
chapters = await db.execute(
|
||||
select(Chapter).where(Chapter.project_id == project_id)
|
||||
)
|
||||
chapter_count = 0
|
||||
for chapter in chapters.scalars():
|
||||
await db.delete(chapter)
|
||||
chapter_count += 1
|
||||
|
||||
# 删除项目
|
||||
yield await SSEResponse.send_progress("删除项目...", 85)
|
||||
await db.delete(project)
|
||||
|
||||
yield await SSEResponse.send_progress("提交数据库更改...", 95)
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
|
||||
# 发送结果
|
||||
yield await SSEResponse.send_result({
|
||||
"message": "项目及相关数据已清理",
|
||||
"deleted": {
|
||||
"characters": char_count,
|
||||
"outlines": outline_count,
|
||||
"chapters": chapter_count
|
||||
}
|
||||
})
|
||||
|
||||
yield await SSEResponse.send_progress("完成!", 100, "success")
|
||||
yield await SSEResponse.send_done()
|
||||
|
||||
except GeneratorExit:
|
||||
logger.warning("清理向导数据生成器被提前关闭")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("清理向导数据事务已回滚(GeneratorExit)")
|
||||
except Exception as e:
|
||||
logger.error(f"清理数据失败: {str(e)}")
|
||||
if not db_committed and db.in_transaction():
|
||||
await db.rollback()
|
||||
logger.info("清理向导数据事务已回滚(异常)")
|
||||
yield await SSEResponse.send_error(f"清理失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/cleanup/{project_id}", summary="流式清理向导数据")
|
||||
async def cleanup_wizard_data_stream(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
使用SSE流式清理向导过程中创建的项目及相关数据
|
||||
用于返回上一步时清理已生成的内容
|
||||
"""
|
||||
return create_sse_response(cleanup_wizard_data_generator(project_id, db))
|
||||
@@ -110,6 +110,7 @@ class Settings(BaseSettings):
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
extra = "ignore" # 忽略未定义的环境变量,避免验证错误
|
||||
|
||||
|
||||
# 创建全局配置实例
|
||||
|
||||
@@ -21,7 +21,8 @@ from app.models import (
|
||||
Project, Outline, Character, Chapter, GenerationHistory,
|
||||
Settings, WritingStyle, ProjectDefaultStyle,
|
||||
RelationshipType, CharacterRelationship, Organization, OrganizationMember,
|
||||
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask
|
||||
StoryMemory, PlotAnalysis, AnalysisTask, BatchGenerationTask,
|
||||
RegenerationTask
|
||||
)
|
||||
|
||||
# 引擎缓存:每个用户一个引擎
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.writing_style import WritingStyle
|
||||
from app.models.project_default_style import ProjectDefaultStyle
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.models.user import User, UserPassword
|
||||
from app.models.regeneration_task import RegenerationTask
|
||||
|
||||
__all__ = [
|
||||
"Project",
|
||||
@@ -30,5 +32,8 @@ __all__ = [
|
||||
"PlotAnalysis",
|
||||
"WritingStyle",
|
||||
"ProjectDefaultStyle",
|
||||
"MCPPlugin"
|
||||
"MCPPlugin",
|
||||
"User",
|
||||
"UserPassword",
|
||||
"RegenerationTask"
|
||||
]
|
||||
@@ -0,0 +1,51 @@
|
||||
"""章节重新生成任务模型"""
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, JSON, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class RegenerationTask(Base):
|
||||
"""章节重新生成任务表"""
|
||||
__tablename__ = "regeneration_tasks"
|
||||
|
||||
# 基本信息
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
chapter_id = Column(String(36), ForeignKey('chapters.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
analysis_id = Column(String(36), nullable=True, comment="关联的分析结果ID")
|
||||
user_id = Column(String(50), nullable=False, index=True)
|
||||
project_id = Column(String(36), nullable=False, index=True)
|
||||
|
||||
# 修改指令
|
||||
modification_instructions = Column(Text, nullable=False, comment="综合修改指令")
|
||||
original_suggestions = Column(JSON, comment="来自分析的原始建议列表")
|
||||
selected_suggestion_indices = Column(JSON, comment="用户选择的建议索引")
|
||||
custom_instructions = Column(Text, comment="用户自定义修改意见")
|
||||
|
||||
# 生成参数
|
||||
style_id = Column(Integer, nullable=True, comment="写作风格ID")
|
||||
target_word_count = Column(Integer, default=3000, comment="目标字数")
|
||||
focus_areas = Column(JSON, comment="重点优化方向")
|
||||
preserve_elements = Column(JSON, comment="需要保留的元素配置")
|
||||
|
||||
# 状态跟踪
|
||||
status = Column(String(20), default='pending', comment="pending/running/completed/failed")
|
||||
progress = Column(Integer, default=0, comment="进度 0-100")
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# 内容版本
|
||||
original_content = Column(Text, comment="原始章节内容快照")
|
||||
original_word_count = Column(Integer, comment="原始字数")
|
||||
regenerated_content = Column(Text, comment="重新生成的内容")
|
||||
regenerated_word_count = Column(Integer, comment="新内容字数")
|
||||
version_number = Column(Integer, default=1, comment="版本号")
|
||||
version_note = Column(String(500), comment="版本说明")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, server_default=func.now())
|
||||
started_at = Column(DateTime, nullable=True)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<RegenerationTask(id={self.id[:8]}..., chapter_id={self.chapter_id[:8]}..., status={self.status})>"
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
用户数据模型 - 存储用户基本信息
|
||||
"""
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户模型 - 存储OAuth和本地用户信息"""
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID,格式:linuxdo_{id} 或 local_{id}")
|
||||
username = Column(String(100), nullable=False, index=True, comment="用户名")
|
||||
display_name = Column(String(200), nullable=False, comment="显示名称")
|
||||
avatar_url = Column(String(500), nullable=True, comment="头像URL")
|
||||
trust_level = Column(Integer, default=0, comment="信任等级(仅用于显示)")
|
||||
is_admin = Column(Boolean, default=False, comment="是否为管理员")
|
||||
linuxdo_id = Column(String(100), nullable=False, unique=True, index=True, comment="LinuxDO用户ID或本地用户ID")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||
last_login = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后登录时间")
|
||||
|
||||
def to_dict(self):
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"username": self.username,
|
||||
"display_name": self.display_name,
|
||||
"avatar_url": self.avatar_url,
|
||||
"trust_level": self.trust_level,
|
||||
"is_admin": self.is_admin,
|
||||
"linuxdo_id": self.linuxdo_id,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
}
|
||||
|
||||
|
||||
class UserPassword(Base):
|
||||
"""用户密码模型 - 存储用户密码信息"""
|
||||
__tablename__ = "user_passwords"
|
||||
|
||||
user_id = Column(String(100), primary_key=True, index=True, comment="用户ID")
|
||||
username = Column(String(100), nullable=False, comment="用户名")
|
||||
password_hash = Column(String(64), nullable=False, comment="密码哈希(SHA256)")
|
||||
has_custom_password = Column(Boolean, default=False, comment="是否为自定义密码")
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
@@ -0,0 +1,65 @@
|
||||
"""章节重新生成相关的Schema定义"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class PreserveElementsConfig(BaseModel):
|
||||
"""保留元素配置"""
|
||||
preserve_structure: bool = Field(False, description="是否保留整体结构")
|
||||
preserve_dialogues: List[str] = Field(default_factory=list, description="需要保留的对话片段关键词")
|
||||
preserve_plot_points: List[str] = Field(default_factory=list, description="需要保留的情节点关键词")
|
||||
preserve_character_traits: bool = Field(True, description="保持角色性格一致")
|
||||
|
||||
|
||||
class ChapterRegenerateRequest(BaseModel):
|
||||
"""章节重新生成请求"""
|
||||
|
||||
# 修改来源
|
||||
modification_source: str = Field("custom", description="修改来源: custom/analysis_suggestions/mixed")
|
||||
|
||||
# 基于分析建议
|
||||
selected_suggestion_indices: Optional[List[int]] = Field(None, description="选中的建议索引列表")
|
||||
|
||||
# 自定义修改指令
|
||||
custom_instructions: Optional[str] = Field(None, description="用户自定义的修改要求")
|
||||
|
||||
# 保留配置
|
||||
preserve_elements: Optional[PreserveElementsConfig] = Field(None, description="保留元素配置")
|
||||
|
||||
# 生成参数
|
||||
style_id: Optional[int] = Field(None, description="写作风格ID")
|
||||
target_word_count: int = Field(3000, description="目标字数", ge=500, le=10000)
|
||||
focus_areas: List[str] = Field(default_factory=list, description="重点优化方向")
|
||||
|
||||
# 版本管理
|
||||
save_as_version: bool = Field(True, description="是否保存为新版本")
|
||||
version_note: Optional[str] = Field(None, description="版本说明", max_length=500)
|
||||
auto_apply: bool = Field(False, description="是否自动应用(替换当前内容)")
|
||||
|
||||
|
||||
class RegenerationTaskResponse(BaseModel):
|
||||
"""重新生成任务响应"""
|
||||
task_id: str
|
||||
chapter_id: str
|
||||
status: str
|
||||
message: str
|
||||
estimated_time_seconds: int = 120
|
||||
|
||||
|
||||
class RegenerationTaskStatus(BaseModel):
|
||||
"""重新生成任务状态"""
|
||||
task_id: str
|
||||
chapter_id: str
|
||||
status: str
|
||||
progress: int
|
||||
error_message: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
# 结果信息
|
||||
original_word_count: Optional[int] = None
|
||||
regenerated_word_count: Optional[int] = None
|
||||
version_number: Optional[int] = None
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
"""章节重新生成服务"""
|
||||
from typing import Dict, Any, AsyncGenerator, Optional, List
|
||||
from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.memory import PlotAnalysis
|
||||
from app.schemas.regeneration import ChapterRegenerateRequest, PreserveElementsConfig
|
||||
from app.logger import get_logger
|
||||
import difflib
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChapterRegenerator:
|
||||
"""章节重新生成服务"""
|
||||
|
||||
def __init__(self, ai_service: AIService):
|
||||
self.ai_service = ai_service
|
||||
logger.info("✅ ChapterRegenerator初始化成功")
|
||||
|
||||
async def regenerate_with_feedback(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
analysis: Optional[PlotAnalysis],
|
||||
regenerate_request: ChapterRegenerateRequest,
|
||||
project_context: Dict[str, Any]
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
根据反馈重新生成章节(流式)
|
||||
|
||||
Args:
|
||||
chapter: 原始章节对象
|
||||
analysis: 分析结果(可选)
|
||||
regenerate_request: 重新生成请求参数
|
||||
project_context: 项目上下文(项目信息、角色、大纲等)
|
||||
|
||||
Yields:
|
||||
包含类型和数据的字典: {'type': 'progress'/'chunk', 'data': ...}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重新生成章节: 第{chapter.chapter_number}章")
|
||||
|
||||
# 1. 构建修改指令
|
||||
yield {'type': 'progress', 'progress': 5, 'message': '正在构建修改指令...'}
|
||||
modification_instructions = self._build_modification_instructions(
|
||||
analysis=analysis,
|
||||
regenerate_request=regenerate_request
|
||||
)
|
||||
|
||||
logger.info(f"📝 修改指令构建完成,长度: {len(modification_instructions)}字符")
|
||||
|
||||
# 2. 构建完整提示词
|
||||
yield {'type': 'progress', 'progress': 10, 'message': '正在构建生成提示词...'}
|
||||
full_prompt = self._build_regeneration_prompt(
|
||||
chapter=chapter,
|
||||
modification_instructions=modification_instructions,
|
||||
project_context=project_context,
|
||||
regenerate_request=regenerate_request
|
||||
)
|
||||
|
||||
logger.info(f"🎯 提示词构建完成,开始AI生成")
|
||||
yield {'type': 'progress', 'progress': 15, 'message': '开始AI生成内容...'}
|
||||
|
||||
# 3. 流式生成新内容,同时跟踪进度
|
||||
target_word_count = regenerate_request.target_word_count
|
||||
accumulated_length = 0
|
||||
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=full_prompt,
|
||||
temperature=0.7
|
||||
):
|
||||
# 发送内容块
|
||||
yield {'type': 'chunk', 'content': chunk}
|
||||
|
||||
# 更新累积字数并计算进度(15%-95%)
|
||||
accumulated_length += len(chunk)
|
||||
# 进度从15%开始,到95%结束,为后处理预留5%
|
||||
generation_progress = min(15 + (accumulated_length / target_word_count) * 80, 95)
|
||||
yield {'type': 'progress', 'progress': int(generation_progress), 'word_count': accumulated_length}
|
||||
|
||||
logger.info(f"✅ 章节重新生成完成,共生成 {accumulated_length} 字")
|
||||
yield {'type': 'progress', 'progress': 100, 'message': '生成完成'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重新生成失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _build_modification_instructions(
|
||||
self,
|
||||
analysis: Optional[PlotAnalysis],
|
||||
regenerate_request: ChapterRegenerateRequest
|
||||
) -> str:
|
||||
"""构建修改指令"""
|
||||
|
||||
instructions = []
|
||||
|
||||
# 标题
|
||||
instructions.append("# 章节修改指令\n")
|
||||
|
||||
# 1. 来自分析的建议
|
||||
if (analysis and
|
||||
regenerate_request.selected_suggestion_indices and
|
||||
analysis.suggestions):
|
||||
|
||||
instructions.append("## 📋 需要改进的问题(来自AI分析):\n")
|
||||
for idx in regenerate_request.selected_suggestion_indices:
|
||||
if 0 <= idx < len(analysis.suggestions):
|
||||
suggestion = analysis.suggestions[idx]
|
||||
instructions.append(f"{idx + 1}. {suggestion}")
|
||||
instructions.append("")
|
||||
|
||||
# 2. 用户自定义指令
|
||||
if regenerate_request.custom_instructions:
|
||||
instructions.append("## ✍️ 用户自定义修改要求:\n")
|
||||
instructions.append(regenerate_request.custom_instructions)
|
||||
instructions.append("")
|
||||
|
||||
# 3. 重点优化方向
|
||||
if regenerate_request.focus_areas:
|
||||
instructions.append("## 🎯 重点优化方向:\n")
|
||||
focus_map = {
|
||||
"pacing": "节奏把控 - 调整叙事速度,避免拖沓或过快",
|
||||
"emotion": "情感渲染 - 深化人物情感表达,增强感染力",
|
||||
"description": "场景描写 - 丰富环境细节,增强画面感",
|
||||
"dialogue": "对话质量 - 让对话更自然真实,推动剧情",
|
||||
"conflict": "冲突强度 - 强化矛盾冲突,提升戏剧张力"
|
||||
}
|
||||
|
||||
for area in regenerate_request.focus_areas:
|
||||
if area in focus_map:
|
||||
instructions.append(f"- {focus_map[area]}")
|
||||
instructions.append("")
|
||||
|
||||
# 4. 保留要求
|
||||
if regenerate_request.preserve_elements:
|
||||
preserve = regenerate_request.preserve_elements
|
||||
instructions.append("## 🔒 必须保留的元素:\n")
|
||||
|
||||
if preserve.preserve_structure:
|
||||
instructions.append("- 保持原章节的整体结构和情节框架")
|
||||
|
||||
if preserve.preserve_dialogues:
|
||||
instructions.append("- 必须保留以下关键对话:")
|
||||
for dialogue in preserve.preserve_dialogues:
|
||||
instructions.append(f" * {dialogue}")
|
||||
|
||||
if preserve.preserve_plot_points:
|
||||
instructions.append("- 必须保留以下关键情节点:")
|
||||
for plot in preserve.preserve_plot_points:
|
||||
instructions.append(f" * {plot}")
|
||||
|
||||
if preserve.preserve_character_traits:
|
||||
instructions.append("- 保持所有角色的性格特征和行为模式一致")
|
||||
|
||||
instructions.append("")
|
||||
|
||||
return "\n".join(instructions)
|
||||
|
||||
def _build_regeneration_prompt(
|
||||
self,
|
||||
chapter: Chapter,
|
||||
modification_instructions: str,
|
||||
project_context: Dict[str, Any],
|
||||
regenerate_request: ChapterRegenerateRequest
|
||||
) -> str:
|
||||
"""构建完整的重新生成提示词"""
|
||||
|
||||
prompt_parts = []
|
||||
|
||||
# 系统角色
|
||||
prompt_parts.append("""你是一位经验丰富的专业小说编辑和作家。现在需要根据反馈意见重新创作一个章节。
|
||||
|
||||
你的任务是:
|
||||
1. 仔细理解原章节的内容和意图
|
||||
2. 认真分析所有的修改要求
|
||||
3. 在保持故事连贯性的前提下,创作一个改进后的新版本
|
||||
4. 确保新版本在艺术性和可读性上都有明显提升
|
||||
|
||||
---
|
||||
""")
|
||||
|
||||
# 原始章节信息
|
||||
prompt_parts.append(f"""## 📖 原始章节信息
|
||||
|
||||
**章节**:第{chapter.chapter_number}章
|
||||
**标题**:{chapter.title}
|
||||
**字数**:{chapter.word_count}字
|
||||
|
||||
**原始内容**:
|
||||
{chapter.content}
|
||||
|
||||
---
|
||||
""")
|
||||
|
||||
# 修改指令
|
||||
prompt_parts.append(modification_instructions)
|
||||
prompt_parts.append("\n---\n")
|
||||
|
||||
# 项目背景信息
|
||||
prompt_parts.append(f"""## 🌍 项目背景信息
|
||||
|
||||
**小说标题**:{project_context.get('project_title', '未知')}
|
||||
**题材**:{project_context.get('genre', '未设定')}
|
||||
**主题**:{project_context.get('theme', '未设定')}
|
||||
**叙事视角**:{project_context.get('narrative_perspective', '第三人称')}
|
||||
**世界观设定**:
|
||||
- 时代背景:{project_context.get('time_period', '未设定')}
|
||||
- 地理位置:{project_context.get('location', '未设定')}
|
||||
- 氛围基调:{project_context.get('atmosphere', '未设定')}
|
||||
|
||||
---
|
||||
""")
|
||||
|
||||
# 角色信息
|
||||
if project_context.get('characters_info'):
|
||||
prompt_parts.append(f"""## 👥 角色信息
|
||||
|
||||
{project_context['characters_info']}
|
||||
|
||||
---
|
||||
""")
|
||||
|
||||
# 章节大纲
|
||||
if project_context.get('chapter_outline'):
|
||||
prompt_parts.append(f"""## 📝 本章大纲
|
||||
|
||||
{project_context['chapter_outline']}
|
||||
|
||||
---
|
||||
""")
|
||||
|
||||
# 前置章节上下文
|
||||
if project_context.get('previous_context'):
|
||||
prompt_parts.append(f"""## 📚 前置章节上下文
|
||||
|
||||
{project_context['previous_context']}
|
||||
|
||||
---
|
||||
""")
|
||||
|
||||
# 创作要求
|
||||
prompt_parts.append(f"""## ✨ 创作要求
|
||||
|
||||
1. **解决问题**:针对上述修改指令中提到的所有问题进行改进
|
||||
2. **保持连贯**:确保与前后章节的情节、人物、风格保持一致
|
||||
3. **提升质量**:在节奏、情感、描写等方面明显优于原版
|
||||
4. **保留精华**:保持原章节中优秀的部分和关键情节
|
||||
5. **字数控制**:目标字数约{regenerate_request.target_word_count}字(可适当浮动±20%)
|
||||
|
||||
---
|
||||
|
||||
## 🎬 开始创作
|
||||
|
||||
请现在开始创作改进后的新版本章节内容。
|
||||
|
||||
**重要提示**:
|
||||
- 直接输出章节正文内容,从故事内容开始写
|
||||
- **不要**输出章节标题(如"第X章"、"第X章:XXX"等)
|
||||
- **不要**输出任何额外的说明、注释或元数据
|
||||
- 只需要纯粹的故事正文内容
|
||||
|
||||
现在开始:
|
||||
""")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def calculate_content_diff(
|
||||
self,
|
||||
original_content: str,
|
||||
new_content: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算两个版本的差异
|
||||
|
||||
Returns:
|
||||
差异统计信息
|
||||
"""
|
||||
# 基本统计
|
||||
diff_stats = {
|
||||
'original_length': len(original_content),
|
||||
'new_length': len(new_content),
|
||||
'length_change': len(new_content) - len(original_content),
|
||||
'length_change_percent': round((len(new_content) - len(original_content)) / len(original_content) * 100, 2) if len(original_content) > 0 else 0
|
||||
}
|
||||
|
||||
# 计算相似度
|
||||
similarity = difflib.SequenceMatcher(None, original_content, new_content).ratio()
|
||||
diff_stats['similarity'] = round(similarity * 100, 2)
|
||||
diff_stats['difference'] = round((1 - similarity) * 100, 2)
|
||||
|
||||
# 段落统计
|
||||
original_paragraphs = [p for p in original_content.split('\n\n') if p.strip()]
|
||||
new_paragraphs = [p for p in new_content.split('\n\n') if p.strip()]
|
||||
diff_stats['original_paragraph_count'] = len(original_paragraphs)
|
||||
diff_stats['new_paragraph_count'] = len(new_paragraphs)
|
||||
|
||||
return diff_stats
|
||||
|
||||
|
||||
# 全局实例
|
||||
_regenerator_instance = None
|
||||
|
||||
def get_chapter_regenerator(ai_service: AIService) -> ChapterRegenerator:
|
||||
"""获取章节重新生成器实例"""
|
||||
global _regenerator_instance
|
||||
if _regenerator_instance is None:
|
||||
_regenerator_instance = ChapterRegenerator(ai_service)
|
||||
return _regenerator_instance
|
||||
+145
-213
@@ -1,106 +1,49 @@
|
||||
"""
|
||||
用户管理模块 - 支持 LinuxDO OAuth2
|
||||
用户管理模块 - 使用数据库存储
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from pydantic import BaseModel
|
||||
from app.config import settings, DATA_DIR
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""用户模型"""
|
||||
user_id: str # 格式: linuxdo_{linuxdo_id}
|
||||
"""用户数据传输对象"""
|
||||
user_id: str
|
||||
username: str
|
||||
display_name: str
|
||||
avatar_url: Optional[str] = None
|
||||
trust_level: int = 0 # 仅用于显示
|
||||
is_admin: bool = False # 手动设置的管理员权限
|
||||
linuxdo_id: str # LinuxDO 用户 ID
|
||||
trust_level: int = 0
|
||||
is_admin: bool = False
|
||||
linuxdo_id: str
|
||||
created_at: str
|
||||
last_login: str
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""用户管理器 - 线程安全版本"""
|
||||
|
||||
USERS_FILE = str(DATA_DIR / "users.json")
|
||||
ADMINS_FILE = str(DATA_DIR / "admins.json")
|
||||
"""用户管理器 - 使用数据库存储(PostgreSQL共享库)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化用户管理器"""
|
||||
# DATA_DIR 已在 config.py 中创建,无需重复创建
|
||||
# 添加文件锁保护并发读写
|
||||
self._users_lock = asyncio.Lock()
|
||||
self._admins_lock = asyncio.Lock()
|
||||
self._ensure_files_exist()
|
||||
pass
|
||||
|
||||
def _ensure_files_exist(self):
|
||||
"""确保必要的文件存在"""
|
||||
if not os.path.exists(self.USERS_FILE):
|
||||
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump({}, f, ensure_ascii=False, indent=2)
|
||||
async def _get_session(self) -> AsyncSession:
|
||||
"""获取数据库会话 - 使用共享的PostgreSQL引擎"""
|
||||
from app.database import get_engine
|
||||
|
||||
if not os.path.exists(self.ADMINS_FILE):
|
||||
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump({"admins": []}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def _load_users_unsafe(self) -> Dict[str, dict]:
|
||||
"""加载用户数据(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.USERS_FILE, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"加载用户数据失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_users_unsafe(self, users: Dict[str, dict]):
|
||||
"""保存用户数据(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.USERS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(users, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存用户数据失败: {e}")
|
||||
|
||||
async def _load_users(self) -> Dict[str, dict]:
|
||||
"""加载用户数据(加锁)"""
|
||||
async with self._users_lock:
|
||||
return self._load_users_unsafe()
|
||||
|
||||
async def _save_users(self, users: Dict[str, dict]):
|
||||
"""保存用户数据(加锁)"""
|
||||
async with self._users_lock:
|
||||
self._save_users_unsafe(users)
|
||||
|
||||
def _load_admin_list_unsafe(self) -> List[str]:
|
||||
"""加载管理员列表(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.ADMINS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("admins", [])
|
||||
except Exception as e:
|
||||
print(f"加载管理员列表失败: {e}")
|
||||
return []
|
||||
|
||||
def _save_admin_list_unsafe(self, admin_list: List[str]):
|
||||
"""保存管理员列表(不加锁,内部使用)"""
|
||||
try:
|
||||
with open(self.ADMINS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump({"admins": admin_list}, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"保存管理员列表失败: {e}")
|
||||
|
||||
async def _load_admin_list(self) -> List[str]:
|
||||
"""加载管理员列表(加锁)"""
|
||||
async with self._admins_lock:
|
||||
return self._load_admin_list_unsafe()
|
||||
|
||||
async def _save_admin_list(self, admin_list: List[str]):
|
||||
"""保存管理员列表(加锁)"""
|
||||
async with self._admins_lock:
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
# 使用共享的PostgreSQL引擎(user_id使用特殊标识)
|
||||
engine = await get_engine("_global_users_")
|
||||
|
||||
session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
return session_maker()
|
||||
|
||||
async def create_or_update_from_linuxdo(
|
||||
self,
|
||||
@@ -111,106 +54,97 @@ class UserManager:
|
||||
trust_level: int
|
||||
) -> User:
|
||||
"""
|
||||
从 LinuxDO 用户信息创建或更新用户(线程安全)
|
||||
从 LinuxDO 用户信息创建或更新用户
|
||||
|
||||
Args:
|
||||
linuxdo_id: LinuxDO 用户 ID(本地用户时为 local_xxx 格式)
|
||||
username: 用户名
|
||||
display_name: 显示名称
|
||||
avatar_url: 头像 URL
|
||||
trust_level: 信任等级 (仅用于显示)
|
||||
trust_level: 信任等级
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
"""
|
||||
# 如果已经是 local_ 开头,直接使用;否则添加 linuxdo_ 前缀
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
# 生成 user_id
|
||||
if linuxdo_id.startswith("local_"):
|
||||
user_id = linuxdo_id
|
||||
else:
|
||||
user_id = f"linuxdo_{linuxdo_id}"
|
||||
|
||||
# 使用锁保护整个读-改-写操作
|
||||
async with self._users_lock:
|
||||
async with self._admins_lock:
|
||||
users = self._load_users_unsafe()
|
||||
admin_list = self._load_admin_list_unsafe()
|
||||
async with await self._get_session() as session:
|
||||
# 查询用户是否存在
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.user_id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# 检查是否为初始管理员或本地用户
|
||||
initial_admin_id = settings.INITIAL_ADMIN_LINUXDO_ID
|
||||
is_initial_admin = (initial_admin_id and linuxdo_id == initial_admin_id)
|
||||
is_local_user = user_id.startswith("local_")
|
||||
is_admin = is_initial_admin or is_local_user
|
||||
|
||||
if user:
|
||||
# 更新现有用户
|
||||
user.username = username
|
||||
user.display_name = display_name
|
||||
user.avatar_url = avatar_url
|
||||
user.trust_level = trust_level
|
||||
user.last_login = datetime.now()
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
# 检查是否为初始管理员
|
||||
initial_admin_id = settings.INITIAL_ADMIN_LINUXDO_ID
|
||||
is_initial_admin = (initial_admin_id and linuxdo_id == initial_admin_id)
|
||||
|
||||
# 检查是否为本地用户(所有 local_ 开头的用户默认为管理员)
|
||||
is_local_user = user_id.startswith("local_")
|
||||
|
||||
if user_id in users:
|
||||
# 更新现有用户
|
||||
user_data = users[user_id]
|
||||
user_data["username"] = username
|
||||
user_data["display_name"] = display_name
|
||||
user_data["avatar_url"] = avatar_url
|
||||
user_data["trust_level"] = trust_level
|
||||
user_data["last_login"] = now
|
||||
|
||||
# 如果是初始管理员或本地用户且还不在管理员列表中,添加进去
|
||||
if (is_initial_admin or is_local_user) and user_id not in admin_list:
|
||||
admin_list.append(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
user_data["is_admin"] = True
|
||||
else:
|
||||
# 从管理员列表同步 is_admin 状态
|
||||
user_data["is_admin"] = user_id in admin_list
|
||||
else:
|
||||
# 创建新用户(本地用户默认为管理员)
|
||||
is_admin = is_initial_admin or is_local_user
|
||||
if is_admin and user_id not in admin_list:
|
||||
admin_list.append(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
|
||||
user_data = {
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
"display_name": display_name,
|
||||
"avatar_url": avatar_url,
|
||||
"trust_level": trust_level,
|
||||
"is_admin": is_admin,
|
||||
"linuxdo_id": linuxdo_id,
|
||||
"created_at": now,
|
||||
"last_login": now
|
||||
}
|
||||
users[user_id] = user_data
|
||||
|
||||
self._save_users_unsafe(users)
|
||||
return User(**user_data)
|
||||
# 更新管理员状态
|
||||
if is_admin and not user.is_admin:
|
||||
user.is_admin = True
|
||||
else:
|
||||
# 创建新用户
|
||||
user = UserModel(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url,
|
||||
trust_level=trust_level,
|
||||
is_admin=is_admin,
|
||||
linuxdo_id=linuxdo_id,
|
||||
created_at=datetime.now(),
|
||||
last_login=datetime.now()
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
return User(**user.to_dict())
|
||||
|
||||
async def get_user(self, user_id: str) -> Optional[User]:
|
||||
"""获取用户(线程安全)"""
|
||||
users = await self._load_users()
|
||||
user_data = users.get(user_id)
|
||||
if user_data:
|
||||
# 同步管理员状态
|
||||
admin_list = await self._load_admin_list()
|
||||
user_data["is_admin"] = user_id in admin_list
|
||||
return User(**user_data)
|
||||
return None
|
||||
"""获取用户"""
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.user_id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
return User(**user.to_dict())
|
||||
return None
|
||||
|
||||
async def get_all_users(self) -> List[User]:
|
||||
"""获取所有用户(线程安全)"""
|
||||
users = await self._load_users()
|
||||
admin_list = await self._load_admin_list()
|
||||
"""获取所有用户"""
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
user_list = []
|
||||
for user_data in users.values():
|
||||
# 同步管理员状态
|
||||
user_data["is_admin"] = user_data["user_id"] in admin_list
|
||||
user_list.append(User(**user_data))
|
||||
|
||||
return user_list
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(select(UserModel))
|
||||
users = result.scalars().all()
|
||||
|
||||
return [User(**user.to_dict()) for user in users]
|
||||
|
||||
async def set_admin(self, user_id: str, is_admin: bool) -> bool:
|
||||
"""
|
||||
设置用户的管理员权限(线程安全)
|
||||
设置用户的管理员权限
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
@@ -219,38 +153,35 @@ class UserManager:
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
# 使用锁保护整个读-改-写操作
|
||||
async with self._users_lock:
|
||||
async with self._admins_lock:
|
||||
users = self._load_users_unsafe()
|
||||
if user_id not in users:
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.user_id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return False
|
||||
|
||||
if not is_admin:
|
||||
# 撤销管理员权限时,确保至少保留一个管理员
|
||||
admin_result = await session.execute(
|
||||
select(UserModel).where(UserModel.is_admin == True)
|
||||
)
|
||||
admin_count = len(admin_result.scalars().all())
|
||||
|
||||
if admin_count <= 1:
|
||||
return False
|
||||
|
||||
admin_list = self._load_admin_list_unsafe()
|
||||
|
||||
if is_admin:
|
||||
# 授予管理员权限
|
||||
if user_id not in admin_list:
|
||||
admin_list.append(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
else:
|
||||
# 撤销管理员权限
|
||||
if user_id in admin_list:
|
||||
# 确保至少保留一个管理员
|
||||
if len(admin_list) <= 1:
|
||||
return False
|
||||
admin_list.remove(user_id)
|
||||
self._save_admin_list_unsafe(admin_list)
|
||||
|
||||
# 更新用户数据中的 is_admin 字段
|
||||
users[user_id]["is_admin"] = is_admin
|
||||
self._save_users_unsafe(users)
|
||||
|
||||
return True
|
||||
|
||||
user.is_admin = is_admin
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def delete_user(self, user_id: str) -> bool:
|
||||
"""
|
||||
删除用户(线程安全)
|
||||
删除用户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
@@ -258,36 +189,37 @@ class UserManager:
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
# 使用锁保护整个读-改-写操作
|
||||
async with self._users_lock:
|
||||
async with self._admins_lock:
|
||||
users = self._load_users_unsafe()
|
||||
if user_id not in users:
|
||||
return False
|
||||
|
||||
# 不能删除管理员
|
||||
admin_list = self._load_admin_list_unsafe()
|
||||
if user_id in admin_list:
|
||||
return False
|
||||
|
||||
# 删除用户数据
|
||||
del users[user_id]
|
||||
self._save_users_unsafe(users)
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
# 删除用户数据库文件(在锁外执行,避免阻塞)
|
||||
db_file = str(DATA_DIR / f"ai_story_user_{user_id}.db")
|
||||
if os.path.exists(db_file):
|
||||
try:
|
||||
os.remove(db_file)
|
||||
except Exception as e:
|
||||
print(f"删除用户数据库文件失败: {e}")
|
||||
|
||||
return True
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.user_id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return False
|
||||
|
||||
# 不能删除管理员
|
||||
if user.is_admin:
|
||||
return False
|
||||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def is_admin(self, user_id: str) -> bool:
|
||||
"""检查用户是否为管理员(线程安全)"""
|
||||
admin_list = await self._load_admin_list()
|
||||
return user_id in admin_list
|
||||
"""检查用户是否为管理员"""
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.user_id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
return user.is_admin if user else False
|
||||
|
||||
|
||||
# 全局用户管理器实例
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
用户密码管理模块 - 使用数据库存储
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class UserPasswordManager:
|
||||
"""用户密码管理器 - 使用数据库存储(PostgreSQL共享库)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化密码管理器"""
|
||||
pass
|
||||
|
||||
async def _get_session(self) -> AsyncSession:
|
||||
"""获取数据库会话 - 使用共享的PostgreSQL引擎"""
|
||||
from app.database import get_engine
|
||||
|
||||
# 使用共享的PostgreSQL引擎(user_id使用特殊标识)
|
||||
engine = await get_engine("_global_users_")
|
||||
|
||||
session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
return session_maker()
|
||||
|
||||
def _hash_password(self, password: str) -> str:
|
||||
"""密码哈希"""
|
||||
return hashlib.sha256(password.encode()).hexdigest()
|
||||
|
||||
async def set_password(self, user_id: str, username: str, password: Optional[str] = None) -> str:
|
||||
"""
|
||||
设置用户密码
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
username: 用户名
|
||||
password: 密码,如果为None则使用默认密码(username+@666)
|
||||
|
||||
Returns:
|
||||
实际使用的密码(明文,仅用于首次设置时返回给用户)
|
||||
"""
|
||||
from app.models.user import UserPassword as UserPasswordModel
|
||||
|
||||
# 如果没有提供密码,使用默认密码
|
||||
actual_password = password if password else f"{username}@666"
|
||||
|
||||
async with await self._get_session() as session:
|
||||
# 查询密码记录是否存在
|
||||
result = await session.execute(
|
||||
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
|
||||
)
|
||||
pwd_record = result.scalar_one_or_none()
|
||||
|
||||
if pwd_record:
|
||||
# 更新现有密码
|
||||
pwd_record.username = username
|
||||
pwd_record.password_hash = self._hash_password(actual_password)
|
||||
pwd_record.has_custom_password = password is not None
|
||||
pwd_record.updated_at = datetime.now()
|
||||
else:
|
||||
# 创建新密码记录
|
||||
pwd_record = UserPasswordModel(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
password_hash=self._hash_password(actual_password),
|
||||
has_custom_password=password is not None,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
session.add(pwd_record)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return actual_password
|
||||
|
||||
async def verify_password(self, user_id: str, password: str) -> bool:
|
||||
"""
|
||||
验证用户密码
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
password: 待验证的密码
|
||||
|
||||
Returns:
|
||||
是否验证通过
|
||||
"""
|
||||
from app.models.user import UserPassword as UserPasswordModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
|
||||
)
|
||||
pwd_record = result.scalar_one_or_none()
|
||||
|
||||
if not pwd_record:
|
||||
return False
|
||||
|
||||
password_hash = self._hash_password(password)
|
||||
return pwd_record.password_hash == password_hash
|
||||
|
||||
async def has_password(self, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否已设置密码
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否已设置密码
|
||||
"""
|
||||
from app.models.user import UserPassword as UserPasswordModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
|
||||
)
|
||||
pwd_record = result.scalar_one_or_none()
|
||||
|
||||
return pwd_record is not None
|
||||
|
||||
async def has_custom_password(self, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否设置了自定义密码(非默认密码)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否使用自定义密码
|
||||
"""
|
||||
from app.models.user import UserPassword as UserPasswordModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
|
||||
)
|
||||
pwd_record = result.scalar_one_or_none()
|
||||
|
||||
if not pwd_record:
|
||||
return False
|
||||
|
||||
return pwd_record.has_custom_password
|
||||
|
||||
async def get_username(self, user_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取用户名
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户名,如果不存在返回None
|
||||
"""
|
||||
from app.models.user import UserPassword as UserPasswordModel
|
||||
|
||||
async with await self._get_session() as session:
|
||||
result = await session.execute(
|
||||
select(UserPasswordModel).where(UserPasswordModel.user_id == user_id)
|
||||
)
|
||||
pwd_record = result.scalar_one_or_none()
|
||||
|
||||
if not pwd_record:
|
||||
return None
|
||||
|
||||
return pwd_record.username
|
||||
|
||||
|
||||
# 全局密码管理器实例
|
||||
password_manager = UserPasswordManager()
|
||||
@@ -158,8 +158,18 @@ def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingRespon
|
||||
Returns:
|
||||
StreamingResponse对象
|
||||
"""
|
||||
async def wrapper():
|
||||
"""包装生成器以捕获StreamingResponse初始化时的GeneratorExit"""
|
||||
try:
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
except GeneratorExit:
|
||||
# StreamingResponse在初始化时会进行类型检查,导致GeneratorExit
|
||||
# 这是正常行为,不需要记录警告
|
||||
pass
|
||||
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
wrapper(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
-- 创建章节重新生成任务表
|
||||
-- 用于支持根据AI分析建议重新生成章节内容的功能
|
||||
|
||||
-- 创建重新生成任务表
|
||||
CREATE TABLE IF NOT EXISTS regeneration_tasks (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
chapter_id VARCHAR(36) NOT NULL,
|
||||
analysis_id VARCHAR(36),
|
||||
user_id VARCHAR(100) NOT NULL,
|
||||
project_id VARCHAR(36) NOT NULL,
|
||||
|
||||
-- 修改指令
|
||||
modification_instructions TEXT NOT NULL,
|
||||
original_suggestions JSON,
|
||||
selected_suggestion_indices JSON,
|
||||
custom_instructions TEXT,
|
||||
|
||||
-- 生成配置
|
||||
style_id INTEGER,
|
||||
target_word_count INTEGER DEFAULT 3000,
|
||||
focus_areas JSON,
|
||||
preserve_elements JSON,
|
||||
|
||||
-- 任务状态
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
progress INTEGER DEFAULT 0,
|
||||
error_message TEXT,
|
||||
|
||||
-- 内容数据
|
||||
original_content TEXT,
|
||||
original_word_count INTEGER,
|
||||
regenerated_content TEXT,
|
||||
regenerated_word_count INTEGER,
|
||||
|
||||
-- 版本信息
|
||||
version_number INTEGER DEFAULT 1,
|
||||
version_note TEXT,
|
||||
|
||||
-- 时间戳
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at TIMESTAMP,
|
||||
completed_at TIMESTAMP,
|
||||
|
||||
-- 外键约束
|
||||
CONSTRAINT fk_regeneration_chapter FOREIGN KEY (chapter_id) REFERENCES chapters(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_regeneration_project FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_regeneration_analysis FOREIGN KEY (analysis_id) REFERENCES analysis_tasks(id) ON DELETE SET NULL,
|
||||
CONSTRAINT fk_regeneration_style FOREIGN KEY (style_id) REFERENCES writing_styles(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
-- 创建索引以提升查询性能
|
||||
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_chapter ON regeneration_tasks(chapter_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_project ON regeneration_tasks(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_user ON regeneration_tasks(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_status ON regeneration_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_regeneration_tasks_created ON regeneration_tasks(created_at DESC);
|
||||
|
||||
-- 添加注释
|
||||
COMMENT ON TABLE regeneration_tasks IS '章节重新生成任务表,记录每次根据AI建议重新生成章节的任务';
|
||||
|
||||
COMMENT ON COLUMN regeneration_tasks.modification_instructions IS '合并后的完整修改指令';
|
||||
COMMENT ON COLUMN regeneration_tasks.original_suggestions IS '原始AI分析建议列表';
|
||||
COMMENT ON COLUMN regeneration_tasks.selected_suggestion_indices IS '用户选择的建议索引';
|
||||
COMMENT ON COLUMN regeneration_tasks.preserve_elements IS '需要保留的元素配置(JSON)';
|
||||
COMMENT ON COLUMN regeneration_tasks.focus_areas IS '重点优化方向列表(JSON)';
|
||||
|
||||
-- 修复外键约束(合并自 fix_all_missing_columns.sql)
|
||||
-- 删除可能存在问题的外键约束
|
||||
ALTER TABLE regeneration_tasks
|
||||
DROP CONSTRAINT IF EXISTS fk_regeneration_analysis;
|
||||
|
||||
-- 完成提示
|
||||
SELECT '✅ 重新生成任务表创建完成,外键约束已修复' AS status;
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
用户数据迁移脚本 - 从JSON文件迁移到数据库
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from app.user_manager import user_manager
|
||||
from app.user_password import password_manager
|
||||
from app.config import DATA_DIR
|
||||
|
||||
|
||||
async def migrate_users():
|
||||
"""迁移用户数据"""
|
||||
users_file = DATA_DIR / "users.json"
|
||||
|
||||
if not users_file.exists():
|
||||
print("❌ 用户数据文件不存在,跳过迁移")
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(users_file, "r", encoding="utf-8") as f:
|
||||
users_data = json.load(f)
|
||||
|
||||
if not users_data:
|
||||
print("ℹ️ 用户数据为空,跳过迁移")
|
||||
return 0
|
||||
|
||||
migrated_count = 0
|
||||
for user_id, user_info in users_data.items():
|
||||
try:
|
||||
# 迁移用户基本信息
|
||||
await user_manager.create_or_update_from_linuxdo(
|
||||
linuxdo_id=user_info["linuxdo_id"],
|
||||
username=user_info["username"],
|
||||
display_name=user_info["display_name"],
|
||||
avatar_url=user_info.get("avatar_url"),
|
||||
trust_level=user_info.get("trust_level", 0)
|
||||
)
|
||||
|
||||
# 如果用户是管理员,设置管理员权限
|
||||
if user_info.get("is_admin", False):
|
||||
await user_manager.set_admin(user_id, True)
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✅ 迁移用户: {user_info['username']} ({user_id})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移用户 {user_id} 失败: {e}")
|
||||
|
||||
print(f"\n✅ 用户数据迁移完成: {migrated_count}/{len(users_data)} 个用户")
|
||||
|
||||
# 备份原文件
|
||||
backup_file = DATA_DIR / "users.json.backup"
|
||||
os.rename(users_file, backup_file)
|
||||
print(f"📦 原文件已备份到: {backup_file}")
|
||||
|
||||
return migrated_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移用户数据失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def migrate_passwords():
|
||||
"""迁移密码数据"""
|
||||
passwords_file = DATA_DIR / "user_passwords.json"
|
||||
|
||||
if not passwords_file.exists():
|
||||
print("❌ 密码数据文件不存在,跳过迁移")
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(passwords_file, "r", encoding="utf-8") as f:
|
||||
passwords_data = json.load(f)
|
||||
|
||||
if not passwords_data:
|
||||
print("ℹ️ 密码数据为空,跳过迁移")
|
||||
return 0
|
||||
|
||||
migrated_count = 0
|
||||
for user_id, pwd_info in passwords_data.items():
|
||||
try:
|
||||
# 直接插入密码记录(已经是哈希值)
|
||||
from app.models.user import UserPassword
|
||||
from app.user_password import password_manager as pm
|
||||
|
||||
async with await pm._get_session() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
# 检查是否已存在
|
||||
result = await session.execute(
|
||||
select(UserPassword).where(UserPassword.user_id == user_id)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
print(f"ℹ️ 密码已存在,跳过: {pwd_info['username']} ({user_id})")
|
||||
continue
|
||||
|
||||
# 创建密码记录
|
||||
from datetime import datetime
|
||||
pwd_record = UserPassword(
|
||||
user_id=user_id,
|
||||
username=pwd_info["username"],
|
||||
password_hash=pwd_info["password_hash"],
|
||||
has_custom_password=pwd_info.get("has_custom_password", False),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
session.add(pwd_record)
|
||||
await session.commit()
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✅ 迁移密码: {pwd_info['username']} ({user_id})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移密码 {user_id} 失败: {e}")
|
||||
|
||||
print(f"\n✅ 密码数据迁移完成: {migrated_count}/{len(passwords_data)} 个密码")
|
||||
|
||||
# 备份原文件
|
||||
backup_file = DATA_DIR / "user_passwords.json.backup"
|
||||
os.rename(passwords_file, backup_file)
|
||||
print(f"📦 原文件已备份到: {backup_file}")
|
||||
|
||||
return migrated_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移密码数据失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def migrate_admins():
|
||||
"""迁移管理员列表"""
|
||||
admins_file = DATA_DIR / "admins.json"
|
||||
|
||||
if not admins_file.exists():
|
||||
print("❌ 管理员数据文件不存在,跳过迁移")
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(admins_file, "r", encoding="utf-8") as f:
|
||||
admins_data = json.load(f)
|
||||
|
||||
admin_list = admins_data.get("admins", [])
|
||||
|
||||
if not admin_list:
|
||||
print("ℹ️ 管理员列表为空,跳过迁移")
|
||||
return 0
|
||||
|
||||
migrated_count = 0
|
||||
for user_id in admin_list:
|
||||
try:
|
||||
# 设置管理员权限
|
||||
success = await user_manager.set_admin(user_id, True)
|
||||
if success:
|
||||
migrated_count += 1
|
||||
print(f"✅ 设置管理员: {user_id}")
|
||||
else:
|
||||
print(f"⚠️ 用户不存在或已是管理员: {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 设置管理员 {user_id} 失败: {e}")
|
||||
|
||||
print(f"\n✅ 管理员数据迁移完成: {migrated_count}/{len(admin_list)} 个管理员")
|
||||
|
||||
# 备份原文件
|
||||
backup_file = DATA_DIR / "admins.json.backup"
|
||||
os.rename(admins_file, backup_file)
|
||||
print(f"📦 原文件已备份到: {backup_file}")
|
||||
|
||||
return migrated_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移管理员数据失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
print("=" * 60)
|
||||
print("用户数据迁移工具 - JSON 到数据库")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 迁移用户
|
||||
print("📋 步骤 1/3: 迁移用户数据")
|
||||
print("-" * 60)
|
||||
user_count = await migrate_users()
|
||||
print()
|
||||
|
||||
# 迁移密码
|
||||
print("📋 步骤 2/3: 迁移密码数据")
|
||||
print("-" * 60)
|
||||
pwd_count = await migrate_passwords()
|
||||
print()
|
||||
|
||||
# 迁移管理员
|
||||
print("📋 步骤 3/3: 迁移管理员数据")
|
||||
print("-" * 60)
|
||||
admin_count = await migrate_admins()
|
||||
print()
|
||||
|
||||
# 总结
|
||||
print("=" * 60)
|
||||
print("迁移完成")
|
||||
print("=" * 60)
|
||||
print(f"✅ 用户: {user_count}")
|
||||
print(f"✅ 密码: {pwd_count}")
|
||||
print(f"✅ 管理员: {admin_count}")
|
||||
print()
|
||||
print("💡 提示: 原文件已备份为 .backup 后缀")
|
||||
print("💡 如需回滚,请删除数据库文件并恢复 .backup 文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
用户数据迁移脚本 - 从JSON文件迁移到PostgreSQL数据库
|
||||
|
||||
使用方法:
|
||||
python migrate_users_to_postgres.py
|
||||
python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from app.config import settings, DATA_DIR
|
||||
|
||||
|
||||
async def create_tables(engine):
|
||||
"""创建用户相关表"""
|
||||
from app.database import Base
|
||||
from app.models.user import User, UserPassword
|
||||
|
||||
print("📋 创建数据库表...")
|
||||
async with engine.begin() as conn:
|
||||
# 只创建用户相关的表
|
||||
await conn.run_sync(User.metadata.create_all)
|
||||
await conn.run_sync(UserPassword.metadata.create_all)
|
||||
print("✅ 表创建成功")
|
||||
|
||||
|
||||
async def migrate_users(session):
|
||||
"""迁移用户数据"""
|
||||
from app.models.user import User as UserModel
|
||||
|
||||
users_file = DATA_DIR / "users.json"
|
||||
|
||||
if not users_file.exists():
|
||||
print("ℹ️ 用户数据文件不存在,跳过迁移")
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(users_file, "r", encoding="utf-8") as f:
|
||||
users_data = json.load(f)
|
||||
|
||||
if not users_data:
|
||||
print("ℹ️ 用户数据为空,跳过迁移")
|
||||
return 0
|
||||
|
||||
migrated_count = 0
|
||||
for user_id, user_info in users_data.items():
|
||||
try:
|
||||
# 检查用户是否已存在
|
||||
result = await session.execute(
|
||||
select(UserModel).where(UserModel.user_id == user_id)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
print(f"ℹ️ 用户已存在,跳过: {user_info['username']} ({user_id})")
|
||||
continue
|
||||
|
||||
# 创建用户记录
|
||||
user = UserModel(
|
||||
user_id=user_id,
|
||||
username=user_info["username"],
|
||||
display_name=user_info["display_name"],
|
||||
avatar_url=user_info.get("avatar_url"),
|
||||
trust_level=user_info.get("trust_level", 0),
|
||||
is_admin=user_info.get("is_admin", False),
|
||||
linuxdo_id=user_info["linuxdo_id"],
|
||||
created_at=datetime.fromisoformat(user_info.get("created_at", datetime.now().isoformat())),
|
||||
last_login=datetime.fromisoformat(user_info.get("last_login", datetime.now().isoformat()))
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✅ 迁移用户: {user_info['username']} ({user_id})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移用户 {user_id} 失败: {e}")
|
||||
|
||||
await session.commit()
|
||||
print(f"\n✅ 用户数据迁移完成: {migrated_count}/{len(users_data)} 个用户")
|
||||
|
||||
return migrated_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移用户数据失败: {e}")
|
||||
await session.rollback()
|
||||
return 0
|
||||
|
||||
|
||||
async def migrate_passwords(session):
|
||||
"""迁移密码数据"""
|
||||
from app.models.user import UserPassword
|
||||
|
||||
passwords_file = DATA_DIR / "user_passwords.json"
|
||||
|
||||
if not passwords_file.exists():
|
||||
print("ℹ️ 密码数据文件不存在,跳过迁移")
|
||||
return 0
|
||||
|
||||
try:
|
||||
with open(passwords_file, "r", encoding="utf-8") as f:
|
||||
passwords_data = json.load(f)
|
||||
|
||||
if not passwords_data:
|
||||
print("ℹ️ 密码数据为空,跳过迁移")
|
||||
return 0
|
||||
|
||||
migrated_count = 0
|
||||
for user_id, pwd_info in passwords_data.items():
|
||||
try:
|
||||
# 检查密码是否已存在
|
||||
result = await session.execute(
|
||||
select(UserPassword).where(UserPassword.user_id == user_id)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
print(f"ℹ️ 密码已存在,跳过: {pwd_info['username']} ({user_id})")
|
||||
continue
|
||||
|
||||
# 创建密码记录
|
||||
pwd_record = UserPassword(
|
||||
user_id=user_id,
|
||||
username=pwd_info["username"],
|
||||
password_hash=pwd_info["password_hash"],
|
||||
has_custom_password=pwd_info.get("has_custom_password", False),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now()
|
||||
)
|
||||
session.add(pwd_record)
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✅ 迁移密码: {pwd_info['username']} ({user_id})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移密码 {user_id} 失败: {e}")
|
||||
|
||||
await session.commit()
|
||||
print(f"\n✅ 密码数据迁移完成: {migrated_count}/{len(passwords_data)} 个密码")
|
||||
|
||||
return migrated_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 迁移密码数据失败: {e}")
|
||||
await session.rollback()
|
||||
return 0
|
||||
|
||||
|
||||
async def backup_json_files():
|
||||
"""备份原始JSON文件"""
|
||||
files_to_backup = ["users.json", "user_passwords.json", "admins.json"]
|
||||
|
||||
print("\n📦 备份原始文件...")
|
||||
for filename in files_to_backup:
|
||||
source = DATA_DIR / filename
|
||||
if source.exists():
|
||||
backup = DATA_DIR / f"{filename}.backup.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
import shutil
|
||||
shutil.copy2(source, backup)
|
||||
print(f"✅ 备份: {filename} -> {backup.name}")
|
||||
|
||||
|
||||
async def main(db_url=None):
|
||||
"""主函数
|
||||
|
||||
Args:
|
||||
db_url: 可选的数据库URL,如果不提供则使用配置文件中的
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("用户数据迁移工具 - JSON 到 PostgreSQL")
|
||||
print("=" * 70)
|
||||
print()
|
||||
|
||||
# 确定使用的数据库URL
|
||||
target_db_url = db_url if db_url else settings.database_url
|
||||
|
||||
# 检查数据库配置
|
||||
if "postgresql" not in target_db_url:
|
||||
print("❌ 错误: 未指定 PostgreSQL 数据库")
|
||||
if not db_url:
|
||||
print(f" 当前配置: {settings.database_url}")
|
||||
print(" 请使用 --db-url 参数指定PostgreSQL数据库,或在 .env 中配置 DATABASE_URL")
|
||||
else:
|
||||
print(f" 提供的URL: {target_db_url}")
|
||||
print()
|
||||
print("示例:")
|
||||
print(" python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname")
|
||||
return
|
||||
|
||||
# 隐藏密码部分显示
|
||||
display_url = target_db_url
|
||||
if '@' in display_url:
|
||||
parts = display_url.split('@')
|
||||
if ':' in parts[0]:
|
||||
user_part = parts[0].split(':')[0]
|
||||
display_url = f"{user_part}:****@{parts[1]}"
|
||||
|
||||
print(f"📊 目标数据库: {display_url}")
|
||||
print()
|
||||
|
||||
try:
|
||||
# 创建数据库引擎
|
||||
engine = create_async_engine(
|
||||
target_db_url,
|
||||
echo=False,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# 创建表
|
||||
await create_tables(engine)
|
||||
print()
|
||||
|
||||
# 创建会话
|
||||
async_session = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
# 迁移用户
|
||||
print("📋 步骤 1/2: 迁移用户数据")
|
||||
print("-" * 70)
|
||||
async with async_session() as session:
|
||||
user_count = await migrate_users(session)
|
||||
print()
|
||||
|
||||
# 迁移密码
|
||||
print("📋 步骤 2/2: 迁移密码数据")
|
||||
print("-" * 70)
|
||||
async with async_session() as session:
|
||||
pwd_count = await migrate_passwords(session)
|
||||
print()
|
||||
|
||||
# 备份原文件
|
||||
await backup_json_files()
|
||||
print()
|
||||
|
||||
# 总结
|
||||
print("=" * 70)
|
||||
print("迁移完成")
|
||||
print("=" * 70)
|
||||
print(f"✅ 用户: {user_count}")
|
||||
print(f"✅ 密码: {pwd_count}")
|
||||
print()
|
||||
print("💡 提示:")
|
||||
print(" - 原文件已备份(带时间戳)")
|
||||
print(" - 可以安全删除 users.json 和 user_passwords.json")
|
||||
print(" - 如需回滚,请从备份文件恢复")
|
||||
print()
|
||||
|
||||
# 关闭引擎
|
||||
await engine.dispose()
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 迁移过程出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(
|
||||
description="迁移用户数据从JSON到PostgreSQL数据库",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 使用 .env 配置的数据库
|
||||
python migrate_users_to_postgres.py
|
||||
|
||||
# 指定数据库URL
|
||||
python migrate_users_to_postgres.py --db-url postgresql+asyncpg://user:pass@localhost/dbname
|
||||
|
||||
# 使用环境变量
|
||||
DATABASE_URL=postgresql+asyncpg://user:pass@localhost/db python migrate_users_to_postgres.py
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-url",
|
||||
type=str,
|
||||
help="PostgreSQL数据库连接URL (格式: postgresql+asyncpg://user:password@host:port/database)",
|
||||
default=None
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 运行迁移
|
||||
asyncio.run(main(db_url=args.db_url))
|
||||
Reference in New Issue
Block a user