278 lines
10 KiB
Python
278 lines
10 KiB
Python
"""Grok 封面图片 Provider"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import struct
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from app.logger import get_logger
|
|
from app.services.cover_providers.base_cover_provider import BaseCoverProvider, CoverGenerationResult
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class GrokCoverProvider(BaseCoverProvider):
|
|
"""基于 xAI Grok Images API 的封面生成实现"""
|
|
|
|
def __init__(self, api_key: str, base_url: str):
|
|
self.api_key = api_key
|
|
self.base_url = (base_url or "https://api.x.ai/v1").rstrip("/")
|
|
|
|
async def generate_cover(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
width: int,
|
|
height: int,
|
|
) -> CoverGenerationResult:
|
|
result = await self._request_cover(
|
|
prompt=prompt,
|
|
model=model,
|
|
width=width,
|
|
height=height,
|
|
)
|
|
return self._to_public_result(result)
|
|
|
|
async def _request_cover(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
model: str,
|
|
width: int,
|
|
height: int,
|
|
) -> dict[str, Any]:
|
|
url = f"{self.base_url}/images/generations"
|
|
payload: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": self._adapt_prompt(prompt=prompt, width=width, height=height),
|
|
"n": 1,
|
|
"response_format": "b64_json",
|
|
"aspect_ratio": self._get_aspect_ratio(width=width, height=height),
|
|
"resolution": self._get_resolution(width=width, height=height),
|
|
}
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
logger.debug(
|
|
"Grok 封面生成请求开始: url=%s model=%s width=%s height=%s prompt_len=%s prompt_preview=%s",
|
|
url,
|
|
model,
|
|
width,
|
|
height,
|
|
len(prompt or ""),
|
|
(prompt or "")[:300].replace("\n", " "),
|
|
)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=180.0) as client:
|
|
response = await client.post(url, headers=headers, json=payload)
|
|
|
|
logger.debug(
|
|
"Grok 封面生成响应: status=%s content_type=%s headers=%s body_preview=%s",
|
|
response.status_code,
|
|
response.headers.get("content-type"),
|
|
{
|
|
"x-request-id": response.headers.get("x-request-id"),
|
|
"cf-ray": response.headers.get("cf-ray"),
|
|
"openai-processing-ms": response.headers.get("openai-processing-ms"),
|
|
},
|
|
response.text[:1000],
|
|
)
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
except httpx.HTTPStatusError as exc:
|
|
logger.error(
|
|
"Grok 封面生成 HTTP 错误: status=%s response=%s",
|
|
exc.response.status_code if exc.response else None,
|
|
exc.response.text[:2000] if exc.response is not None else None,
|
|
)
|
|
raise
|
|
except Exception:
|
|
logger.error("Grok 封面生成请求异常", exc_info=True)
|
|
raise
|
|
|
|
images = data.get("data") or []
|
|
logger.debug(
|
|
"Grok 封面生成解析结果: has_data=%s image_count=%s keys=%s",
|
|
bool(data),
|
|
len(images),
|
|
list(data.keys()) if isinstance(data, dict) else type(data).__name__,
|
|
)
|
|
|
|
if not images:
|
|
logger.error("Grok 未返回图片结果: data=%s", data)
|
|
raise ValueError("Grok 未返回图片结果")
|
|
|
|
image_item = images[0]
|
|
revised_prompt = image_item.get("revised_prompt")
|
|
logger.debug(
|
|
"Grok 首张图片结果: keys=%s has_b64=%s has_url=%s revised_prompt_preview=%s",
|
|
list(image_item.keys()),
|
|
bool(image_item.get("b64_json")),
|
|
bool(image_item.get("url")),
|
|
(revised_prompt or "")[:300],
|
|
)
|
|
|
|
b64_json = image_item.get("b64_json")
|
|
if b64_json:
|
|
decoded_content = self._decode_base64_image(b64_json)
|
|
image_width, image_height = self._detect_image_size(decoded_content)
|
|
logger.debug(
|
|
"Grok 返回 base64 图片: bytes=%s mime=image/jpeg size=%sx%s",
|
|
len(decoded_content),
|
|
image_width,
|
|
image_height,
|
|
)
|
|
return {
|
|
"content": decoded_content,
|
|
"mime_type": "image/jpeg",
|
|
"file_extension": "jpg",
|
|
"revised_prompt": revised_prompt,
|
|
"provider": "grok",
|
|
"model": model,
|
|
"image_width": image_width,
|
|
"image_height": image_height,
|
|
}
|
|
|
|
image_url = image_item.get("url")
|
|
if image_url:
|
|
logger.debug("Grok 返回图片 URL,开始下载: %s", image_url)
|
|
try:
|
|
async with httpx.AsyncClient(timeout=180.0) as client:
|
|
image_response = await client.get(image_url)
|
|
|
|
logger.debug(
|
|
"Grok 图片下载响应: status=%s content_type=%s content_length=%s",
|
|
image_response.status_code,
|
|
image_response.headers.get("content-type"),
|
|
image_response.headers.get("content-length"),
|
|
)
|
|
image_response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
logger.error(
|
|
"Grok 图片下载 HTTP 错误: status=%s response=%s",
|
|
exc.response.status_code if exc.response else None,
|
|
exc.response.text[:2000] if exc.response is not None else None,
|
|
)
|
|
raise
|
|
except Exception:
|
|
logger.error("Grok 图片下载异常", exc_info=True)
|
|
raise
|
|
|
|
content_type = image_response.headers.get("content-type", "image/jpeg")
|
|
file_extension = self._guess_extension(content_type=content_type, image_url=image_url)
|
|
image_width, image_height = self._detect_image_size(image_response.content)
|
|
logger.debug(
|
|
"Grok 图片下载完成: bytes=%s extension=%s size=%sx%s",
|
|
len(image_response.content),
|
|
file_extension,
|
|
image_width,
|
|
image_height,
|
|
)
|
|
return {
|
|
"content": image_response.content,
|
|
"mime_type": content_type,
|
|
"file_extension": file_extension,
|
|
"revised_prompt": revised_prompt,
|
|
"provider": "grok",
|
|
"model": model,
|
|
"image_width": image_width,
|
|
"image_height": image_height,
|
|
}
|
|
|
|
logger.error("Grok 返回内容中既没有 b64_json,也没有 url: %s", data)
|
|
raise ValueError("Grok 未返回可用的图片数据")
|
|
|
|
@staticmethod
|
|
def _to_public_result(result: dict[str, Any]) -> CoverGenerationResult:
|
|
return {
|
|
"content": result["content"],
|
|
"mime_type": result["mime_type"],
|
|
"file_extension": result["file_extension"],
|
|
"revised_prompt": result.get("revised_prompt"),
|
|
"provider": result["provider"],
|
|
"model": result["model"],
|
|
}
|
|
|
|
@staticmethod
|
|
def _detect_image_size(content: bytes) -> tuple[int, int]:
|
|
if len(content) >= 24 and content[:8] == b"\x89PNG\r\n\x1a\n":
|
|
width, height = struct.unpack(">II", content[16:24])
|
|
return int(width), int(height)
|
|
|
|
if len(content) >= 2 and content[:2] == b"\xff\xd8":
|
|
index = 2
|
|
content_length = len(content)
|
|
while index < content_length - 1:
|
|
if content[index] != 0xFF:
|
|
index += 1
|
|
continue
|
|
marker = content[index + 1]
|
|
index += 2
|
|
if marker in (0xD8, 0xD9):
|
|
continue
|
|
if index + 2 > content_length:
|
|
break
|
|
segment_length = struct.unpack(">H", content[index:index + 2])[0]
|
|
if segment_length < 2 or index + segment_length > content_length:
|
|
break
|
|
if marker in {
|
|
0xC0, 0xC1, 0xC2, 0xC3,
|
|
0xC5, 0xC6, 0xC7,
|
|
0xC9, 0xCA, 0xCB,
|
|
0xCD, 0xCE, 0xCF,
|
|
}:
|
|
if index + 7 <= content_length:
|
|
height, width = struct.unpack(">HH", content[index + 3:index + 7])
|
|
return int(width), int(height)
|
|
break
|
|
index += segment_length
|
|
|
|
return 0, 0
|
|
|
|
@staticmethod
|
|
def _decode_base64_image(value: str) -> bytes:
|
|
if value.startswith("data:") and "," in value:
|
|
value = value.split(",", 1)[1]
|
|
return base64.b64decode(value)
|
|
|
|
@staticmethod
|
|
def _adapt_prompt(*, prompt: str, width: int, height: int) -> str:
|
|
cleaned_prompt = " ".join((prompt or "").split())
|
|
return (
|
|
f"{cleaned_prompt} "
|
|
f"Use a {width}x{height} vertical composition."
|
|
).strip()
|
|
|
|
@staticmethod
|
|
def _get_aspect_ratio(*, width: int, height: int) -> str:
|
|
if width <= 0 or height <= 0:
|
|
return "2:3"
|
|
if width * 3 == height * 2:
|
|
return "2:3"
|
|
return f"{width}:{height}"
|
|
|
|
@staticmethod
|
|
def _get_resolution(*, width: int, height: int) -> str:
|
|
longest_edge = max(width, height)
|
|
if longest_edge >= 1536:
|
|
return "2k"
|
|
return "1k"
|
|
|
|
@staticmethod
|
|
def _guess_extension(*, content_type: str, image_url: str) -> str:
|
|
lowered_content_type = (content_type or "").lower()
|
|
lowered_url = (image_url or "").lower()
|
|
if "png" in lowered_content_type or lowered_url.endswith(".png"):
|
|
return "png"
|
|
if "webp" in lowered_content_type or lowered_url.endswith(".webp"):
|
|
return "webp"
|
|
return "jpg"
|