feature: 新增小说封面图片生成功能
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
"""项目封面生成与下载 API"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.services.cover_generation_service import cover_generation_service
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["项目封面"])
|
||||
|
||||
|
||||
class CoverGenerateRequest(BaseModel):
|
||||
overwrite: bool = Field(default=True, description="是否覆盖已有封面")
|
||||
|
||||
|
||||
class CoverGenerateResponse(BaseModel):
|
||||
project_id: str
|
||||
cover_status: str
|
||||
cover_image_url: str | None = None
|
||||
cover_prompt: str | None = None
|
||||
provider: str | None = None
|
||||
model: str | None = None
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/{project_id}/cover/generate", response_model=CoverGenerateResponse, summary="生成项目封面")
|
||||
async def generate_project_cover(
|
||||
project_id: str,
|
||||
payload: CoverGenerateRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await cover_generation_service.generate_cover(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
overwrite=payload.overwrite,
|
||||
)
|
||||
return CoverGenerateResponse(**result)
|
||||
|
||||
|
||||
@router.get("/{project_id}/cover/download", summary="下载项目封面")
|
||||
async def download_project_cover(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
project, file_path = await cover_generation_service.get_cover_download_path(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
)
|
||||
suffix = file_path.suffix or ".png"
|
||||
filename = f"{project.title}-cover{suffix}"
|
||||
return FileResponse(path=file_path, filename=filename, media_type="application/octet-stream")
|
||||
@@ -14,6 +14,7 @@ import time
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.settings import Settings
|
||||
from app.services.cover_generation_service import cover_generation_service
|
||||
from app.schemas.settings import (
|
||||
SettingsCreate, SettingsUpdate, SettingsResponse,
|
||||
APIKeyPreset, APIKeyPresetConfig, PresetCreateRequest,
|
||||
@@ -29,6 +30,13 @@ logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/settings", tags=["设置管理"])
|
||||
|
||||
|
||||
class CoverSettingsTestRequest(BaseModel):
|
||||
cover_api_provider: str
|
||||
cover_api_key: str
|
||||
cover_api_base_url: Optional[str] = None
|
||||
cover_image_model: str
|
||||
|
||||
|
||||
def read_env_defaults() -> Dict[str, Any]:
|
||||
"""从.env文件读取默认配置(仅读取,不修改)"""
|
||||
return {
|
||||
@@ -142,6 +150,25 @@ async def get_settings(
|
||||
return settings
|
||||
|
||||
|
||||
@router.post("/cover/test")
|
||||
async def test_cover_settings(
|
||||
data: CoverSettingsTestRequest,
|
||||
user: User = Depends(require_login),
|
||||
):
|
||||
result = await cover_generation_service.test_cover_settings(
|
||||
provider=data.cover_api_provider,
|
||||
api_key=data.cover_api_key,
|
||||
api_base_url=data.cover_api_base_url,
|
||||
model=data.cover_image_model,
|
||||
)
|
||||
return {
|
||||
"success": result.success,
|
||||
"message": result.message,
|
||||
"provider": result.provider,
|
||||
"model": result.model,
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=SettingsResponse)
|
||||
async def save_settings(
|
||||
data: SettingsCreate,
|
||||
|
||||
+7
-1
@@ -130,7 +130,8 @@ from app.api import (
|
||||
wizard_stream, relationships, organizations,
|
||||
auth, users, settings, writing_styles, memories,
|
||||
mcp_plugins, admin, inspiration, prompt_templates,
|
||||
changelog, careers, foreshadows, prompt_workshop, book_import
|
||||
changelog, careers, foreshadows, prompt_workshop, book_import,
|
||||
project_covers
|
||||
)
|
||||
|
||||
app.include_router(auth.router, prefix="/api")
|
||||
@@ -139,6 +140,7 @@ app.include_router(settings.router, prefix="/api")
|
||||
app.include_router(admin.router, prefix="/api")
|
||||
|
||||
app.include_router(projects.router, prefix="/api")
|
||||
app.include_router(project_covers.router, prefix="/api")
|
||||
app.include_router(wizard_stream.router, prefix="/api")
|
||||
app.include_router(inspiration.router, prefix="/api")
|
||||
app.include_router(outlines.router, prefix="/api")
|
||||
@@ -157,8 +159,12 @@ app.include_router(prompt_workshop.router, prefix="/api") # 提示词工坊API
|
||||
app.include_router(book_import.router, prefix="/api") # 拆书导入API
|
||||
|
||||
static_dir = Path(__file__).parent.parent / "static"
|
||||
generated_assets_root_dir = Path(__file__).parent.parent / "storage"
|
||||
generated_covers_dir = generated_assets_root_dir / "generated_covers"
|
||||
generated_covers_dir.mkdir(parents=True, exist_ok=True)
|
||||
if static_dir.exists():
|
||||
app.mount("/assets", StaticFiles(directory=str(static_dir / "assets")), name="assets")
|
||||
app.mount("/generated-assets/covers", StaticFiles(directory=str(generated_covers_dir)), name="generated-covers")
|
||||
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
|
||||
@@ -32,6 +32,13 @@ class Project(Base):
|
||||
chapter_count = Column(Integer, comment="章节数量")
|
||||
narrative_perspective = Column(String(50), comment="叙事视角:first_person/third_person/omniscient")
|
||||
character_count = Column(Integer, default=5, comment="角色数量")
|
||||
|
||||
# 封面字段
|
||||
cover_image_url = Column(String(1000), comment="封面图片访问地址")
|
||||
cover_prompt = Column(Text, comment="最近一次生成封面使用的提示词")
|
||||
cover_status = Column(String(20), default="none", nullable=False, comment="封面状态: none/generating/ready/failed")
|
||||
cover_error = Column(Text, comment="最近一次封面生成失败原因")
|
||||
cover_updated_at = Column(DateTime, comment="最近一次封面生成成功时间")
|
||||
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
@@ -41,7 +48,11 @@ class Project(Base):
|
||||
"outline_mode IN ('one-to-one', 'one-to-many')",
|
||||
name='check_outline_mode'
|
||||
),
|
||||
CheckConstraint(
|
||||
"cover_status IN ('none', 'generating', 'ready', 'failed')",
|
||||
name='check_cover_status'
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Project(id={self.id}, title={self.title})>"
|
||||
return f"<Project(id={self.id}, title={self.title})>"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""设置数据模型"""
|
||||
from sqlalchemy import Column, String, Text, Float, Integer, DateTime, Index
|
||||
from sqlalchemy import Column, String, Text, Float, Integer, DateTime, Boolean, Index
|
||||
from sqlalchemy.sql import func
|
||||
from app.database import Base
|
||||
import uuid
|
||||
@@ -18,6 +18,14 @@ class Settings(Base):
|
||||
temperature = Column(Float, default=0.7, comment="温度参数")
|
||||
max_tokens = Column(Integer, default=2000, comment="最大token数")
|
||||
system_prompt = Column(Text, comment="系统级别提示词,每次AI调用都会使用")
|
||||
|
||||
# 封面图片生成配置
|
||||
cover_api_provider = Column(String(50), comment="封面图片API提供商")
|
||||
cover_api_key = Column(String(500), comment="封面图片API密钥")
|
||||
cover_api_base_url = Column(String(500), comment="封面图片自定义API地址")
|
||||
cover_image_model = Column(String(100), comment="封面图片模型名称")
|
||||
cover_enabled = Column(Boolean, default=False, server_default="0", nullable=False, comment="是否启用封面图片生成")
|
||||
|
||||
preferences = Column(Text, comment="其他偏好设置(JSON)")
|
||||
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
|
||||
@@ -27,4 +35,4 @@ class Settings(Base):
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Settings(id={self.id}, user_id={self.user_id}, api_provider={self.api_provider})>"
|
||||
return f"<Settings(id={self.id}, user_id={self.user_id}, api_provider={self.api_provider})>"
|
||||
|
||||
@@ -55,6 +55,11 @@ class ProjectResponse(ProjectBase):
|
||||
chapter_count: Optional[int] = None
|
||||
narrative_perspective: Optional[str] = None
|
||||
character_count: Optional[int] = None
|
||||
cover_image_url: Optional[str] = None
|
||||
cover_prompt: Optional[str] = None
|
||||
cover_status: Optional[str] = None
|
||||
cover_error: Optional[str] = None
|
||||
cover_updated_at: Optional[datetime] = None
|
||||
outline_mode: str # 显式声明以确保响应中包含
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@@ -15,6 +15,11 @@ class SettingsBase(BaseModel):
|
||||
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="温度参数")
|
||||
max_tokens: Optional[int] = Field(default=2000, ge=1, description="最大token数")
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统级别提示词,每次AI调用都会使用")
|
||||
cover_api_provider: Optional[str] = Field(default=None, description="封面图片API提供商")
|
||||
cover_api_key: Optional[str] = Field(default=None, description="封面图片API密钥")
|
||||
cover_api_base_url: Optional[str] = Field(default=None, description="封面图片自定义API地址")
|
||||
cover_image_model: Optional[str] = Field(default=None, description="封面图片模型名称")
|
||||
cover_enabled: Optional[bool] = Field(default=False, description="是否启用封面图片生成")
|
||||
preferences: Optional[str] = Field(default=None, description="其他偏好设置(JSON)")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
"""小说封面生成服务"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import PROJECT_ROOT
|
||||
from app.logger import get_logger
|
||||
from app.models.project import Project
|
||||
from app.models.settings import Settings
|
||||
from app.services.cover_providers.base_cover_provider import BaseCoverProvider, CoverGenerationResult
|
||||
from app.services.cover_providers.gemini_cover_provider import GeminiCoverProvider
|
||||
from app.services.cover_providers.grok_cover_provider import GrokCoverProvider
|
||||
from app.services.prompt_service import PromptService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
COVER_WIDTH = 1024
|
||||
COVER_HEIGHT = 1536
|
||||
GENERATED_COVER_STORAGE_DIR = PROJECT_ROOT / "storage" / "generated_covers"
|
||||
GENERATED_COVER_PUBLIC_PREFIX = "/generated-assets/covers"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoverTestResult:
|
||||
success: bool
|
||||
message: str
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
class CoverGenerationService:
|
||||
"""封面生成服务"""
|
||||
|
||||
async def generate_cover(
|
||||
self,
|
||||
*,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
overwrite: bool = True,
|
||||
) -> dict:
|
||||
project = await self._get_project(db=db, user_id=user_id, project_id=project_id)
|
||||
settings = await self._get_settings(db=db, user_id=user_id)
|
||||
self._validate_cover_settings(settings)
|
||||
|
||||
if project.cover_status == "generating":
|
||||
raise HTTPException(status_code=409, detail="封面正在生成中,请勿重复提交")
|
||||
if project.cover_status == "ready" and project.cover_image_url and not overwrite:
|
||||
raise HTTPException(status_code=400, detail="当前项目已存在封面,如需覆盖请传入 overwrite=true")
|
||||
|
||||
prompt = await PromptService.build_novel_cover_prompt(
|
||||
project,
|
||||
user_id=user_id,
|
||||
db=db,
|
||||
)
|
||||
project.cover_status = "generating"
|
||||
project.cover_error = None
|
||||
project.cover_prompt = prompt
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
|
||||
try:
|
||||
provider = self._build_provider(settings)
|
||||
result = await provider.generate_cover(
|
||||
prompt=prompt,
|
||||
model=settings.cover_image_model or "",
|
||||
width=COVER_WIDTH,
|
||||
height=COVER_HEIGHT,
|
||||
)
|
||||
image_url = self._save_cover_file(
|
||||
user_id=user_id,
|
||||
project_id=project.id,
|
||||
content=result["content"],
|
||||
file_extension=result["file_extension"],
|
||||
)
|
||||
|
||||
project.cover_image_url = image_url
|
||||
project.cover_status = "ready"
|
||||
project.cover_error = None
|
||||
project.cover_updated_at = datetime.utcnow()
|
||||
project.cover_prompt = result.get("revised_prompt") or prompt
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
|
||||
return {
|
||||
"project_id": project.id,
|
||||
"cover_status": project.cover_status,
|
||||
"cover_image_url": project.cover_image_url,
|
||||
"cover_prompt": project.cover_prompt,
|
||||
"provider": result["provider"],
|
||||
"model": result["model"],
|
||||
"message": "封面生成成功",
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error("封面生成失败: project_id=%s error=%s", project.id, exc, exc_info=True)
|
||||
project.cover_status = "failed"
|
||||
project.cover_error = str(exc)
|
||||
await db.commit()
|
||||
raise HTTPException(status_code=500, detail=f"封面生成失败: {exc}") from exc
|
||||
|
||||
async def test_cover_settings(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base_url: Optional[str],
|
||||
model: str,
|
||||
) -> CoverTestResult:
|
||||
if not provider or not api_key or not model:
|
||||
raise HTTPException(status_code=400, detail="封面图片配置不完整,请填写 provider、api_key 和 model")
|
||||
|
||||
provider_instance = self._build_provider_from_values(
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
)
|
||||
test_prompt = (
|
||||
"Create a clean fantasy novel cover illustration, vertical book cover, "
|
||||
"standard 2:3 ratio, atmospheric lighting, no text, no watermark."
|
||||
)
|
||||
await provider_instance.generate_cover(
|
||||
prompt=test_prompt,
|
||||
model=model,
|
||||
width=COVER_WIDTH,
|
||||
height=COVER_HEIGHT,
|
||||
)
|
||||
return CoverTestResult(
|
||||
success=True,
|
||||
message="封面图片接口测试成功",
|
||||
provider=provider,
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def get_cover_download_path(
|
||||
self,
|
||||
*,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
) -> tuple[Project, Path]:
|
||||
project = await self._get_project(db=db, user_id=user_id, project_id=project_id)
|
||||
if project.cover_status != "ready" or not project.cover_image_url:
|
||||
raise HTTPException(status_code=404, detail="当前项目尚未生成可下载的封面")
|
||||
|
||||
absolute_path = self._resolve_cover_path(project.cover_image_url)
|
||||
if not absolute_path.exists():
|
||||
raise HTTPException(status_code=404, detail="封面文件不存在,请重新生成")
|
||||
return project, absolute_path
|
||||
|
||||
async def clear_cover_metadata(self, *, db: AsyncSession, project: Project) -> None:
|
||||
project.cover_image_url = None
|
||||
project.cover_prompt = None
|
||||
project.cover_status = "none"
|
||||
project.cover_error = None
|
||||
project.cover_updated_at = None
|
||||
await db.commit()
|
||||
|
||||
async def _get_project(self, *, db: AsyncSession, user_id: str, project_id: str) -> Project:
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id, Project.user_id == user_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
return project
|
||||
|
||||
async def _get_settings(self, *, db: AsyncSession, user_id: str) -> Settings:
|
||||
result = await db.execute(select(Settings).where(Settings.user_id == user_id))
|
||||
settings = result.scalar_one_or_none()
|
||||
if not settings:
|
||||
raise HTTPException(status_code=400, detail="请先在设置页完成封面图片配置")
|
||||
return settings
|
||||
|
||||
def _validate_cover_settings(self, settings: Settings) -> None:
|
||||
if not settings.cover_enabled:
|
||||
raise HTTPException(status_code=400, detail="封面图片功能未启用,请先在设置页开启")
|
||||
if not settings.cover_api_provider or not settings.cover_api_key or not settings.cover_image_model:
|
||||
raise HTTPException(status_code=400, detail="封面图片配置不完整,请前往设置页补全")
|
||||
|
||||
def _build_provider(self, settings: Settings) -> BaseCoverProvider:
|
||||
return self._build_provider_from_values(
|
||||
provider=settings.cover_api_provider or "",
|
||||
api_key=settings.cover_api_key or "",
|
||||
api_base_url=settings.cover_api_base_url,
|
||||
)
|
||||
|
||||
def _build_provider_from_values(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base_url: Optional[str],
|
||||
) -> BaseCoverProvider:
|
||||
provider_value = (provider or "").lower().strip()
|
||||
normalized_base_url = (api_base_url or "").rstrip("/")
|
||||
if provider_value == "gemini":
|
||||
return GeminiCoverProvider(api_key=api_key, base_url=normalized_base_url)
|
||||
if provider_value == "grok":
|
||||
return GrokCoverProvider(api_key=api_key, base_url=normalized_base_url)
|
||||
if provider_value == "mumu":
|
||||
if normalized_base_url.endswith("/v1beta"):
|
||||
return GeminiCoverProvider(api_key=api_key, base_url=normalized_base_url)
|
||||
return GrokCoverProvider(api_key=api_key, base_url=normalized_base_url or "https://api.mumuverse.space/v1")
|
||||
raise HTTPException(status_code=400, detail="当前版本仅支持 Gemini、Grok 或 MuMuのAPI 作为封面图片 Provider")
|
||||
|
||||
def _save_cover_file(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
content: bytes,
|
||||
file_extension: str,
|
||||
) -> str:
|
||||
user_dir = GENERATED_COVER_STORAGE_DIR / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d%H%M%S")
|
||||
safe_extension = (file_extension or "png").lstrip(".")
|
||||
filename = f"{project_id}_{timestamp}.{safe_extension}"
|
||||
file_path = user_dir / filename
|
||||
file_path.write_bytes(content)
|
||||
logger.info("封面文件已保存: project_id=%s path=%s", project_id, file_path)
|
||||
return f"{GENERATED_COVER_PUBLIC_PREFIX}/{quote(user_id)}/{quote(filename)}"
|
||||
|
||||
def _resolve_cover_path(self, cover_image_url: Optional[str]) -> Path:
|
||||
if not cover_image_url:
|
||||
raise HTTPException(status_code=404, detail="当前项目尚未生成可下载的封面")
|
||||
|
||||
if cover_image_url.startswith(f"{GENERATED_COVER_PUBLIC_PREFIX}/"):
|
||||
relative_path = cover_image_url.replace(f"{GENERATED_COVER_PUBLIC_PREFIX}/", "", 1)
|
||||
return GENERATED_COVER_STORAGE_DIR / relative_path
|
||||
|
||||
if cover_image_url.startswith("/assets/generated_covers/"):
|
||||
relative_path = cover_image_url.replace("/assets/generated_covers/", "", 1)
|
||||
return GENERATED_COVER_STORAGE_DIR / relative_path
|
||||
|
||||
raise HTTPException(status_code=404, detail="封面文件路径无效,请重新生成")
|
||||
|
||||
|
||||
cover_generation_service = CoverGenerationService()
|
||||
@@ -0,0 +1,32 @@
|
||||
"""封面图片 Provider 抽象基类"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
|
||||
class CoverGenerationResult(TypedDict):
|
||||
"""封面生成结果"""
|
||||
|
||||
content: bytes
|
||||
mime_type: str
|
||||
file_extension: str
|
||||
revised_prompt: Optional[str]
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class BaseCoverProvider(ABC):
|
||||
"""封面图片 Provider 抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_cover(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> CoverGenerationResult:
|
||||
"""生成封面图片"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Gemini 封面图片 Provider"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.cover_providers.base_cover_provider import BaseCoverProvider, CoverGenerationResult
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GeminiCoverProvider(BaseCoverProvider):
|
||||
"""基于 Gemini API 的封面生成实现"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||
|
||||
async def generate_cover(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> CoverGenerationResult:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||
payload: dict[str, Any] = {
|
||||
"contents": [{
|
||||
"role": "user",
|
||||
"parts": [{
|
||||
"text": (
|
||||
f"{prompt}\n\n"
|
||||
f"Generate a final cover image at {width}x{height} pixels. "
|
||||
"Return one final cover image."
|
||||
)
|
||||
}]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 0.4,
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
candidates = data.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("Gemini 未返回候选结果")
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
inline_data = part.get("inlineData")
|
||||
if not inline_data:
|
||||
continue
|
||||
|
||||
mime_type = inline_data.get("mimeType", "image/png")
|
||||
image_data = inline_data.get("data")
|
||||
if not image_data:
|
||||
continue
|
||||
|
||||
file_extension = "png" if "png" in mime_type else "jpg"
|
||||
return {
|
||||
"content": base64.b64decode(image_data),
|
||||
"mime_type": mime_type,
|
||||
"file_extension": file_extension,
|
||||
"revised_prompt": None,
|
||||
"provider": "gemini",
|
||||
"model": model,
|
||||
}
|
||||
|
||||
logger.error("Gemini 返回内容中未找到 inlineData 图像数据")
|
||||
raise ValueError("Gemini 未返回图片数据")
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Grok 封面图片 Provider"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.cover_providers.base_cover_provider import BaseCoverProvider, CoverGenerationResult
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GrokCoverProvider(BaseCoverProvider):
|
||||
"""基于 xAI Grok Images API 的封面生成实现"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://api.x.ai/v1").rstrip("/")
|
||||
|
||||
async def generate_cover(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> CoverGenerationResult:
|
||||
result = await self._request_cover(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
return self._to_public_result(result)
|
||||
|
||||
async def _request_cover(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> dict[str, Any]:
|
||||
url = f"{self.base_url}/images/generations"
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"prompt": self._adapt_prompt(prompt=prompt, width=width, height=height),
|
||||
"n": 1,
|
||||
"response_format": "b64_json",
|
||||
"aspect_ratio": self._get_aspect_ratio(width=width, height=height),
|
||||
"resolution": self._get_resolution(width=width, height=height),
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Grok 封面生成请求开始: url=%s model=%s width=%s height=%s prompt_len=%s prompt_preview=%s",
|
||||
url,
|
||||
model,
|
||||
width,
|
||||
height,
|
||||
len(prompt or ""),
|
||||
(prompt or "")[:300].replace("\n", " "),
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
|
||||
logger.debug(
|
||||
"Grok 封面生成响应: status=%s content_type=%s headers=%s body_preview=%s",
|
||||
response.status_code,
|
||||
response.headers.get("content-type"),
|
||||
{
|
||||
"x-request-id": response.headers.get("x-request-id"),
|
||||
"cf-ray": response.headers.get("cf-ray"),
|
||||
"openai-processing-ms": response.headers.get("openai-processing-ms"),
|
||||
},
|
||||
response.text[:1000],
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(
|
||||
"Grok 封面生成 HTTP 错误: status=%s response=%s",
|
||||
exc.response.status_code if exc.response else None,
|
||||
exc.response.text[:2000] if exc.response is not None else None,
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
logger.error("Grok 封面生成请求异常", exc_info=True)
|
||||
raise
|
||||
|
||||
images = data.get("data") or []
|
||||
logger.debug(
|
||||
"Grok 封面生成解析结果: has_data=%s image_count=%s keys=%s",
|
||||
bool(data),
|
||||
len(images),
|
||||
list(data.keys()) if isinstance(data, dict) else type(data).__name__,
|
||||
)
|
||||
|
||||
if not images:
|
||||
logger.error("Grok 未返回图片结果: data=%s", data)
|
||||
raise ValueError("Grok 未返回图片结果")
|
||||
|
||||
image_item = images[0]
|
||||
revised_prompt = image_item.get("revised_prompt")
|
||||
logger.debug(
|
||||
"Grok 首张图片结果: keys=%s has_b64=%s has_url=%s revised_prompt_preview=%s",
|
||||
list(image_item.keys()),
|
||||
bool(image_item.get("b64_json")),
|
||||
bool(image_item.get("url")),
|
||||
(revised_prompt or "")[:300],
|
||||
)
|
||||
|
||||
b64_json = image_item.get("b64_json")
|
||||
if b64_json:
|
||||
decoded_content = self._decode_base64_image(b64_json)
|
||||
image_width, image_height = self._detect_image_size(decoded_content)
|
||||
logger.debug(
|
||||
"Grok 返回 base64 图片: bytes=%s mime=image/jpeg size=%sx%s",
|
||||
len(decoded_content),
|
||||
image_width,
|
||||
image_height,
|
||||
)
|
||||
return {
|
||||
"content": decoded_content,
|
||||
"mime_type": "image/jpeg",
|
||||
"file_extension": "jpg",
|
||||
"revised_prompt": revised_prompt,
|
||||
"provider": "grok",
|
||||
"model": model,
|
||||
"image_width": image_width,
|
||||
"image_height": image_height,
|
||||
}
|
||||
|
||||
image_url = image_item.get("url")
|
||||
if image_url:
|
||||
logger.debug("Grok 返回图片 URL,开始下载: %s", image_url)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=180.0) as client:
|
||||
image_response = await client.get(image_url)
|
||||
|
||||
logger.debug(
|
||||
"Grok 图片下载响应: status=%s content_type=%s content_length=%s",
|
||||
image_response.status_code,
|
||||
image_response.headers.get("content-type"),
|
||||
image_response.headers.get("content-length"),
|
||||
)
|
||||
image_response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error(
|
||||
"Grok 图片下载 HTTP 错误: status=%s response=%s",
|
||||
exc.response.status_code if exc.response else None,
|
||||
exc.response.text[:2000] if exc.response is not None else None,
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
logger.error("Grok 图片下载异常", exc_info=True)
|
||||
raise
|
||||
|
||||
content_type = image_response.headers.get("content-type", "image/jpeg")
|
||||
file_extension = self._guess_extension(content_type=content_type, image_url=image_url)
|
||||
image_width, image_height = self._detect_image_size(image_response.content)
|
||||
logger.debug(
|
||||
"Grok 图片下载完成: bytes=%s extension=%s size=%sx%s",
|
||||
len(image_response.content),
|
||||
file_extension,
|
||||
image_width,
|
||||
image_height,
|
||||
)
|
||||
return {
|
||||
"content": image_response.content,
|
||||
"mime_type": content_type,
|
||||
"file_extension": file_extension,
|
||||
"revised_prompt": revised_prompt,
|
||||
"provider": "grok",
|
||||
"model": model,
|
||||
"image_width": image_width,
|
||||
"image_height": image_height,
|
||||
}
|
||||
|
||||
logger.error("Grok 返回内容中既没有 b64_json,也没有 url: %s", data)
|
||||
raise ValueError("Grok 未返回可用的图片数据")
|
||||
|
||||
@staticmethod
|
||||
def _to_public_result(result: dict[str, Any]) -> CoverGenerationResult:
|
||||
return {
|
||||
"content": result["content"],
|
||||
"mime_type": result["mime_type"],
|
||||
"file_extension": result["file_extension"],
|
||||
"revised_prompt": result.get("revised_prompt"),
|
||||
"provider": result["provider"],
|
||||
"model": result["model"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _detect_image_size(content: bytes) -> tuple[int, int]:
|
||||
if len(content) >= 24 and content[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
width, height = struct.unpack(">II", content[16:24])
|
||||
return int(width), int(height)
|
||||
|
||||
if len(content) >= 2 and content[:2] == b"\xff\xd8":
|
||||
index = 2
|
||||
content_length = len(content)
|
||||
while index < content_length - 1:
|
||||
if content[index] != 0xFF:
|
||||
index += 1
|
||||
continue
|
||||
marker = content[index + 1]
|
||||
index += 2
|
||||
if marker in (0xD8, 0xD9):
|
||||
continue
|
||||
if index + 2 > content_length:
|
||||
break
|
||||
segment_length = struct.unpack(">H", content[index:index + 2])[0]
|
||||
if segment_length < 2 or index + segment_length > content_length:
|
||||
break
|
||||
if marker in {
|
||||
0xC0, 0xC1, 0xC2, 0xC3,
|
||||
0xC5, 0xC6, 0xC7,
|
||||
0xC9, 0xCA, 0xCB,
|
||||
0xCD, 0xCE, 0xCF,
|
||||
}:
|
||||
if index + 7 <= content_length:
|
||||
height, width = struct.unpack(">HH", content[index + 3:index + 7])
|
||||
return int(width), int(height)
|
||||
break
|
||||
index += segment_length
|
||||
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def _decode_base64_image(value: str) -> bytes:
|
||||
if value.startswith("data:") and "," in value:
|
||||
value = value.split(",", 1)[1]
|
||||
return base64.b64decode(value)
|
||||
|
||||
@staticmethod
|
||||
def _adapt_prompt(*, prompt: str, width: int, height: int) -> str:
|
||||
cleaned_prompt = " ".join((prompt or "").split())
|
||||
return (
|
||||
f"{cleaned_prompt} "
|
||||
f"Use a {width}x{height} vertical composition."
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _get_aspect_ratio(*, width: int, height: int) -> str:
|
||||
if width <= 0 or height <= 0:
|
||||
return "2:3"
|
||||
if width * 3 == height * 2:
|
||||
return "2:3"
|
||||
return f"{width}:{height}"
|
||||
|
||||
@staticmethod
|
||||
def _get_resolution(*, width: int, height: int) -> str:
|
||||
longest_edge = max(width, height)
|
||||
if longest_edge >= 1536:
|
||||
return "2k"
|
||||
return "1k"
|
||||
|
||||
@staticmethod
|
||||
def _guess_extension(*, content_type: str, image_url: str) -> str:
|
||||
lowered_content_type = (content_type or "").lower()
|
||||
lowered_url = (image_url or "").lower()
|
||||
if "png" in lowered_content_type or lowered_url.endswith(".png"):
|
||||
return "png"
|
||||
if "webp" in lowered_content_type or lowered_url.endswith(".webp"):
|
||||
return "webp"
|
||||
return "jpg"
|
||||
@@ -24,6 +24,53 @@ class WritingStyleManager:
|
||||
|
||||
class PromptService:
|
||||
"""提示词模板管理"""
|
||||
|
||||
NOVEL_COVER_PROMPT_TEMPLATE = """创作一幅高质量小说封面插图,适用于竖版书籍封面。
|
||||
|
||||
小说标题是:“{title}”。
|
||||
类型为 {genre}。核心主题是 {theme}。故事摘要如下:{description}
|
||||
|
||||
画面应具有电影感、精致、富有氛围和情感表现力,并具备清晰的视觉焦点和强烈的象征性意象。请优先展现符合小说类型的视觉叙事和情绪,而不是死板地描绘具体场景。
|
||||
|
||||
这必须看起来像一幅专业的网络小说或实体出版物风格的封面。
|
||||
|
||||
硬性要求:
|
||||
- 必须在画面醒目位置包含小说标题文字:“{title}”,文字排版需极具艺术感,并与小说的 {genre} 类型风格完美融合。
|
||||
- 适用于标准小说封面的竖版构图(2:3 比例)。
|
||||
- 画面中只能出现标题文字,绝不能出现作者名字、副标题或其他无关的随机字母。
|
||||
- 无标志 (Logo)。
|
||||
- 无水印。
|
||||
- 无边框。
|
||||
- 无 UI 元素。
|
||||
- 无样机展示效果 (Mockup)。
|
||||
|
||||
最终图像必须是一张完整、专业的书籍封面艺术作品,背景插画与标题排版需相得益彰。"""
|
||||
|
||||
@classmethod
|
||||
async def build_novel_cover_prompt(
|
||||
cls,
|
||||
project: Any,
|
||||
user_id: str = None,
|
||||
db = None,
|
||||
) -> str:
|
||||
"""基于项目基础信息构建小说封面提示词,支持用户自定义模板"""
|
||||
title = (getattr(project, "title", "") or "未命名小说").strip()
|
||||
genre = (getattr(project, "genre", "") or "未指定类型").strip()
|
||||
theme = (getattr(project, "theme", "") or "未指定主题").strip()
|
||||
description = (getattr(project, "description", "") or "无额外简介").strip()
|
||||
|
||||
compact_description = description[:300]
|
||||
template = await cls.get_template_with_fallback(
|
||||
"NOVEL_COVER_PROMPT_TEMPLATE",
|
||||
user_id=user_id,
|
||||
db=db,
|
||||
)
|
||||
return template.format(
|
||||
title=title,
|
||||
genre=genre,
|
||||
theme=theme,
|
||||
description=compact_description,
|
||||
)
|
||||
|
||||
# ========== V2版本提示词模板(RTCO框架)==========
|
||||
|
||||
@@ -2813,6 +2860,12 @@ class PromptService:
|
||||
|
||||
# 定义所有模板及其元信息
|
||||
template_definitions = {
|
||||
"NOVEL_COVER_PROMPT_TEMPLATE": {
|
||||
"name": "小说封面生成",
|
||||
"category": "封面生成",
|
||||
"description": "根据项目基础信息生成小说封面绘制提示词,适用于竖版书籍封面",
|
||||
"parameters": ["title", "genre", "theme", "description"]
|
||||
},
|
||||
"WORLD_BUILDING": {
|
||||
"name": "世界构建",
|
||||
"category": "世界构建",
|
||||
|
||||
Reference in New Issue
Block a user