Files
MuMuAINovel/backend/app/services/cover_providers/grok_cover_provider.py
T
2026-03-16 11:34:07 +08:00

278 lines
10 KiB
Python

"""Grok 封面图片 Provider"""
from __future__ import annotations
import base64
import struct
from typing import Any
import httpx
from app.logger import get_logger
from app.services.cover_providers.base_cover_provider import BaseCoverProvider, CoverGenerationResult
logger = get_logger(__name__)
class GrokCoverProvider(BaseCoverProvider):
"""基于 xAI Grok Images API 的封面生成实现"""
def __init__(self, api_key: str, base_url: str):
self.api_key = api_key
self.base_url = (base_url or "https://api.x.ai/v1").rstrip("/")
async def generate_cover(
self,
*,
prompt: str,
model: str,
width: int,
height: int,
) -> CoverGenerationResult:
result = await self._request_cover(
prompt=prompt,
model=model,
width=width,
height=height,
)
return self._to_public_result(result)
async def _request_cover(
self,
*,
prompt: str,
model: str,
width: int,
height: int,
) -> dict[str, Any]:
url = f"{self.base_url}/images/generations"
payload: dict[str, Any] = {
"model": model,
"prompt": self._adapt_prompt(prompt=prompt, width=width, height=height),
"n": 1,
"response_format": "b64_json",
"aspect_ratio": self._get_aspect_ratio(width=width, height=height),
"resolution": self._get_resolution(width=width, height=height),
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
logger.debug(
"Grok 封面生成请求开始: url=%s model=%s width=%s height=%s prompt_len=%s prompt_preview=%s",
url,
model,
width,
height,
len(prompt or ""),
(prompt or "")[:300].replace("\n", " "),
)
try:
async with httpx.AsyncClient(timeout=180.0) as client:
response = await client.post(url, headers=headers, json=payload)
logger.debug(
"Grok 封面生成响应: status=%s content_type=%s headers=%s body_preview=%s",
response.status_code,
response.headers.get("content-type"),
{
"x-request-id": response.headers.get("x-request-id"),
"cf-ray": response.headers.get("cf-ray"),
"openai-processing-ms": response.headers.get("openai-processing-ms"),
},
response.text[:1000],
)
response.raise_for_status()
data = response.json()
except httpx.HTTPStatusError as exc:
logger.error(
"Grok 封面生成 HTTP 错误: status=%s response=%s",
exc.response.status_code if exc.response else None,
exc.response.text[:2000] if exc.response is not None else None,
)
raise
except Exception:
logger.error("Grok 封面生成请求异常", exc_info=True)
raise
images = data.get("data") or []
logger.debug(
"Grok 封面生成解析结果: has_data=%s image_count=%s keys=%s",
bool(data),
len(images),
list(data.keys()) if isinstance(data, dict) else type(data).__name__,
)
if not images:
logger.error("Grok 未返回图片结果: data=%s", data)
raise ValueError("Grok 未返回图片结果")
image_item = images[0]
revised_prompt = image_item.get("revised_prompt")
logger.debug(
"Grok 首张图片结果: keys=%s has_b64=%s has_url=%s revised_prompt_preview=%s",
list(image_item.keys()),
bool(image_item.get("b64_json")),
bool(image_item.get("url")),
(revised_prompt or "")[:300],
)
b64_json = image_item.get("b64_json")
if b64_json:
decoded_content = self._decode_base64_image(b64_json)
image_width, image_height = self._detect_image_size(decoded_content)
logger.debug(
"Grok 返回 base64 图片: bytes=%s mime=image/jpeg size=%sx%s",
len(decoded_content),
image_width,
image_height,
)
return {
"content": decoded_content,
"mime_type": "image/jpeg",
"file_extension": "jpg",
"revised_prompt": revised_prompt,
"provider": "grok",
"model": model,
"image_width": image_width,
"image_height": image_height,
}
image_url = image_item.get("url")
if image_url:
logger.debug("Grok 返回图片 URL,开始下载: %s", image_url)
try:
async with httpx.AsyncClient(timeout=180.0) as client:
image_response = await client.get(image_url)
logger.debug(
"Grok 图片下载响应: status=%s content_type=%s content_length=%s",
image_response.status_code,
image_response.headers.get("content-type"),
image_response.headers.get("content-length"),
)
image_response.raise_for_status()
except httpx.HTTPStatusError as exc:
logger.error(
"Grok 图片下载 HTTP 错误: status=%s response=%s",
exc.response.status_code if exc.response else None,
exc.response.text[:2000] if exc.response is not None else None,
)
raise
except Exception:
logger.error("Grok 图片下载异常", exc_info=True)
raise
content_type = image_response.headers.get("content-type", "image/jpeg")
file_extension = self._guess_extension(content_type=content_type, image_url=image_url)
image_width, image_height = self._detect_image_size(image_response.content)
logger.debug(
"Grok 图片下载完成: bytes=%s extension=%s size=%sx%s",
len(image_response.content),
file_extension,
image_width,
image_height,
)
return {
"content": image_response.content,
"mime_type": content_type,
"file_extension": file_extension,
"revised_prompt": revised_prompt,
"provider": "grok",
"model": model,
"image_width": image_width,
"image_height": image_height,
}
logger.error("Grok 返回内容中既没有 b64_json,也没有 url: %s", data)
raise ValueError("Grok 未返回可用的图片数据")
@staticmethod
def _to_public_result(result: dict[str, Any]) -> CoverGenerationResult:
return {
"content": result["content"],
"mime_type": result["mime_type"],
"file_extension": result["file_extension"],
"revised_prompt": result.get("revised_prompt"),
"provider": result["provider"],
"model": result["model"],
}
@staticmethod
def _detect_image_size(content: bytes) -> tuple[int, int]:
if len(content) >= 24 and content[:8] == b"\x89PNG\r\n\x1a\n":
width, height = struct.unpack(">II", content[16:24])
return int(width), int(height)
if len(content) >= 2 and content[:2] == b"\xff\xd8":
index = 2
content_length = len(content)
while index < content_length - 1:
if content[index] != 0xFF:
index += 1
continue
marker = content[index + 1]
index += 2
if marker in (0xD8, 0xD9):
continue
if index + 2 > content_length:
break
segment_length = struct.unpack(">H", content[index:index + 2])[0]
if segment_length < 2 or index + segment_length > content_length:
break
if marker in {
0xC0, 0xC1, 0xC2, 0xC3,
0xC5, 0xC6, 0xC7,
0xC9, 0xCA, 0xCB,
0xCD, 0xCE, 0xCF,
}:
if index + 7 <= content_length:
height, width = struct.unpack(">HH", content[index + 3:index + 7])
return int(width), int(height)
break
index += segment_length
return 0, 0
@staticmethod
def _decode_base64_image(value: str) -> bytes:
if value.startswith("data:") and "," in value:
value = value.split(",", 1)[1]
return base64.b64decode(value)
@staticmethod
def _adapt_prompt(*, prompt: str, width: int, height: int) -> str:
cleaned_prompt = " ".join((prompt or "").split())
return (
f"{cleaned_prompt} "
f"Use a {width}x{height} vertical composition."
).strip()
@staticmethod
def _get_aspect_ratio(*, width: int, height: int) -> str:
if width <= 0 or height <= 0:
return "2:3"
if width * 3 == height * 2:
return "2:3"
return f"{width}:{height}"
@staticmethod
def _get_resolution(*, width: int, height: int) -> str:
longest_edge = max(width, height)
if longest_edge >= 1536:
return "2k"
return "1k"
@staticmethod
def _guess_extension(*, content_type: str, image_url: str) -> str:
lowered_content_type = (content_type or "").lower()
lowered_url = (image_url or "").lower()
if "png" in lowered_content_type or lowered_url.endswith(".png"):
return "png"
if "webp" in lowered_content_type or lowered_url.endswith(".webp"):
return "webp"
return "jpg"