1.优化AI请求替换OpenAI SDK调用,使用httpx和自定义头请求,避免触发部分公益站的cloudflare
2.修复deepseek模型调用问题,舍弃思考过程AI响应内容,只获取结果内容 3.新增会话过期机制,更新后添加到.env中 4.支持用户在生成章节内容时设置字数
This commit is contained in:
+115
-5
@@ -6,12 +6,20 @@ from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.services.oauth_service import LinuxDOOAuthService
|
||||
from app.user_manager import user_manager
|
||||
from app.database import init_db
|
||||
from app.logger import get_logger
|
||||
from app.config import settings
|
||||
|
||||
# 中国时区 UTC+8
|
||||
CHINA_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
def get_china_now():
|
||||
"""获取中国当前时间"""
|
||||
return datetime.now(CHINA_TZ)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
@@ -84,15 +92,31 @@ async def local_login(request: LocalLoginRequest, response: Response):
|
||||
except Exception as e:
|
||||
logger.error(f"本地用户 {user.user_id} 数据库初始化失败: {e}")
|
||||
|
||||
# 设置 Cookie(7天有效)
|
||||
# 设置 Cookie(2小时有效)
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=7 * 24 * 60 * 60, # 7天
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 设置过期时间戳 Cookie(用于前端判断)
|
||||
china_now = get_china_now()
|
||||
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
|
||||
expire_at = int(expire_time.timestamp())
|
||||
|
||||
logger.info(f"✅ [登录] 用户 {user.user_id} 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
|
||||
|
||||
response.set_cookie(
|
||||
key="session_expire_at",
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False, # 前端需要读取
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return LocalLoginResponse(
|
||||
success=True,
|
||||
message="登录成功",
|
||||
@@ -180,15 +204,31 @@ async def _handle_callback(
|
||||
logger.info(f"OAuth回调成功,重定向到前端: {redirect_url}")
|
||||
redirect_response = RedirectResponse(url=redirect_url)
|
||||
|
||||
# 设置 httponly Cookie(7天有效)
|
||||
# 设置 httponly Cookie(2小时有效)
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
redirect_response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=7 * 24 * 60 * 60, # 7天
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 设置过期时间戳 Cookie(用于前端判断)
|
||||
china_now = get_china_now()
|
||||
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
|
||||
expire_at = int(expire_time.timestamp())
|
||||
|
||||
logger.info(f"✅ [OAuth登录] 用户 {user.user_id} 登录成功,会话有效期 {settings.SESSION_EXPIRE_MINUTES} 分钟")
|
||||
|
||||
redirect_response.set_cookie(
|
||||
key="session_expire_at",
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False, # 前端需要读取
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
return redirect_response
|
||||
|
||||
|
||||
@@ -214,10 +254,80 @@ async def callback_alias(
|
||||
return await _handle_callback(code, state, error, response)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_session(request: Request, response: Response):
|
||||
"""刷新会话 - 延长登录状态"""
|
||||
# 检查是否已登录
|
||||
if not hasattr(request.state, "user") or not request.state.user:
|
||||
raise HTTPException(status_code=401, detail="未登录,无法刷新会话")
|
||||
|
||||
user = request.state.user
|
||||
|
||||
# 检查当前会话是否即将过期(剩余时间少于阈值)
|
||||
session_expire_at = request.cookies.get("session_expire_at")
|
||||
if session_expire_at:
|
||||
try:
|
||||
expire_timestamp = int(session_expire_at)
|
||||
current_timestamp = int(get_china_now().timestamp())
|
||||
remaining_minutes = (expire_timestamp - current_timestamp) / 60
|
||||
|
||||
# 如果剩余时间大于刷新阈值,不需要刷新
|
||||
if remaining_minutes > settings.SESSION_REFRESH_THRESHOLD_MINUTES:
|
||||
logger.info(f"⏱️ [刷新会话] 用户 {user.user_id} 会话仍有效,剩余 {int(remaining_minutes)} 分钟")
|
||||
return {
|
||||
"message": "会话仍然有效,无需刷新",
|
||||
"remaining_minutes": int(remaining_minutes),
|
||||
"expire_at": expire_timestamp
|
||||
}
|
||||
except (ValueError, TypeError):
|
||||
pass # Cookie 格式错误,继续刷新
|
||||
|
||||
# 刷新 Cookie
|
||||
max_age = settings.SESSION_EXPIRE_MINUTES * 60
|
||||
response.set_cookie(
|
||||
key="user_id",
|
||||
value=user.user_id,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
# 更新过期时间戳
|
||||
china_now = get_china_now()
|
||||
expire_time = china_now + timedelta(minutes=settings.SESSION_EXPIRE_MINUTES)
|
||||
expire_at = int(expire_time.timestamp())
|
||||
|
||||
logger.info(f"[刷新会话] 用户: {user.user_id}")
|
||||
logger.info(f"[刷新会话] 中国当前时间: {china_now.strftime('%Y-%m-%d %H:%M:%S')} (UTC+8)")
|
||||
logger.info(f"[刷新会话] 中国过期时间: {expire_time.strftime('%Y-%m-%d %H:%M:%S')} (UTC+8)")
|
||||
logger.info(f"[刷新会话] 过期时间戳 (秒): {expire_at}")
|
||||
logger.info(f"[刷新会话] Cookie max_age (秒): {max_age}")
|
||||
|
||||
response.set_cookie(
|
||||
key="session_expire_at",
|
||||
value=str(expire_at),
|
||||
max_age=max_age,
|
||||
httponly=False,
|
||||
samesite="lax"
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user.user_id} 刷新会话成功")
|
||||
return {
|
||||
"message": "会话刷新成功",
|
||||
"expire_at": expire_at,
|
||||
"remaining_minutes": settings.SESSION_EXPIRE_MINUTES
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
async def logout(request: Request, response: Response):
|
||||
"""退出登录"""
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if user_id:
|
||||
logger.info(f"🚪 [退出] 用户 {user_id} 退出登录")
|
||||
|
||||
response.delete_cookie("user_id")
|
||||
response.delete_cookie("session_expire_at")
|
||||
return {"message": "退出登录成功"}
|
||||
|
||||
|
||||
|
||||
@@ -261,11 +261,13 @@ async def generate_chapter_content_stream(
|
||||
|
||||
请求体参数:
|
||||
- style_id: 可选,指定使用的写作风格ID。不提供则不使用任何风格
|
||||
- target_word_count: 可选,目标字数,默认3000字,范围500-10000字
|
||||
|
||||
注意:此函数不使用依赖注入的db,而是在生成器内部创建独立的数据库会话
|
||||
以避免流式响应期间的连接泄漏问题
|
||||
"""
|
||||
style_id = generate_request.style_id
|
||||
target_word_count = generate_request.target_word_count or 3000
|
||||
# 预先验证章节存在性(使用临时会话)
|
||||
async for temp_db in get_db(request):
|
||||
try:
|
||||
@@ -415,7 +417,8 @@ async def generate_chapter_content_stream(
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count
|
||||
)
|
||||
else:
|
||||
prompt = prompt_service.get_chapter_generation_prompt(
|
||||
@@ -432,7 +435,8 @@ async def generate_chapter_content_stream(
|
||||
chapter_number=current_chapter.chapter_number,
|
||||
chapter_title=current_chapter.title,
|
||||
chapter_outline=outline.content if outline else current_chapter.summary or '暂无大纲',
|
||||
style_content=style_content
|
||||
style_content=style_content,
|
||||
target_word_count=target_word_count
|
||||
)
|
||||
|
||||
logger.info(f"开始AI流式创作章节 {chapter_id}")
|
||||
|
||||
+166
-1
@@ -6,6 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Dict, Any, List
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
import httpx
|
||||
|
||||
from app.database import get_db
|
||||
@@ -296,4 +297,168 @@ async def get_available_models(
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取模型列表失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ApiTestRequest(BaseModel):
|
||||
"""API 测试请求模型"""
|
||||
api_key: str
|
||||
api_base_url: str
|
||||
provider: str
|
||||
model_name: str
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_api_connection(data: ApiTestRequest):
|
||||
"""
|
||||
测试 API 连接和配置是否正确
|
||||
|
||||
Args:
|
||||
data: 包含 API 配置的请求数据
|
||||
|
||||
Returns:
|
||||
测试结果包含状态、响应时间和详细信息
|
||||
"""
|
||||
api_key = data.api_key
|
||||
api_base_url = data.api_base_url
|
||||
provider = data.provider
|
||||
model_name = data.model_name
|
||||
import time
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 创建临时 AI 服务实例
|
||||
test_service = AIService(
|
||||
api_provider=provider,
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
default_model=model_name,
|
||||
default_temperature=0.7,
|
||||
default_max_tokens=100
|
||||
)
|
||||
|
||||
# 发送简单的测试请求
|
||||
test_prompt = "请用一句话回复:测试成功"
|
||||
|
||||
logger.info(f"🧪 开始测试 API 连接")
|
||||
logger.info(f" - 提供商: {provider}")
|
||||
logger.info(f" - 模型: {model_name}")
|
||||
logger.info(f" - Base URL: {api_base_url}")
|
||||
|
||||
response = await test_service.generate_text(
|
||||
prompt=test_prompt,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
temperature=0.7,
|
||||
max_tokens=8000
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2) # 转换为毫秒
|
||||
|
||||
logger.info(f"✅ API 测试成功")
|
||||
logger.info(f" - 响应时间: {response_time}ms")
|
||||
logger.info(f" - 响应内容: {response[:100] if response else 'N/A'}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "API 连接测试成功",
|
||||
"response_time_ms": response_time,
|
||||
"provider": provider,
|
||||
"model": model_name,
|
||||
"response_preview": response[:100] if response and len(response) > 100 else response,
|
||||
"details": {
|
||||
"api_available": True,
|
||||
"model_accessible": True,
|
||||
"response_valid": bool(response)
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
# 配置错误
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ API 配置错误: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "API 配置错误",
|
||||
"error": error_msg,
|
||||
"error_type": "ConfigurationError",
|
||||
"suggestions": [
|
||||
"请检查 API Key 是否正确",
|
||||
"请确认 API Base URL 格式正确",
|
||||
"请验证所选提供商是否匹配"
|
||||
]
|
||||
}
|
||||
|
||||
except TimeoutError as e:
|
||||
# 超时错误
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ API 请求超时: {error_msg}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "API 请求超时",
|
||||
"error": error_msg,
|
||||
"error_type": "TimeoutError",
|
||||
"suggestions": [
|
||||
"请检查网络连接",
|
||||
"请确认 API Base URL 是否可访问",
|
||||
"如果使用代理,请检查代理设置"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
|
||||
logger.error(f"❌ API 测试失败: {error_msg}")
|
||||
logger.error(f" - 错误类型: {error_type}")
|
||||
|
||||
# 分析错误原因并提供建议
|
||||
suggestions = []
|
||||
if "blocked" in error_msg.lower():
|
||||
suggestions = [
|
||||
"请求被 API 提供商阻止",
|
||||
"可能原因:API Key 被限制或地区限制",
|
||||
"建议:检查 API Key 状态和账户余额",
|
||||
"建议:尝试更换 API Base URL 或使用代理"
|
||||
]
|
||||
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
|
||||
suggestions = [
|
||||
"API Key 认证失败",
|
||||
"建议:检查 API Key 是否正确",
|
||||
"建议:确认 API Key 是否过期"
|
||||
]
|
||||
elif "not found" in error_msg.lower() or "404" in error_msg:
|
||||
suggestions = [
|
||||
"API 端点不存在或模型不可用",
|
||||
"建议:检查 API Base URL 是否正确",
|
||||
"建议:确认模型名称是否正确"
|
||||
]
|
||||
elif "rate limit" in error_msg.lower() or "429" in error_msg:
|
||||
suggestions = [
|
||||
"API 请求频率超限",
|
||||
"建议:稍后重试",
|
||||
"建议:升级 API 套餐"
|
||||
]
|
||||
elif "insufficient" in error_msg.lower() or "quota" in error_msg.lower():
|
||||
suggestions = [
|
||||
"API 配额不足",
|
||||
"建议:检查账户余额",
|
||||
"建议:充值或升级套餐"
|
||||
]
|
||||
else:
|
||||
suggestions = [
|
||||
"请检查所有配置参数是否正确",
|
||||
"请确认网络连接正常",
|
||||
"请查看详细错误信息"
|
||||
]
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": "API 测试失败",
|
||||
"error": error_msg,
|
||||
"error_type": error_type,
|
||||
"suggestions": suggestions
|
||||
}
|
||||
@@ -260,6 +260,7 @@ async def characters_generator(
|
||||
# 重试逻辑
|
||||
retry_count = 0
|
||||
batch_success = False
|
||||
batch_error_message = ""
|
||||
|
||||
while retry_count < MAX_RETRIES and not batch_success:
|
||||
try:
|
||||
@@ -326,37 +327,24 @@ async def characters_generator(
|
||||
if not isinstance(characters_data, list):
|
||||
characters_data = [characters_data]
|
||||
|
||||
# 验证生成数量是否精确
|
||||
# 严格验证生成数量是否精确匹配
|
||||
if len(characters_data) != current_batch_size:
|
||||
logger.warning(f"批次{batch_idx+1}生成数量不匹配: 期望{current_batch_size}, 实际{len(characters_data)}")
|
||||
error_msg = f"批次{batch_idx+1}生成数量不正确: 期望{current_batch_size}个, 实际{len(characters_data)}个"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 如果数量不足,重试
|
||||
if len(characters_data) < current_batch_size:
|
||||
if retry_count < MAX_RETRIES - 1:
|
||||
retry_count += 1
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 生成数量不足(期望{current_batch_size},实际{len(characters_data)}),准备重试...",
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 最后一次重试仍不足,记录但继续使用
|
||||
logger.warning(f"批次{batch_idx+1}多次重试后仍数量不足,使用当前结果")
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ 批次{batch_idx+1}生成{len(characters_data)}个(期望{current_batch_size}),继续处理",
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
# 如果数量过多,只取需要的数量并发出警告
|
||||
else:
|
||||
logger.warning(f"批次{batch_idx+1}生成过多角色({len(characters_data)}>{current_batch_size}),将只取前{current_batch_size}个")
|
||||
# 如果还有重试机会,继续重试
|
||||
if retry_count < MAX_RETRIES - 1:
|
||||
retry_count += 1
|
||||
yield await SSEResponse.send_progress(
|
||||
f"⚠️ AI生成过多,截取前{current_batch_size}个角色",
|
||||
f"⚠️ {error_msg},准备重试...",
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
characters_data = characters_data[:current_batch_size]
|
||||
continue
|
||||
else:
|
||||
# 最后一次重试仍失败,直接返回错误
|
||||
yield await SSEResponse.send_error(error_msg)
|
||||
return
|
||||
|
||||
all_characters.extend(characters_data)
|
||||
batch_success = True
|
||||
@@ -364,6 +352,7 @@ async def characters_generator(
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"批次{batch_idx+1}解析失败(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
||||
batch_error_message = f"JSON解析失败: {str(e)}"
|
||||
retry_count += 1
|
||||
if retry_count < MAX_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
@@ -371,14 +360,9 @@ async def characters_generator(
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
else:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"批次{batch_idx+1}多次重试失败,跳过",
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"批次{batch_idx+1}生成异常(尝试{retry_count+1}/{MAX_RETRIES}): {e}")
|
||||
batch_error_message = f"生成异常: {str(e)}"
|
||||
retry_count += 1
|
||||
if retry_count < MAX_RETRIES:
|
||||
yield await SSEResponse.send_progress(
|
||||
@@ -386,16 +370,15 @@ async def characters_generator(
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
else:
|
||||
yield await SSEResponse.send_progress(
|
||||
f"批次{batch_idx+1}多次重试失败,跳过",
|
||||
batch_progress,
|
||||
"warning"
|
||||
)
|
||||
|
||||
if not all_characters:
|
||||
yield await SSEResponse.send_error("所有批次都生成失败,请重试")
|
||||
return
|
||||
|
||||
# 检查批次是否成功
|
||||
if not batch_success:
|
||||
error_msg = f"批次{batch_idx+1}在{MAX_RETRIES}次重试后仍然失败"
|
||||
if batch_error_message:
|
||||
error_msg += f": {batch_error_message}"
|
||||
logger.error(error_msg)
|
||||
yield await SSEResponse.send_error(error_msg)
|
||||
return
|
||||
|
||||
# 保存到数据库 - 分阶段处理以保证一致性
|
||||
yield await SSEResponse.send_progress("验证角色数据...", 82)
|
||||
@@ -665,6 +648,10 @@ async def characters_generator(
|
||||
logger.info(f" - 创建角色关系:{relationships_created} 条")
|
||||
logger.info(f" - 创建组织成员:{members_created} 条")
|
||||
|
||||
# 更新项目的角色数量
|
||||
project.character_count = len(created_characters)
|
||||
logger.info(f"✅ 更新项目角色数量: {project.character_count}")
|
||||
|
||||
await db.commit()
|
||||
db_committed = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user