chore: layout optimize
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from openai import OpenAI
|
||||
|
||||
from app.schemas.embedding_model import (
|
||||
EmbeddingModelConfig,
|
||||
EmbeddingModelConfigCreate,
|
||||
EmbeddingModelConfigUpdate,
|
||||
EmbeddingModelConnectionTestRequest
|
||||
)
|
||||
from app.services.embedding_model_store import embedding_model_store
|
||||
from app.services.openai_compat import normalize_openai_base_url
|
||||
from app.api.llm import get_admin_user, get_current_user, CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def _mask_api_key(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
if len(value) <= 8:
|
||||
return "*" * len(value)
|
||||
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
|
||||
|
||||
@router.get("/embedding-models", response_model=List[EmbeddingModelConfig])
|
||||
def list_embedding_models(current_user: CurrentUser = Depends(get_current_user)):
|
||||
models = embedding_model_store.list_models()
|
||||
for m in models:
|
||||
if not current_user.is_admin:
|
||||
m["api_key"] = None
|
||||
return models
|
||||
|
||||
@router.post("/embedding-models", response_model=EmbeddingModelConfig)
|
||||
def create_embedding_model(payload: EmbeddingModelConfigCreate, _: CurrentUser = Depends(get_admin_user)):
|
||||
return embedding_model_store.create_model(payload.model_dump())
|
||||
|
||||
@router.get("/embedding-models/{model_id}", response_model=EmbeddingModelConfig)
|
||||
def get_embedding_model(model_id: str, current_user: CurrentUser = Depends(get_current_user)):
|
||||
model = embedding_model_store.get_model(model_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Embedding model not found")
|
||||
if not current_user.is_admin:
|
||||
model["api_key"] = None
|
||||
return model
|
||||
|
||||
@router.put("/embedding-models/{model_id}", response_model=EmbeddingModelConfig)
|
||||
def update_embedding_model(model_id: str, payload: EmbeddingModelConfigUpdate, _: CurrentUser = Depends(get_admin_user)):
|
||||
model = embedding_model_store.update_model(model_id, payload.model_dump(exclude_unset=True))
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Embedding model not found")
|
||||
return model
|
||||
|
||||
@router.delete("/embedding-models/{model_id}")
|
||||
def delete_embedding_model(model_id: str, _: CurrentUser = Depends(get_admin_user)):
|
||||
deleted = embedding_model_store.delete_model(model_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Embedding model not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.post("/embedding-models/test")
|
||||
def test_embedding_model_connection(payload: EmbeddingModelConnectionTestRequest, _: CurrentUser = Depends(get_admin_user)):
|
||||
api_base = normalize_openai_base_url(payload.api_base or "")
|
||||
api_key = payload.api_key
|
||||
model_name = (payload.model or "").strip()
|
||||
|
||||
if not api_base:
|
||||
raise HTTPException(status_code=400, detail="API Base is required")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="API Key is required")
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Model name is required")
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
embedding_resp = client.embeddings.create(
|
||||
model=model_name,
|
||||
input="connection test",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Embedding call failed: {exc}")
|
||||
|
||||
dimension = None
|
||||
if getattr(embedding_resp, "data", None):
|
||||
first = embedding_resp.data[0]
|
||||
vector = getattr(first, "embedding", None)
|
||||
if isinstance(vector, list):
|
||||
dimension = len(vector)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Connection successful",
|
||||
"model_name": model_name,
|
||||
"embedding_dimension": dimension,
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class EmbeddingModelConfigBase(BaseModel):
|
||||
name: str = Field(..., description="Display name for the model configuration")
|
||||
provider: str = Field("openai", description="Provider type (e.g. openai)")
|
||||
model: str = Field(..., description="Model name (e.g. text-embedding-3-small)")
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class EmbeddingModelConfigCreate(EmbeddingModelConfigBase):
|
||||
pass
|
||||
|
||||
class EmbeddingModelConfigUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class EmbeddingModelConfig(EmbeddingModelConfigBase):
|
||||
id: str
|
||||
|
||||
class EmbeddingModelConnectionTestRequest(BaseModel):
|
||||
provider: str = Field("openai")
|
||||
model: str = Field(...)
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
class EmbeddingModelStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _file_path() -> Path:
|
||||
return get_data_root() / "embedding_models.json"
|
||||
|
||||
def _read(self) -> List[Dict[str, Any]]:
|
||||
file_path = self._file_path()
|
||||
if not file_path.exists():
|
||||
return []
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return []
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
return data
|
||||
|
||||
def _write(self, data: List[Dict[str, Any]]) -> None:
|
||||
file_path = self._file_path()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def list_models(self) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
return self._read()
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for item in data:
|
||||
if item.get("id") == model_id:
|
||||
return item
|
||||
return None
|
||||
|
||||
def create_model(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
new_model = payload.copy()
|
||||
new_model["id"] = uuid.uuid4().hex
|
||||
data.append(new_model)
|
||||
self._write(data)
|
||||
return new_model
|
||||
|
||||
def update_model(self, model_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for item in data:
|
||||
if item.get("id") == model_id:
|
||||
item.update(payload)
|
||||
self._write(data)
|
||||
return item
|
||||
return None
|
||||
|
||||
def delete_model(self, model_id: str) -> bool:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
initial_len = len(data)
|
||||
data = [item for item in data if item.get("id") != model_id]
|
||||
if len(data) < initial_len:
|
||||
self._write(data)
|
||||
return True
|
||||
return False
|
||||
|
||||
embedding_model_store = EmbeddingModelStore()
|
||||
@@ -124,10 +124,27 @@ class KnowledgeIndexService:
|
||||
|
||||
@staticmethod
|
||||
def _build_embed_model(kb: Dict[str, Any]) -> Any:
|
||||
global_config = knowledge_global_config_store.get()
|
||||
api_base = global_config.get("api_base")
|
||||
api_key = global_config.get("api_key")
|
||||
model_name = kb.get("embedding_model") or global_config.get("default_embedding_model")
|
||||
from app.services.embedding_model_store import embedding_model_store
|
||||
models = embedding_model_store.list_models()
|
||||
if not models:
|
||||
return None
|
||||
|
||||
target_model = None
|
||||
kb_model_val = kb.get("embedding_model")
|
||||
if kb_model_val:
|
||||
# Try matching by ID first, then by model name
|
||||
target_model = next((m for m in models if m.get("id") == kb_model_val), None)
|
||||
if not target_model:
|
||||
target_model = next((m for m in models if m.get("model") == kb_model_val), None)
|
||||
|
||||
if not target_model:
|
||||
# Fallback to the first model
|
||||
target_model = models[0]
|
||||
|
||||
api_base = target_model.get("api_base")
|
||||
api_key = target_model.get("api_key")
|
||||
model_name = target_model.get("model")
|
||||
|
||||
if not api_base or not api_key or not model_name:
|
||||
return None
|
||||
api_base = _normalize_embedding_api_base(api_base)
|
||||
|
||||
+2
-1
@@ -16,7 +16,7 @@ import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp, subagents, knowledge
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp, subagents, knowledge, embedding_models
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.artifacts import extract_artifacts
|
||||
@@ -71,6 +71,7 @@ app.include_router(semantic.router, prefix="/api/v1")
|
||||
app.include_router(mcp.router, prefix="/api/v1")
|
||||
app.include_router(subagents.router, prefix="/api/v1")
|
||||
app.include_router(knowledge.router, prefix="/api/v1")
|
||||
app.include_router(embedding_models.router, prefix="/api/v1")
|
||||
|
||||
STREAM_DELTA_CHUNK_SIZE = 48
|
||||
PREVIEWABLE_TEXT_EXTENSIONS = {
|
||||
|
||||
Reference in New Issue
Block a user