feature: 新增小说封面图片生成功能
This commit is contained in:
@@ -0,0 +1,277 @@
|
||||
"""Grok 封面图片 Provider"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.cover_providers.base_cover_provider import BaseCoverProvider, CoverGenerationResult
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GrokCoverProvider(BaseCoverProvider):
|
||||
"""基于 xAI Grok Images API 的封面生成实现"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://api.x.ai/v1").rstrip("/")
|
||||
|
||||
async def generate_cover(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> CoverGenerationResult:
|
||||
result = await self._request_cover(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
return self._to_public_result(result)
|
||||
|
||||
async def _request_cover(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> dict[str, Any]:
|
||||
url = f"{self.base_url}/images/generations"
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"prompt": self._adapt_prompt(prompt=prompt, width=width, height=height),
|
||||
"n": 1,
|
||||
"response_format": "b64_json",
|
||||
"aspect_ratio": self._get_aspect_ratio(width=width, height=height),
|
||||
"resolution": self._get_resolution(width=width, height=height),
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Grok 封面生成请求开始: url=%s model=%s width=%s height=%s prompt_len=%s prompt_preview=%s",
|
||||
url,
|
||||
model,
|
||||
width,
|
||||
height,
|
||||
len(prompt or ""),
|
||||
(prompt or "")[:300].replace("\n", " "),
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
|
||||
logger.debug(
|
||||
"Grok 封面生成响应: status=%s content_type=%s headers=%s body_preview=%s",
|
||||
response.status_code,
|
||||
response.headers.get("content-type"),
|
||||
{
|
||||
"x-request-id": response.headers.get("x-request-id"),
|
||||
"cf-ray": response.headers.get("cf-ray"),
|
||||
"openai-processing-ms": response.headers.get("openai-processing-ms"),
|
||||
},
|
||||
response.text[:1000],
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(
|
||||
"Grok 封面生成 HTTP 错误: status=%s response=%s",
|
||||
exc.response.status_code if exc.response else None,
|
||||
exc.response.text[:2000] if exc.response is not None else None,
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
logger.error("Grok 封面生成请求异常", exc_info=True)
|
||||
raise
|
||||
|
||||
images = data.get("data") or []
|
||||
logger.debug(
|
||||
"Grok 封面生成解析结果: has_data=%s image_count=%s keys=%s",
|
||||
bool(data),
|
||||
len(images),
|
||||
list(data.keys()) if isinstance(data, dict) else type(data).__name__,
|
||||
)
|
||||
|
||||
if not images:
|
||||
logger.error("Grok 未返回图片结果: data=%s", data)
|
||||
raise ValueError("Grok 未返回图片结果")
|
||||
|
||||
image_item = images[0]
|
||||
revised_prompt = image_item.get("revised_prompt")
|
||||
logger.debug(
|
||||
"Grok 首张图片结果: keys=%s has_b64=%s has_url=%s revised_prompt_preview=%s",
|
||||
list(image_item.keys()),
|
||||
bool(image_item.get("b64_json")),
|
||||
bool(image_item.get("url")),
|
||||
(revised_prompt or "")[:300],
|
||||
)
|
||||
|
||||
b64_json = image_item.get("b64_json")
|
||||
if b64_json:
|
||||
decoded_content = self._decode_base64_image(b64_json)
|
||||
image_width, image_height = self._detect_image_size(decoded_content)
|
||||
logger.debug(
|
||||
"Grok 返回 base64 图片: bytes=%s mime=image/jpeg size=%sx%s",
|
||||
len(decoded_content),
|
||||
image_width,
|
||||
image_height,
|
||||
)
|
||||
return {
|
||||
"content": decoded_content,
|
||||
"mime_type": "image/jpeg",
|
||||
"file_extension": "jpg",
|
||||
"revised_prompt": revised_prompt,
|
||||
"provider": "grok",
|
||||
"model": model,
|
||||
"image_width": image_width,
|
||||
"image_height": image_height,
|
||||
}
|
||||
|
||||
image_url = image_item.get("url")
|
||||
if image_url:
|
||||
logger.debug("Grok 返回图片 URL,开始下载: %s", image_url)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
image_response = await client.get(image_url)
|
||||
|
||||
logger.debug(
|
||||
"Grok 图片下载响应: status=%s content_type=%s content_length=%s",
|
||||
image_response.status_code,
|
||||
image_response.headers.get("content-type"),
|
||||
image_response.headers.get("content-length"),
|
||||
)
|
||||
image_response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(
|
||||
"Grok 图片下载 HTTP 错误: status=%s response=%s",
|
||||
exc.response.status_code if exc.response else None,
|
||||
exc.response.text[:2000] if exc.response is not None else None,
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
logger.error("Grok 图片下载异常", exc_info=True)
|
||||
raise
|
||||
|
||||
content_type = image_response.headers.get("content-type", "image/jpeg")
|
||||
file_extension = self._guess_extension(content_type=content_type, image_url=image_url)
|
||||
image_width, image_height = self._detect_image_size(image_response.content)
|
||||
logger.debug(
|
||||
"Grok 图片下载完成: bytes=%s extension=%s size=%sx%s",
|
||||
len(image_response.content),
|
||||
file_extension,
|
||||
image_width,
|
||||
image_height,
|
||||
)
|
||||
return {
|
||||
"content": image_response.content,
|
||||
"mime_type": content_type,
|
||||
"file_extension": file_extension,
|
||||
"revised_prompt": revised_prompt,
|
||||
"provider": "grok",
|
||||
"model": model,
|
||||
"image_width": image_width,
|
||||
"image_height": image_height,
|
||||
}
|
||||
|
||||
logger.error("Grok 返回内容中既没有 b64_json,也没有 url: %s", data)
|
||||
raise ValueError("Grok 未返回可用的图片数据")
|
||||
|
||||
@staticmethod
|
||||
def _to_public_result(result: dict[str, Any]) -> CoverGenerationResult:
|
||||
return {
|
||||
"content": result["content"],
|
||||
"mime_type": result["mime_type"],
|
||||
"file_extension": result["file_extension"],
|
||||
"revised_prompt": result.get("revised_prompt"),
|
||||
"provider": result["provider"],
|
||||
"model": result["model"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _detect_image_size(content: bytes) -> tuple[int, int]:
|
||||
if len(content) >= 24 and content[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
width, height = struct.unpack(">II", content[16:24])
|
||||
return int(width), int(height)
|
||||
|
||||
if len(content) >= 2 and content[:2] == b"\xff\xd8":
|
||||
index = 2
|
||||
content_length = len(content)
|
||||
while index < content_length - 1:
|
||||
if content[index] != 0xFF:
|
||||
index += 1
|
||||
continue
|
||||
marker = content[index + 1]
|
||||
index += 2
|
||||
if marker in (0xD8, 0xD9):
|
||||
continue
|
||||
if index + 2 > content_length:
|
||||
break
|
||||
segment_length = struct.unpack(">H", content[index:index + 2])[0]
|
||||
if segment_length < 2 or index + segment_length > content_length:
|
||||
break
|
||||
if marker in {
|
||||
0xC0, 0xC1, 0xC2, 0xC3,
|
||||
0xC5, 0xC6, 0xC7,
|
||||
0xC9, 0xCA, 0xCB,
|
||||
0xCD, 0xCE, 0xCF,
|
||||
}:
|
||||
if index + 7 <= content_length:
|
||||
height, width = struct.unpack(">HH", content[index + 3:index + 7])
|
||||
return int(width), int(height)
|
||||
break
|
||||
index += segment_length
|
||||
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: str) -> bytes:
|
||||
if value.startswith("data:") and "," in value:
|
||||
value = value.split(",", 1)[1]
|
||||
return base64.b64decode(value)
|
||||
|
||||
@staticmethod
|
||||
def _adapt_prompt(*, prompt: str, width: int, height: int) -> str:
|
||||
cleaned_prompt = " ".join((prompt or "").split())
|
||||
return (
|
||||
f"{cleaned_prompt} "
|
||||
f"Use a {width}x{height} vertical composition."
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _get_aspect_ratio(*, width: int, height: int) -> str:
|
||||
if width <= 0 or height <= 0:
|
||||
return "2:3"
|
||||
if width * 3 == height * 2:
|
||||
return "2:3"
|
||||
return f"{width}:{height}"
|
||||
|
||||
@staticmethod
|
||||
def _get_resolution(*, width: int, height: int) -> str:
|
||||
longest_edge = max(width, height)
|
||||
if longest_edge >= 1536:
|
||||
return "2k"
|
||||
return "1k"
|
||||
|
||||
@staticmethod
|
||||
def _guess_extension(*, content_type: str, image_url: str) -> str:
|
||||
lowered_content_type = (content_type or "").lower()
|
||||
lowered_url = (image_url or "").lower()
|
||||
if "png" in lowered_content_type or lowered_url.endswith(".png"):
|
||||
return "png"
|
||||
if "webp" in lowered_content_type or lowered_url.endswith(".webp"):
|
||||
return "webp"
|
||||
return "jpg"
|
||||
Reference in New Issue
Block a user