262 lines
9.9 KiB
Python
262 lines
9.9 KiB
Python
"""拆书导入 API"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from typing import AsyncGenerator
|
||
|
||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.database import get_db
|
||
from app.logger import get_logger
|
||
from app.schemas.book_import import (
|
||
BookImportApplyRequest,
|
||
BookImportApplyResponse,
|
||
BookImportPreviewResponse,
|
||
BookImportRetryRequest,
|
||
BookImportTaskCreateResponse,
|
||
BookImportTaskStatusResponse,
|
||
)
|
||
from app.services.book_import_service import book_import_service
|
||
from app.utils.sse_response import SSEResponse, create_sse_response
|
||
|
||
router = APIRouter(prefix="/book-import", tags=["拆书导入"])
|
||
logger = get_logger(__name__)
|
||
|
||
MAX_TXT_SIZE = 50 * 1024 * 1024 # 50MB
|
||
|
||
|
||
@router.post("/tasks", response_model=BookImportTaskCreateResponse, summary="创建拆书任务(上传TXT)")
|
||
async def create_book_import_task(
|
||
request: Request,
|
||
file: UploadFile = File(..., description="TXT 文件"),
|
||
project_id: str | None = Form(default=None, description="兼容参数:当前版本固定新建项目,不支持传入"),
|
||
create_new_project: bool = Form(default=True, description="兼容参数:当前版本仅支持 true"),
|
||
import_mode: str = Form(default="append", description="导入模式:append/overwrite"),
|
||
):
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
if not file.filename or not file.filename.lower().endswith(".txt"):
|
||
raise HTTPException(status_code=400, detail="仅支持 .txt 文件")
|
||
|
||
if import_mode not in {"append", "overwrite"}:
|
||
raise HTTPException(status_code=400, detail="import_mode 仅支持 append 或 overwrite")
|
||
|
||
if project_id:
|
||
raise HTTPException(status_code=400, detail="当前仅支持新建项目导入,不支持指定 project_id")
|
||
if not create_new_project:
|
||
raise HTTPException(status_code=400, detail="当前仅支持新建项目导入")
|
||
|
||
content = await file.read()
|
||
if len(content) > MAX_TXT_SIZE:
|
||
raise HTTPException(status_code=413, detail="文件大小超过 50MB 限制")
|
||
|
||
task = await book_import_service.create_task(
|
||
user_id=user_id,
|
||
filename=file.filename,
|
||
file_content=content,
|
||
project_id=None,
|
||
create_new_project=True,
|
||
import_mode=import_mode,
|
||
)
|
||
return task
|
||
|
||
|
||
@router.get("/tasks/{task_id}", response_model=BookImportTaskStatusResponse, summary="查询拆书任务状态")
|
||
async def get_book_import_task_status(task_id: str, request: Request):
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
return await book_import_service.get_task_status(task_id=task_id, user_id=user_id)
|
||
|
||
|
||
@router.get("/tasks/{task_id}/preview", response_model=BookImportPreviewResponse, summary="获取拆书预览")
|
||
async def get_book_import_preview(task_id: str, request: Request):
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
return await book_import_service.get_preview(task_id=task_id, user_id=user_id)
|
||
|
||
|
||
@router.post("/tasks/{task_id}/apply", response_model=BookImportApplyResponse, summary="确认并导入")
|
||
async def apply_book_import(
|
||
task_id: str,
|
||
payload: BookImportApplyRequest,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
return await book_import_service.apply_import(
|
||
task_id=task_id,
|
||
user_id=user_id,
|
||
payload=payload,
|
||
db=db,
|
||
)
|
||
|
||
|
||
@router.delete("/tasks/{task_id}", summary="取消拆书任务")
|
||
async def cancel_book_import_task(task_id: str, request: Request):
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
return await book_import_service.cancel_task(task_id=task_id, user_id=user_id)
|
||
|
||
|
||
@router.post("/tasks/{task_id}/apply-stream", summary="确认并导入(SSE流式进度)")
|
||
async def apply_book_import_stream(
|
||
task_id: str,
|
||
payload: BookImportApplyRequest,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
SSE 流式接口:执行基础导入后,分步生成世界观/职业/角色并实时推送进度。
|
||
使用 asyncio.Queue 在服务与 SSE 生成器之间传递进度消息。
|
||
"""
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
# 使用 asyncio.Queue 实现实时进度推送
|
||
progress_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||
|
||
async def _progress_callback(message: str, progress: int, status: str = "processing") -> None:
|
||
"""进度回调:放入队列供 SSE 生成器消费"""
|
||
sse_msg = SSEResponse.format_sse({
|
||
"type": "progress",
|
||
"message": message,
|
||
"progress": progress,
|
||
"status": status,
|
||
})
|
||
await progress_queue.put(sse_msg)
|
||
|
||
async def _run_import() -> None:
|
||
"""在后台任务中执行导入并通过队列推送进度"""
|
||
try:
|
||
result = await book_import_service.apply_import_stream(
|
||
task_id=task_id,
|
||
user_id=user_id,
|
||
payload=payload,
|
||
db=db,
|
||
progress_callback=_progress_callback,
|
||
)
|
||
|
||
# 发送结果
|
||
await progress_queue.put(await SSEResponse.send_result({
|
||
"success": result.success,
|
||
"project_id": result.project_id,
|
||
"statistics": result.statistics,
|
||
}))
|
||
await progress_queue.put(await SSEResponse.send_progress("导入完成!", 100, "success"))
|
||
await progress_queue.put(await SSEResponse.send_done())
|
||
except HTTPException as exc:
|
||
await progress_queue.put(await SSEResponse.send_error(exc.detail, exc.status_code))
|
||
except Exception as exc:
|
||
logger.error(f"拆书SSE导入失败: {exc}", exc_info=True)
|
||
await progress_queue.put(await SSEResponse.send_error(str(exc), 500))
|
||
finally:
|
||
# 发送终止信号
|
||
await progress_queue.put(None)
|
||
|
||
async def _streaming_generator() -> AsyncGenerator[str, None]:
|
||
yield await SSEResponse.send_progress("开始导入拆书数据...", 0, "processing")
|
||
|
||
# 启动后台导入任务
|
||
import_task = asyncio.create_task(_run_import())
|
||
|
||
try:
|
||
while True:
|
||
msg = await progress_queue.get()
|
||
if msg is None:
|
||
break
|
||
yield msg
|
||
except GeneratorExit:
|
||
import_task.cancel()
|
||
except Exception as exc:
|
||
logger.error(f"SSE生成器异常: {exc}", exc_info=True)
|
||
yield await SSEResponse.send_error(str(exc), 500)
|
||
|
||
return create_sse_response(_streaming_generator())
|
||
|
||
|
||
@router.post("/tasks/{task_id}/retry-stream", summary="重试失败的生成步骤(SSE流式进度)")
|
||
async def retry_failed_steps_stream(
|
||
task_id: str,
|
||
payload: BookImportRetryRequest,
|
||
request: Request,
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
SSE 流式接口:仅重试之前导入过程中失败的AI生成步骤(世界观/职业/角色)。
|
||
"""
|
||
user_id = getattr(request.state, "user_id", None)
|
||
if not user_id:
|
||
raise HTTPException(status_code=401, detail="未登录")
|
||
|
||
progress_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||
|
||
async def _progress_callback(message: str, progress: int, status: str = "processing") -> None:
|
||
sse_msg = SSEResponse.format_sse({
|
||
"type": "progress",
|
||
"message": message,
|
||
"progress": progress,
|
||
"status": status,
|
||
})
|
||
await progress_queue.put(sse_msg)
|
||
|
||
async def _run_retry() -> None:
|
||
try:
|
||
result = await book_import_service.retry_failed_steps_stream(
|
||
task_id=task_id,
|
||
user_id=user_id,
|
||
steps_to_retry=payload.steps,
|
||
db=db,
|
||
progress_callback=_progress_callback,
|
||
)
|
||
|
||
await progress_queue.put(await SSEResponse.send_result(result))
|
||
|
||
if result.get("still_failed"):
|
||
await progress_queue.put(await SSEResponse.send_progress(
|
||
f"重试完成,仍有 {len(result['still_failed'])} 个步骤失败",
|
||
100,
|
||
"warning",
|
||
))
|
||
else:
|
||
await progress_queue.put(await SSEResponse.send_progress("所有步骤重试成功!", 100, "success"))
|
||
|
||
await progress_queue.put(await SSEResponse.send_done())
|
||
except HTTPException as exc:
|
||
await progress_queue.put(await SSEResponse.send_error(exc.detail, exc.status_code))
|
||
except Exception as exc:
|
||
logger.error(f"拆书SSE重试失败: {exc}", exc_info=True)
|
||
await progress_queue.put(await SSEResponse.send_error(str(exc), 500))
|
||
finally:
|
||
await progress_queue.put(None)
|
||
|
||
async def _streaming_generator() -> AsyncGenerator[str, None]:
|
||
yield await SSEResponse.send_progress("开始重试失败的生成步骤...", 0, "processing")
|
||
|
||
retry_task = asyncio.create_task(_run_retry())
|
||
|
||
try:
|
||
while True:
|
||
msg = await progress_queue.get()
|
||
if msg is None:
|
||
break
|
||
yield msg
|
||
except GeneratorExit:
|
||
retry_task.cancel()
|
||
except Exception as exc:
|
||
logger.error(f"SSE重试生成器异常: {exc}", exc_info=True)
|
||
yield await SSEResponse.send_error(str(exc), 500)
|
||
|
||
return create_sse_response(_streaming_generator()) |