Files
MuMuAINovel/backend/app/utils/sse_response.py
T

179 lines
4.6 KiB
Python

"""Server-Sent Events (SSE) 响应工具类"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional
from fastapi.responses import StreamingResponse
from app.logger import get_logger
logger = get_logger(__name__)
class SSEResponse:
"""SSE响应构建器"""
@staticmethod
def format_sse(data: Dict[str, Any], event: Optional[str] = None) -> str:
"""
格式化SSE消息
Args:
data: 要发送的数据字典
event: 事件类型(可选)
Returns:
格式化后的SSE消息字符串
"""
message = ""
if event:
message += f"event: {event}\n"
message += f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
return message
@staticmethod
async def send_progress(
message: str,
progress: int,
status: str = "processing"
) -> str:
"""
发送进度消息
Args:
message: 进度消息
progress: 进度百分比(0-100)
status: 状态(processing/success/error)
"""
return SSEResponse.format_sse({
"type": "progress",
"message": message,
"progress": progress,
"status": status
})
@staticmethod
async def send_chunk(content: str) -> str:
"""
发送内容块(用于流式输出AI生成内容)
Args:
content: 内容块
"""
return SSEResponse.format_sse({
"type": "chunk",
"content": content
})
@staticmethod
async def send_result(data: Dict[str, Any]) -> str:
"""
发送最终结果
Args:
data: 结果数据
"""
return SSEResponse.format_sse({
"type": "result",
"data": data
})
@staticmethod
async def send_error(error: str, code: int = 500) -> str:
"""
发送错误消息
Args:
error: 错误描述
code: 错误码
"""
return SSEResponse.format_sse({
"type": "error",
"error": error,
"code": code
})
@staticmethod
async def send_done() -> str:
"""发送完成消息"""
return SSEResponse.format_sse({
"type": "done"
})
@staticmethod
async def send_heartbeat() -> str:
"""发送心跳消息(保持连接活跃)"""
return ": heartbeat\n\n"
async def create_sse_generator(
async_gen: AsyncGenerator[str, None],
show_progress: bool = True
) -> AsyncGenerator[str, None]:
"""
创建SSE生成器包装器
Args:
async_gen: 异步生成器
show_progress: 是否显示进度
Yields:
格式化的SSE消息
"""
try:
if show_progress:
yield await SSEResponse.send_progress("开始生成...", 0)
# 累积内容用于进度计算
accumulated_content = ""
chunk_count = 0
async for chunk in async_gen:
chunk_count += 1
accumulated_content += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 每10个块发送一次心跳
if chunk_count % 10 == 0:
yield await SSEResponse.send_heartbeat()
if show_progress:
yield await SSEResponse.send_progress("生成完成", 100, "success")
# 发送完成信号
yield await SSEResponse.send_done()
except Exception as e:
logger.error(f"SSE生成器错误: {str(e)}")
yield await SSEResponse.send_error(str(e))
def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingResponse:
"""
创建SSE StreamingResponse
Args:
generator: SSE消息生成器
Returns:
StreamingResponse对象
"""
async def wrapper():
"""包装生成器以捕获StreamingResponse初始化时的GeneratorExit"""
try:
async for chunk in generator:
yield chunk
except GeneratorExit:
# StreamingResponse在初始化时会进行类型检查,导致GeneratorExit
# 这是正常行为,不需要记录警告
pass
return StreamingResponse(
wrapper(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用nginx缓冲
}
)