Files
MuMuAINovel/backend/app/api/book_import.py
T

282 lines
11 KiB
Python
Raw Normal View History

"""拆书导入 API"""
from __future__ import annotations
import asyncio
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.logger import get_logger
from app.schemas.book_import import (
BookImportApplyRequest,
BookImportApplyResponse,
BookImportPreviewResponse,
BookImportRetryRequest,
BookImportTaskCreateResponse,
BookImportTaskCreateRequest,
BookImportTaskStatusResponse,
)
from app.services.book_import_service import book_import_service
from app.utils.sse_response import SSEResponse, create_sse_response
router = APIRouter(prefix="/book-import", tags=["拆书导入"])
logger = get_logger(__name__)
MAX_TXT_SIZE = 50 * 1024 * 1024 # 50MB
@router.post("/tasks", response_model=BookImportTaskCreateResponse, summary="创建拆书任务(上传TXT")
async def create_book_import_task(
request: Request,
file: UploadFile = File(..., description="TXT 文件"),
project_id: str | None = Form(default=None, description="兼容参数:当前版本固定新建项目,不支持传入"),
create_new_project: bool = Form(default=True, description="兼容参数:当前版本仅支持 true"),
import_mode: str = Form(default="append", description="导入模式:append/overwrite"),
extract_mode: str = Form(default="tail", description="解析范围:tail=截取末章,full=整本"),
tail_chapter_count: int = Form(default=10, description="当 extract_mode=tail 时,截取末尾章节数,需为5的倍数;超过50按整本拆处理"),
):
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
if not file.filename or not file.filename.lower().endswith(".txt"):
raise HTTPException(status_code=400, detail="仅支持 .txt 文件")
if import_mode not in {"append", "overwrite"}:
raise HTTPException(status_code=400, detail="import_mode 仅支持 append 或 overwrite")
if extract_mode not in {"tail", "full"}:
raise HTTPException(status_code=400, detail="extract_mode 仅支持 tail 或 full")
if tail_chapter_count < 5:
raise HTTPException(status_code=400, detail="tail_chapter_count 不能小于 5")
if tail_chapter_count % 5 != 0:
raise HTTPException(status_code=400, detail="tail_chapter_count 必须是 5 的倍数")
if tail_chapter_count > 50:
extract_mode = "full"
if project_id:
raise HTTPException(status_code=400, detail="当前仅支持新建项目导入,不支持指定 project_id")
if not create_new_project:
raise HTTPException(status_code=400, detail="当前仅支持新建项目导入")
create_payload = BookImportTaskCreateRequest(
extract_mode=extract_mode,
tail_chapter_count=tail_chapter_count,
)
content = await file.read()
if len(content) > MAX_TXT_SIZE:
raise HTTPException(status_code=413, detail="文件大小超过 50MB 限制")
task = await book_import_service.create_task(
user_id=user_id,
filename=file.filename,
file_content=content,
project_id=None,
create_new_project=True,
import_mode=import_mode,
extract_mode=create_payload.extract_mode,
tail_chapter_count=create_payload.tail_chapter_count,
)
return task
@router.get("/tasks/{task_id}", response_model=BookImportTaskStatusResponse, summary="查询拆书任务状态")
async def get_book_import_task_status(task_id: str, request: Request):
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
return await book_import_service.get_task_status(task_id=task_id, user_id=user_id)
@router.get("/tasks/{task_id}/preview", response_model=BookImportPreviewResponse, summary="获取拆书预览")
async def get_book_import_preview(task_id: str, request: Request):
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
return await book_import_service.get_preview(task_id=task_id, user_id=user_id)
@router.post("/tasks/{task_id}/apply", response_model=BookImportApplyResponse, summary="确认并导入")
async def apply_book_import(
task_id: str,
payload: BookImportApplyRequest,
request: Request,
db: AsyncSession = Depends(get_db),
):
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
return await book_import_service.apply_import(
task_id=task_id,
user_id=user_id,
payload=payload,
db=db,
)
@router.delete("/tasks/{task_id}", summary="取消拆书任务")
async def cancel_book_import_task(task_id: str, request: Request):
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
return await book_import_service.cancel_task(task_id=task_id, user_id=user_id)
@router.post("/tasks/{task_id}/apply-stream", summary="确认并导入(SSE流式进度)")
async def apply_book_import_stream(
task_id: str,
payload: BookImportApplyRequest,
request: Request,
db: AsyncSession = Depends(get_db),
):
"""
SSE 流式接口:执行基础导入后,分步生成世界观/职业/角色并实时推送进度。
使用 asyncio.Queue 在服务与 SSE 生成器之间传递进度消息。
"""
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
# 使用 asyncio.Queue 实现实时进度推送
progress_queue: asyncio.Queue[str | None] = asyncio.Queue()
async def _progress_callback(message: str, progress: int, status: str = "processing") -> None:
"""进度回调:放入队列供 SSE 生成器消费"""
sse_msg = SSEResponse.format_sse({
"type": "progress",
"message": message,
"progress": progress,
"status": status,
})
await progress_queue.put(sse_msg)
async def _run_import() -> None:
"""在后台任务中执行导入并通过队列推送进度"""
try:
result = await book_import_service.apply_import_stream(
task_id=task_id,
user_id=user_id,
payload=payload,
db=db,
progress_callback=_progress_callback,
)
# 发送结果
await progress_queue.put(await SSEResponse.send_result({
"success": result.success,
"project_id": result.project_id,
"statistics": result.statistics,
}))
await progress_queue.put(await SSEResponse.send_progress("导入完成!", 100, "success"))
await progress_queue.put(await SSEResponse.send_done())
except HTTPException as exc:
await progress_queue.put(await SSEResponse.send_error(exc.detail, exc.status_code))
except Exception as exc:
logger.error(f"拆书SSE导入失败: {exc}", exc_info=True)
await progress_queue.put(await SSEResponse.send_error(str(exc), 500))
finally:
# 发送终止信号
await progress_queue.put(None)
async def _streaming_generator() -> AsyncGenerator[str, None]:
yield await SSEResponse.send_progress("开始导入拆书数据...", 0, "processing")
# 启动后台导入任务
import_task = asyncio.create_task(_run_import())
try:
while True:
msg = await progress_queue.get()
if msg is None:
break
yield msg
except GeneratorExit:
import_task.cancel()
except Exception as exc:
logger.error(f"SSE生成器异常: {exc}", exc_info=True)
yield await SSEResponse.send_error(str(exc), 500)
return create_sse_response(_streaming_generator())
@router.post("/tasks/{task_id}/retry-stream", summary="重试失败的生成步骤(SSE流式进度)")
async def retry_failed_steps_stream(
task_id: str,
payload: BookImportRetryRequest,
request: Request,
db: AsyncSession = Depends(get_db),
):
"""
SSE 流式接口:仅重试之前导入过程中失败的AI生成步骤(世界观/职业/角色)。
"""
user_id = getattr(request.state, "user_id", None)
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
progress_queue: asyncio.Queue[str | None] = asyncio.Queue()
async def _progress_callback(message: str, progress: int, status: str = "processing") -> None:
sse_msg = SSEResponse.format_sse({
"type": "progress",
"message": message,
"progress": progress,
"status": status,
})
await progress_queue.put(sse_msg)
async def _run_retry() -> None:
try:
result = await book_import_service.retry_failed_steps_stream(
task_id=task_id,
user_id=user_id,
steps_to_retry=payload.steps,
db=db,
progress_callback=_progress_callback,
)
await progress_queue.put(await SSEResponse.send_result(result))
if result.get("still_failed"):
await progress_queue.put(await SSEResponse.send_progress(
f"重试完成,仍有 {len(result['still_failed'])} 个步骤失败",
100,
"warning",
))
else:
await progress_queue.put(await SSEResponse.send_progress("所有步骤重试成功!", 100, "success"))
await progress_queue.put(await SSEResponse.send_done())
except HTTPException as exc:
await progress_queue.put(await SSEResponse.send_error(exc.detail, exc.status_code))
except Exception as exc:
logger.error(f"拆书SSE重试失败: {exc}", exc_info=True)
await progress_queue.put(await SSEResponse.send_error(str(exc), 500))
finally:
await progress_queue.put(None)
async def _streaming_generator() -> AsyncGenerator[str, None]:
yield await SSEResponse.send_progress("开始重试失败的生成步骤...", 0, "processing")
retry_task = asyncio.create_task(_run_retry())
try:
while True:
msg = await progress_queue.get()
if msg is None:
break
yield msg
except GeneratorExit:
retry_task.cancel()
except Exception as exc:
logger.error(f"SSE重试生成器异常: {exc}", exc_info=True)
yield await SSEResponse.send_error(str(exc), 500)
return create_sse_response(_streaming_generator())