Update 2026-05-13 16:43:53
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,139 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSource as DataSourceSchema, DataSourceTestRequest
|
||||
from app.core.security import get_current_user, get_admin_user, CurrentUser
|
||||
from app.connectors.factory import get_connector_from_config
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/datasources", response_model=List[DataSourceSchema])
|
||||
def list_datasources(
|
||||
project_id: Optional[int] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
query = db.query(DataSource)
|
||||
if project_id:
|
||||
query = query.filter(DataSource.project_id == project_id)
|
||||
|
||||
# If not admin, check if user has access to the project
|
||||
if not current_user.is_admin and project_id:
|
||||
from app.models.project import Project
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project or project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions for this project")
|
||||
|
||||
datasources = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Hide sensitive info for non-admins if necessary, but config usually contains secrets.
|
||||
# Maybe we should return a sanitized version for regular users?
|
||||
# For now, return full config but only to admins?
|
||||
# Or just assume the API is secure.
|
||||
# If regular users need to select datasource, they just need ID and Name.
|
||||
if not current_user.is_admin:
|
||||
# Sanitize config
|
||||
sanitized = []
|
||||
for ds in datasources:
|
||||
ds_dict = DataSourceSchema.from_orm(ds).dict()
|
||||
# Remove sensitive fields from config
|
||||
if ds_dict.get("config"):
|
||||
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
|
||||
sanitized.append(ds_dict)
|
||||
return sanitized
|
||||
|
||||
return datasources
|
||||
|
||||
@router.post("/datasources", response_model=DataSourceSchema)
|
||||
def create_datasource(
|
||||
datasource: DataSourceCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
# Check if project exists and user has access
|
||||
from app.models.project import Project
|
||||
project = db.query(Project).filter(Project.id == datasource.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not current_user.is_admin and project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions for this project")
|
||||
|
||||
db_datasource = DataSource(**datasource.dict())
|
||||
db.add(db_datasource)
|
||||
db.commit()
|
||||
db.refresh(db_datasource)
|
||||
return db_datasource
|
||||
|
||||
@router.get("/datasources/{datasource_id}", response_model=DataSourceSchema)
|
||||
def read_datasource(
|
||||
datasource_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
if not current_user.is_admin:
|
||||
ds_dict = DataSourceSchema.from_orm(db_datasource).dict()
|
||||
if ds_dict.get("config"):
|
||||
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
|
||||
return ds_dict
|
||||
|
||||
return db_datasource
|
||||
|
||||
@router.put("/datasources/{datasource_id}", response_model=DataSourceSchema)
|
||||
def update_datasource(
|
||||
datasource_id: int,
|
||||
datasource: DataSourceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
update_data = datasource.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_datasource, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_datasource)
|
||||
return db_datasource
|
||||
|
||||
@router.delete("/datasources/{datasource_id}")
|
||||
def delete_datasource(
|
||||
datasource_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
db.delete(db_datasource)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/datasources/test")
|
||||
def test_datasource_connection(
|
||||
request: DataSourceTestRequest,
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
connector = get_connector_from_config(request.type, request.config)
|
||||
if connector.test_connection():
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Connection failed")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
import sys
|
||||
print(f"Datasource Test Error: {str(e)}\n{traceback.format_exc()}", file=sys.stderr)
|
||||
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
|
||||
@@ -0,0 +1,96 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from openai import OpenAI
|
||||
|
||||
from app.schemas.embedding_model import (
|
||||
EmbeddingModelConfig,
|
||||
EmbeddingModelConfigCreate,
|
||||
EmbeddingModelConfigUpdate,
|
||||
EmbeddingModelConnectionTestRequest
|
||||
)
|
||||
from app.services.embedding_model_store import embedding_model_store
|
||||
from app.services.openai_compat import normalize_openai_base_url
|
||||
from app.api.llm import get_admin_user, get_current_user, CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def _mask_api_key(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
if len(value) <= 8:
|
||||
return "*" * len(value)
|
||||
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
|
||||
|
||||
@router.get("/embedding-models", response_model=List[EmbeddingModelConfig])
|
||||
def list_embedding_models(current_user: CurrentUser = Depends(get_current_user)):
|
||||
models = embedding_model_store.list_models()
|
||||
for m in models:
|
||||
if not current_user.is_admin:
|
||||
m["api_key"] = None
|
||||
return models
|
||||
|
||||
@router.post("/embedding-models", response_model=EmbeddingModelConfig)
|
||||
def create_embedding_model(payload: EmbeddingModelConfigCreate, _: CurrentUser = Depends(get_admin_user)):
|
||||
return embedding_model_store.create_model(payload.model_dump())
|
||||
|
||||
@router.get("/embedding-models/{model_id}", response_model=EmbeddingModelConfig)
|
||||
def get_embedding_model(model_id: str, current_user: CurrentUser = Depends(get_current_user)):
|
||||
model = embedding_model_store.get_model(model_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Embedding model not found")
|
||||
if not current_user.is_admin:
|
||||
model["api_key"] = None
|
||||
return model
|
||||
|
||||
@router.put("/embedding-models/{model_id}", response_model=EmbeddingModelConfig)
|
||||
def update_embedding_model(model_id: str, payload: EmbeddingModelConfigUpdate, _: CurrentUser = Depends(get_admin_user)):
|
||||
model = embedding_model_store.update_model(model_id, payload.model_dump(exclude_unset=True))
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Embedding model not found")
|
||||
return model
|
||||
|
||||
@router.delete("/embedding-models/{model_id}")
|
||||
def delete_embedding_model(model_id: str, _: CurrentUser = Depends(get_admin_user)):
|
||||
deleted = embedding_model_store.delete_model(model_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Embedding model not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.post("/embedding-models/test")
|
||||
def test_embedding_model_connection(payload: EmbeddingModelConnectionTestRequest, _: CurrentUser = Depends(get_admin_user)):
|
||||
api_base = normalize_openai_base_url(payload.api_base or "")
|
||||
api_key = payload.api_key
|
||||
model_name = (payload.model or "").strip()
|
||||
|
||||
if not api_base:
|
||||
raise HTTPException(status_code=400, detail="API Base is required")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="API Key is required")
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Model name is required")
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
embedding_resp = client.embeddings.create(
|
||||
model=model_name,
|
||||
input="connection test",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Embedding call failed: {exc}")
|
||||
|
||||
dimension = None
|
||||
if getattr(embedding_resp, "data", None):
|
||||
first = embedding_resp.data[0]
|
||||
vector = getattr(first, "embedding", None)
|
||||
if isinstance(vector, list):
|
||||
dimension = len(vector)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Connection successful",
|
||||
"model_name": model_name,
|
||||
"embedding_dimension": dimension,
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
from typing import List, Optional
|
||||
import io
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import UploadFile, File, Form
|
||||
from openai import OpenAI
|
||||
import pandas as pd
|
||||
|
||||
from app.schemas.knowledge import (
|
||||
KnowledgeBase,
|
||||
KnowledgeBaseCreate,
|
||||
KnowledgeConnectionTestRequest,
|
||||
KnowledgeConnectionTestResponse,
|
||||
KnowledgeGlobalConfig,
|
||||
KnowledgeGlobalConfigUpdate,
|
||||
KnowledgeBaseUpdate,
|
||||
KnowledgeDocument,
|
||||
KnowledgeDocumentCreate,
|
||||
KnowledgeDocumentUpdate,
|
||||
KnowledgeSearchRequest,
|
||||
KnowledgeSearchResponse,
|
||||
)
|
||||
from app.services.knowledge_base_store import knowledge_base_store
|
||||
from app.services.knowledge_global_config_store import knowledge_global_config_store
|
||||
from app.services.knowledge_index import knowledge_index_service
|
||||
from app.services.openai_compat import normalize_openai_base_url
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _mask_api_key(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
if len(value) <= 8:
|
||||
return "*" * len(value)
|
||||
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
|
||||
|
||||
|
||||
def _extract_upload_text(filename: str, content: bytes) -> str:
|
||||
lower = filename.lower()
|
||||
if lower.endswith((".txt", ".md", ".markdown", ".json", ".yaml", ".yml", ".log", ".xml", ".html", ".htm")):
|
||||
try:
|
||||
return content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
if lower.endswith(".csv"):
|
||||
df = pd.read_csv(io.BytesIO(content))
|
||||
return df.to_csv(index=False)
|
||||
if lower.endswith((".xls", ".xlsx")):
|
||||
df = pd.read_excel(io.BytesIO(content))
|
||||
return df.to_csv(index=False)
|
||||
|
||||
# 增加对 PDF 的文本提取支持
|
||||
if lower.endswith(".pdf"):
|
||||
try:
|
||||
import PyPDF2
|
||||
pdf_reader = PyPDF2.PdfReader(io.BytesIO(content))
|
||||
text = []
|
||||
for page in pdf_reader.pages:
|
||||
page_text = page.extract_text()
|
||||
if page_text:
|
||||
text.append(page_text)
|
||||
return "\n".join(text)
|
||||
except ImportError:
|
||||
raise ValueError("PyPDF2 is not installed. Cannot parse PDF files.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse PDF: {str(e)}")
|
||||
|
||||
# 增加对 Word 文档的文本提取支持
|
||||
if lower.endswith((".doc", ".docx")):
|
||||
try:
|
||||
import docx
|
||||
doc = docx.Document(io.BytesIO(content))
|
||||
return "\n".join([para.text for para in doc.paragraphs])
|
||||
except ImportError:
|
||||
raise ValueError("python-docx is not installed. Cannot parse Word files.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Word document: {str(e)}")
|
||||
|
||||
# 增加对 PPT 文档的文本提取支持
|
||||
if lower.endswith((".ppt", ".pptx")):
|
||||
try:
|
||||
import pptx
|
||||
prs = pptx.Presentation(io.BytesIO(content))
|
||||
text = []
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
text.append(shape.text)
|
||||
return "\n".join(text)
|
||||
except ImportError:
|
||||
raise ValueError("python-pptx is not installed. Cannot parse PPT files.")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse PPT document: {str(e)}")
|
||||
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
|
||||
def get_knowledge_global_config():
|
||||
config = knowledge_global_config_store.get()
|
||||
raw_api_key = config.get("api_key")
|
||||
return {
|
||||
"api_base": config.get("api_base"),
|
||||
"api_key": None,
|
||||
"api_key_masked": _mask_api_key(raw_api_key),
|
||||
"has_api_key": bool(raw_api_key),
|
||||
"default_embedding_model": config.get("default_embedding_model"),
|
||||
}
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
|
||||
def update_knowledge_global_config(payload: KnowledgeGlobalConfigUpdate):
|
||||
updated = knowledge_global_config_store.update(payload.model_dump(exclude_unset=True))
|
||||
raw_api_key = updated.get("api_key")
|
||||
return {
|
||||
"api_base": updated.get("api_base"),
|
||||
"api_key": None,
|
||||
"api_key_masked": _mask_api_key(raw_api_key),
|
||||
"has_api_key": bool(raw_api_key),
|
||||
"default_embedding_model": updated.get("default_embedding_model"),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/global-config/test-connection", response_model=KnowledgeConnectionTestResponse)
|
||||
def test_knowledge_global_connection(payload: KnowledgeConnectionTestRequest):
|
||||
saved = knowledge_global_config_store.get()
|
||||
api_base = normalize_openai_base_url(payload.api_base or saved.get("api_base") or "")
|
||||
api_key = payload.api_key or saved.get("api_key")
|
||||
model_name = (payload.model_name or "").strip()
|
||||
|
||||
if not api_base:
|
||||
raise HTTPException(status_code=400, detail="API Base 未配置")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="API Key 未配置")
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="测试连接必须显式填写向量模型名称")
|
||||
|
||||
if not api_base:
|
||||
raise HTTPException(status_code=400, detail="API Base 未配置")
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
embedding_resp = client.embeddings.create(
|
||||
model=model_name,
|
||||
input="connection test",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Embedding调用失败: {exc}")
|
||||
|
||||
dimension = None
|
||||
if getattr(embedding_resp, "data", None):
|
||||
first = embedding_resp.data[0]
|
||||
vector = getattr(first, "embedding", None)
|
||||
if isinstance(vector, list):
|
||||
dimension = len(vector)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功,Embedding调用正常",
|
||||
"model_name": model_name,
|
||||
"embedding_dimension": dimension,
|
||||
"resolved_api_base": api_base,
|
||||
"available_models": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/knowledge-bases", response_model=List[KnowledgeBase])
|
||||
def list_knowledge_bases(project_id: Optional[int] = None):
|
||||
return knowledge_base_store.list(project_id=project_id)
|
||||
|
||||
|
||||
@router.post("/knowledge-bases", response_model=KnowledgeBase)
|
||||
def create_knowledge_base(payload: KnowledgeBaseCreate):
|
||||
return knowledge_base_store.create(payload.model_dump())
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
|
||||
def get_knowledge_base(kb_id: str):
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return kb
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
|
||||
def update_knowledge_base(kb_id: str, payload: KnowledgeBaseUpdate):
|
||||
kb = knowledge_base_store.update(kb_id, payload.model_dump(exclude_unset=True))
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return kb
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}")
|
||||
def delete_knowledge_base(kb_id: str):
|
||||
deleted = knowledge_base_store.delete(kb_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}/documents", response_model=List[KnowledgeDocument])
|
||||
def list_knowledge_documents(kb_id: str):
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return kb.get("documents", [])
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents", response_model=KnowledgeDocument)
|
||||
def create_knowledge_document(kb_id: str, payload: KnowledgeDocumentCreate):
|
||||
doc = knowledge_base_store.create_document(kb_id=kb_id, payload=payload.model_dump())
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return doc
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/{kb_id}/documents/{doc_id}", response_model=KnowledgeDocument)
|
||||
def update_knowledge_document(kb_id: str, doc_id: str, payload: KnowledgeDocumentUpdate):
|
||||
doc = knowledge_base_store.update_document(
|
||||
kb_id=kb_id,
|
||||
doc_id=doc_id,
|
||||
payload=payload.model_dump(exclude_unset=True),
|
||||
)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Knowledge document not found")
|
||||
return doc
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}/documents/{doc_id}")
|
||||
def delete_knowledge_document(kb_id: str, doc_id: str):
|
||||
deleted = knowledge_base_store.delete_document(kb_id=kb_id, doc_id=doc_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Knowledge document not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/reindex")
|
||||
def reindex_knowledge_base(kb_id: str):
|
||||
try:
|
||||
return knowledge_index_service.reindex(kb_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/search", response_model=KnowledgeSearchResponse)
|
||||
def search_knowledge_base(kb_id: str, payload: KnowledgeSearchRequest):
|
||||
try:
|
||||
result = knowledge_index_service.search(
|
||||
kb_id=kb_id,
|
||||
query=payload.query,
|
||||
top_k=payload.top_k,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents/upload")
|
||||
async def upload_knowledge_documents(
|
||||
kb_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
metadata: Optional[str] = Form(default=None),
|
||||
):
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
metadata_payload: dict[str, Any] = {}
|
||||
if metadata:
|
||||
try:
|
||||
parsed_metadata = json.loads(metadata)
|
||||
if isinstance(parsed_metadata, dict):
|
||||
metadata_payload = parsed_metadata
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="metadata 必须是合法 JSON 对象")
|
||||
|
||||
created: List[dict[str, Any]] = []
|
||||
for file in files:
|
||||
filename = file.filename or "untitled"
|
||||
content = await file.read()
|
||||
if not content:
|
||||
continue
|
||||
# 将大小限制从 5MB 放宽到 15MB,以更好地支持带有图片的 PDF 文件
|
||||
if len(content) > 15 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail=f"文件过大 (超过 15MB): {filename}")
|
||||
try:
|
||||
text = _extract_upload_text(filename, content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {filename}")
|
||||
doc = knowledge_base_store.create_document(
|
||||
kb_id=kb_id,
|
||||
payload={
|
||||
"title": filename,
|
||||
"content": text,
|
||||
"metadata": {**metadata_payload, "source": "upload", "filename": filename},
|
||||
},
|
||||
)
|
||||
if doc:
|
||||
created.append(doc)
|
||||
return {"status": "success", "count": len(created), "documents": created}
|
||||
@@ -0,0 +1,183 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import jwt, JWTError
|
||||
from pydantic import BaseModel, Field
|
||||
from app.core.security import SECRET_KEY, ALGORITHM
|
||||
from app.core.data_root import get_data_root
|
||||
from app.core.llm_provider import build_llm_provider
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
|
||||
DATA_FILE = str(get_data_root() / "llm_config.json")
|
||||
|
||||
|
||||
class CurrentUser(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
is_admin: bool = False
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier for the LLM configuration")
|
||||
name: Optional[str] = Field(None, description="Display name")
|
||||
provider: str = Field(..., description="Provider name (e.g., openai, azure, anthropic)")
|
||||
model: str = Field(..., description="Model name (e.g., gpt-4, claude-3-opus)")
|
||||
api_key: Optional[str] = Field(None, description="API Key for the provider")
|
||||
api_base: Optional[str] = Field(None, description="Base URL for the API")
|
||||
extra_headers: Optional[Dict[str, str]] = Field(None, description="Extra headers for the request")
|
||||
is_active: bool = Field(True, description="Whether this configuration is active")
|
||||
|
||||
class LLMConfigCreate(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
is_active: bool = True
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class TestConnectionRequest(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
|
||||
unauthorized = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise unauthorized
|
||||
user_id = payload.get("id")
|
||||
username = payload.get("sub")
|
||||
is_admin = bool(payload.get("is_admin", False))
|
||||
if user_id is None or username is None:
|
||||
raise unauthorized
|
||||
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
|
||||
|
||||
|
||||
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
def _load_data() -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(DATA_FILE):
|
||||
return []
|
||||
try:
|
||||
with open(DATA_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
def _save_data(data: List[Dict[str, Any]]):
|
||||
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
|
||||
with open(DATA_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def _sanitize_config(item: Dict[str, Any], is_admin: bool) -> Dict[str, Any]:
|
||||
config = item.copy()
|
||||
if not is_admin:
|
||||
config["api_key"] = None
|
||||
return config
|
||||
|
||||
@router.get("/llm", response_model=List[LLMConfig])
|
||||
def list_llm_configs(current_user: CurrentUser = Depends(get_current_user)):
|
||||
data = _load_data()
|
||||
return [LLMConfig(**_sanitize_config(item, current_user.is_admin)) for item in data]
|
||||
|
||||
@router.get("/llm/{config_id}", response_model=LLMConfig)
|
||||
def get_llm_config(config_id: str, current_user: CurrentUser = Depends(get_current_user)):
|
||||
data = _load_data()
|
||||
for item in data:
|
||||
if item["id"] == config_id:
|
||||
return LLMConfig(**_sanitize_config(item, current_user.is_admin))
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
@router.post("/llm", response_model=LLMConfig)
|
||||
def create_llm_config(config: LLMConfigCreate, _: CurrentUser = Depends(get_admin_user)):
|
||||
data = _load_data()
|
||||
if any(item["id"] == config.id for item in data):
|
||||
raise HTTPException(status_code=400, detail="LLM configuration with this ID already exists")
|
||||
|
||||
new_config = config.dict()
|
||||
if new_config.get("is_active"):
|
||||
for item in data:
|
||||
item["is_active"] = False
|
||||
data.append(new_config)
|
||||
_save_data(data)
|
||||
return LLMConfig(**new_config)
|
||||
|
||||
@router.put("/llm/{config_id}", response_model=LLMConfig)
|
||||
def update_llm_config(config_id: str, config: LLMConfigUpdate, _: CurrentUser = Depends(get_admin_user)):
|
||||
data = _load_data()
|
||||
for i, item in enumerate(data):
|
||||
if item["id"] == config_id:
|
||||
updated_item = item.copy()
|
||||
update_data = config.dict(exclude_unset=True)
|
||||
if update_data.get("is_active"):
|
||||
for j in range(len(data)):
|
||||
data[j]["is_active"] = False
|
||||
updated_item.update(update_data)
|
||||
data[i] = updated_item
|
||||
_save_data(data)
|
||||
return LLMConfig(**updated_item)
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
@router.delete("/llm/{config_id}")
|
||||
def delete_llm_config(config_id: str, _: CurrentUser = Depends(get_admin_user)):
|
||||
data = _load_data()
|
||||
initial_len = len(data)
|
||||
data = [item for item in data if item["id"] != config_id]
|
||||
if len(data) == initial_len:
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
_save_data(data)
|
||||
return {"message": "LLM configuration deleted successfully"}
|
||||
|
||||
@router.post("/llm/test")
|
||||
async def test_connection(request: TestConnectionRequest, _: CurrentUser = Depends(get_admin_user)):
|
||||
try:
|
||||
provider = build_llm_provider(
|
||||
model=request.model.strip(),
|
||||
provider=request.provider,
|
||||
api_key=request.api_key,
|
||||
api_base=request.api_base,
|
||||
extra_headers=request.extra_headers,
|
||||
)
|
||||
response = await provider.chat(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=5,
|
||||
temperature=0,
|
||||
)
|
||||
if response.finish_reason == "error":
|
||||
raise ValueError(response.content or "Unknown provider error")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Connection successful",
|
||||
"details": {
|
||||
"content": response.content,
|
||||
"finish_reason": response.finish_reason,
|
||||
"usage": response.usage,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
|
||||
@@ -0,0 +1,135 @@
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from app.schemas.mcp import MCPServer, MCPServerCreate, MCPServerUpdate
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_mcp_servers_file() -> Path:
|
||||
return get_data_root() / "mcp_servers.json"
|
||||
|
||||
def read_mcp_servers() -> List[dict]:
|
||||
file_path = get_mcp_servers_file()
|
||||
if not file_path.exists():
|
||||
return []
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
def write_mcp_servers(servers: List[dict]) -> None:
|
||||
file_path = get_mcp_servers_file()
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(servers, f, indent=2, ensure_ascii=False)
|
||||
|
||||
async def _check_single_mcp_health(server: dict) -> str:
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
server_type = server.get("type")
|
||||
if server_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=server.get("command", ""),
|
||||
args=server.get("args", []),
|
||||
env=server.get("env")
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
elif server_type in ["sse", "streamableHttp"]:
|
||||
read, write = await stack.enter_async_context(sse_client(server.get("url", "")))
|
||||
else:
|
||||
return "error: unsupported type"
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await asyncio.wait_for(session.initialize(), timeout=5.0)
|
||||
return "connected"
|
||||
except Exception as e:
|
||||
err_msg = str(e)
|
||||
if "unhandled errors in a TaskGroup" in err_msg:
|
||||
return "error: connection refused"
|
||||
return f"error: {err_msg or 'unknown'}"
|
||||
|
||||
@router.get("/mcp", response_model=List[MCPServer])
|
||||
async def list_mcp_servers(project_id: Optional[int] = None):
|
||||
servers = read_mcp_servers()
|
||||
if project_id is not None:
|
||||
servers = [s for s in servers if s.get("project_id") == project_id]
|
||||
|
||||
if not servers:
|
||||
return []
|
||||
|
||||
tasks = [_check_single_mcp_health(s) for s in servers]
|
||||
statuses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
needs_update = False
|
||||
for server, status in zip(servers, statuses):
|
||||
new_status = status if isinstance(status, str) else f"error: {str(status)}"
|
||||
if server.get("status") != new_status:
|
||||
server["status"] = new_status
|
||||
needs_update = True
|
||||
|
||||
if needs_update:
|
||||
# Write back to persist the new statuses
|
||||
all_servers = read_mcp_servers()
|
||||
for s in all_servers:
|
||||
for checked_s in servers:
|
||||
if s.get("id") == checked_s.get("id"):
|
||||
s["status"] = checked_s["status"]
|
||||
write_mcp_servers(all_servers)
|
||||
|
||||
return servers
|
||||
|
||||
@router.post("/mcp", response_model=MCPServer)
|
||||
def create_mcp_server(server_in: MCPServerCreate):
|
||||
servers = read_mcp_servers()
|
||||
|
||||
server_data = server_in.dict()
|
||||
server_data["id"] = str(uuid.uuid4())
|
||||
if "status" not in server_data or not server_data["status"]:
|
||||
server_data["status"] = "disconnected"
|
||||
|
||||
servers.append(server_data)
|
||||
write_mcp_servers(servers)
|
||||
return server_data
|
||||
|
||||
@router.get("/mcp/{server_id}", response_model=MCPServer)
|
||||
def get_mcp_server(server_id: str):
|
||||
servers = read_mcp_servers()
|
||||
for server in servers:
|
||||
if server.get("id") == server_id:
|
||||
return server
|
||||
raise HTTPException(status_code=404, detail="MCP Server not found")
|
||||
|
||||
@router.put("/mcp/{server_id}", response_model=MCPServer)
|
||||
def update_mcp_server(server_id: str, server_in: MCPServerUpdate):
|
||||
servers = read_mcp_servers()
|
||||
for i, server in enumerate(servers):
|
||||
if server.get("id") == server_id:
|
||||
update_data = server_in.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
server[key] = value
|
||||
servers[i] = server
|
||||
write_mcp_servers(servers)
|
||||
return server
|
||||
raise HTTPException(status_code=404, detail="MCP Server not found")
|
||||
|
||||
@router.delete("/mcp/{server_id}")
|
||||
def delete_mcp_server(server_id: str):
|
||||
servers = read_mcp_servers()
|
||||
filtered_servers = [s for s in servers if s.get("id") != server_id]
|
||||
|
||||
if len(servers) == len(filtered_servers):
|
||||
raise HTTPException(status_code=404, detail="MCP Server not found")
|
||||
|
||||
write_mcp_servers(filtered_servers)
|
||||
return {"status": "success"}
|
||||
@@ -0,0 +1,92 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.project import Project
|
||||
from app.schemas.project import ProjectCreate, ProjectUpdate, Project as ProjectSchema
|
||||
from app.core.security import get_current_user, CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/projects", response_model=List[ProjectSchema])
|
||||
def list_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
# Users can only see their own projects, unless they are admin (who can see all?)
|
||||
# For simplicity, let's allow users to see their own projects.
|
||||
query = db.query(Project)
|
||||
if not current_user.is_admin:
|
||||
query = query.filter(Project.owner_id == current_user.id)
|
||||
|
||||
projects = query.offset(skip).limit(limit).all()
|
||||
return projects
|
||||
|
||||
@router.post("/projects", response_model=ProjectSchema)
|
||||
def create_project(
|
||||
project: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_project = Project(**project.dict(), owner_id=current_user.id)
|
||||
db.add(db_project)
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
return db_project
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ProjectSchema)
|
||||
def read_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if db_project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not current_user.is_admin and db_project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
return db_project
|
||||
|
||||
@router.put("/projects/{project_id}", response_model=ProjectSchema)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
project: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if db_project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not current_user.is_admin and db_project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
project_data = project.dict(exclude_unset=True)
|
||||
for key, value in project_data.items():
|
||||
setattr(db_project, key, value)
|
||||
|
||||
db.add(db_project)
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
return db_project
|
||||
|
||||
@router.delete("/projects/{project_id}")
|
||||
def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if db_project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not current_user.is_admin and db_project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
db.delete(db_project)
|
||||
db.commit()
|
||||
return {"status": "success"}
|
||||
@@ -0,0 +1,146 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.mdl import MDLManifest
|
||||
from app.services.mdl import MDLService
|
||||
from app.connectors.factory import get_connector
|
||||
|
||||
router = APIRouter(tags=["semantic"])
|
||||
|
||||
class GenerateMDLRequest(BaseModel):
|
||||
selected_tables: Optional[List[str]] = None
|
||||
selected_columns: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
class ModelDetailResponse(BaseModel):
|
||||
model: Dict[str, Any]
|
||||
relationships: List[Dict[str, Any]]
|
||||
preview_rows: List[Dict[str, Any]]
|
||||
|
||||
def _normalize_query_result(results: Any) -> List[Dict[str, Any]]:
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
return results
|
||||
if results and isinstance(results[0], (list, tuple)):
|
||||
return [dict(enumerate(row)) for row in results]
|
||||
return []
|
||||
if isinstance(results, tuple) and len(results) == 2:
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
return [dict(zip(col_names, row)) for row in rows]
|
||||
return []
|
||||
|
||||
@router.get("/semantic/{datasource_id}/schema", response_model=Dict[str, List[Dict[str, str]]])
|
||||
def get_semantic_schema(datasource_id: int, db: Session = Depends(get_db)):
|
||||
# Check if datasource exists
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
try:
|
||||
raw_schema = MDLService.get_raw_schema(ds)
|
||||
result = {}
|
||||
for table, data in raw_schema.items():
|
||||
if isinstance(data, dict) and "columns" in data:
|
||||
result[table] = data["columns"]
|
||||
elif isinstance(data, list):
|
||||
result[table] = data
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/semantic/{datasource_id}", response_model=MDLManifest)
|
||||
def get_semantic_model(datasource_id: int, db: Session = Depends(get_db)):
|
||||
# Check if datasource exists
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
# Get or generate MDL
|
||||
try:
|
||||
mdl = MDLService.get_or_create_mdl(datasource_id)
|
||||
return mdl
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.put("/semantic/{datasource_id}", response_model=MDLManifest)
|
||||
def update_semantic_model(datasource_id: int, mdl: MDLManifest, db: Session = Depends(get_db)):
|
||||
# Check if datasource exists
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
try:
|
||||
MDLService.save_mdl(datasource_id, mdl)
|
||||
return mdl
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/semantic/{datasource_id}/generate", response_model=MDLManifest)
|
||||
def regenerate_semantic_model(datasource_id: int, request: Optional[GenerateMDLRequest] = None, db: Session = Depends(get_db)):
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
try:
|
||||
selected_tables = request.selected_tables if request else None
|
||||
selected_columns = request.selected_columns if request else None
|
||||
mdl = MDLService.generate_default_mdl(
|
||||
ds,
|
||||
selected_tables=selected_tables,
|
||||
selected_columns=selected_columns,
|
||||
)
|
||||
MDLService.save_mdl(datasource_id, mdl)
|
||||
return mdl
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/semantic/{datasource_id}/models/{model_name}", response_model=ModelDetailResponse)
|
||||
def get_model_detail(datasource_id: int, model_name: str, limit: int = 10, db: Session = Depends(get_db)):
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
mdl = MDLService.get_or_create_mdl(datasource_id)
|
||||
model = next((m for m in mdl.models if m.name == model_name), None)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
relationships = [
|
||||
{
|
||||
"name": rel.name,
|
||||
"models": rel.models,
|
||||
"joinType": rel.joinType,
|
||||
"condition": rel.condition,
|
||||
"properties": rel.properties,
|
||||
}
|
||||
for rel in mdl.relationships
|
||||
if model_name in rel.models
|
||||
]
|
||||
|
||||
preview_rows: List[Dict[str, Any]] = []
|
||||
try:
|
||||
connector = get_connector(ds)
|
||||
table_name = model.tableReference.table if model.tableReference else model.name
|
||||
query = f'SELECT * FROM "{table_name}" LIMIT {max(1, min(limit, 100))}'
|
||||
raw = connector.execute_query(query)
|
||||
preview_rows = _normalize_query_result(raw)
|
||||
except Exception:
|
||||
preview_rows = []
|
||||
|
||||
model_payload = {
|
||||
"name": model.name,
|
||||
"tableReference": model.tableReference.model_dump(by_alias=True) if model.tableReference else None,
|
||||
"primaryKey": model.primaryKey,
|
||||
"properties": model.properties,
|
||||
"columns": [c.model_dump(by_alias=True) for c in model.columns],
|
||||
}
|
||||
|
||||
return ModelDetailResponse(
|
||||
model=model_payload,
|
||||
relationships=relationships,
|
||||
preview_rows=preview_rows,
|
||||
)
|
||||
@@ -0,0 +1,562 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tarfile
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.data_root import get_data_root, get_workspace_root
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR as NANOBOT_BUILTIN_SKILLS_DIR
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
DATA_FILE = str(get_data_root() / "skills.json")
|
||||
SKILL_HUB_DIR = str(get_workspace_root() / "skills")
|
||||
BACKEND_BUILTIN_SKILLS_DIR = str(Path(__file__).resolve().parents[1] / "skills_builtin")
|
||||
|
||||
SOURCE_LOCAL_IMPORT = "local_import"
|
||||
SOURCE_SYSTEM_BUILTIN = "system_builtin"
|
||||
SOURCE_BACKEND_GENERATED = "backend_generated"
|
||||
SOURCE_UPLOADED_FILE = "uploaded_file"
|
||||
|
||||
STATUS_SAFE = "safe"
|
||||
STATUS_LOW_RISK = "low_risk"
|
||||
|
||||
_SOURCE_ALIASES = {
|
||||
SOURCE_LOCAL_IMPORT: SOURCE_LOCAL_IMPORT,
|
||||
"本地导入": SOURCE_LOCAL_IMPORT,
|
||||
"Local Import": SOURCE_LOCAL_IMPORT,
|
||||
SOURCE_SYSTEM_BUILTIN: SOURCE_SYSTEM_BUILTIN,
|
||||
"系统内置": SOURCE_SYSTEM_BUILTIN,
|
||||
"System Built-in": SOURCE_SYSTEM_BUILTIN,
|
||||
SOURCE_BACKEND_GENERATED: SOURCE_BACKEND_GENERATED,
|
||||
"后台生成": SOURCE_BACKEND_GENERATED,
|
||||
"Backend Generated": SOURCE_BACKEND_GENERATED,
|
||||
SOURCE_UPLOADED_FILE: SOURCE_UPLOADED_FILE,
|
||||
"文件上传": SOURCE_UPLOADED_FILE,
|
||||
"File Upload": SOURCE_UPLOADED_FILE,
|
||||
}
|
||||
|
||||
_STATUS_ALIASES = {
|
||||
STATUS_SAFE: STATUS_SAFE,
|
||||
"安全": STATUS_SAFE,
|
||||
"Safe": STATUS_SAFE,
|
||||
STATUS_LOW_RISK: STATUS_LOW_RISK,
|
||||
"低风险": STATUS_LOW_RISK,
|
||||
"Low Risk": STATUS_LOW_RISK,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_source(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return SOURCE_LOCAL_IMPORT
|
||||
return _SOURCE_ALIASES.get(value, value)
|
||||
|
||||
|
||||
def _normalize_status(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return STATUS_SAFE
|
||||
return _STATUS_ALIASES.get(value, value)
|
||||
|
||||
def _ensure_skill_hub_dir() -> None:
|
||||
os.makedirs(SKILL_HUB_DIR, exist_ok=True)
|
||||
|
||||
class Skill(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier for the skill")
|
||||
name: str = Field(..., description="Name of the skill")
|
||||
description: Optional[str] = Field(None, description="Description of what the skill does")
|
||||
content: str = Field(..., description="The content/prompt/logic of the skill")
|
||||
type: str = Field("python", description="Type of the skill (python, sql, api)")
|
||||
project_id: Optional[int] = Field(None, description="The ID of the project this skill belongs to")
|
||||
source: str = Field(SOURCE_LOCAL_IMPORT, description="Stable source key of the skill")
|
||||
installation_time: str = Field(default_factory=lambda: datetime.now().strftime("%Y年%m月%d日"), description="Time when the skill was installed")
|
||||
status: str = Field(STATUS_SAFE, description="Stable security status key")
|
||||
file_path: Optional[str] = Field(None, description="Path to the skill folder in skill-hub")
|
||||
is_builtin: bool = Field(False, description="Whether this is a system builtin skill")
|
||||
|
||||
class SkillCreate(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
content: str
|
||||
type: str = "python"
|
||||
project_id: Optional[int] = None
|
||||
source: str = SOURCE_LOCAL_IMPORT
|
||||
installation_time: Optional[str] = None
|
||||
status: str = STATUS_SAFE
|
||||
file_path: Optional[str] = None
|
||||
|
||||
class SkillUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
project_id: Optional[int] = None
|
||||
source: Optional[str] = None
|
||||
installation_time: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
file_path: Optional[str] = None
|
||||
|
||||
def _parse_skill_md(file_path: str) -> Dict[str, Any]:
|
||||
"""Parse SKILL.md for metadata and content according to agentskills.io standard."""
|
||||
if not os.path.exists(file_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
return {}
|
||||
|
||||
# Split YAML frontmatter and Markdown body
|
||||
# Support both --- and +++ for frontmatter
|
||||
metadata = {}
|
||||
body = content
|
||||
|
||||
if content.startswith('---'):
|
||||
parts = content.split('---', 2)
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
metadata = yaml.safe_load(parts[1]) or {}
|
||||
body = parts[2].strip()
|
||||
except Exception as e:
|
||||
print(f"Error parsing YAML frontmatter: {e}")
|
||||
|
||||
# Extract name and description, fallback to some defaults
|
||||
name = metadata.get("name")
|
||||
description = metadata.get("description")
|
||||
|
||||
# If name not in metadata, try to find the first H1 in markdown body
|
||||
if not name:
|
||||
for line in body.split('\n'):
|
||||
if line.startswith('# '):
|
||||
name = line[2:].strip()
|
||||
break
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"content": body,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
def _load_data() -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(DATA_FILE):
|
||||
return []
|
||||
try:
|
||||
with open(DATA_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
return []
|
||||
|
||||
def _save_data(data: List[Dict[str, Any]]):
|
||||
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
|
||||
with open(DATA_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _dedupe_skills(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
deduped: Dict[str, Dict[str, Any]] = {}
|
||||
for item in data:
|
||||
skill_id = str(item.get("id") or "").strip()
|
||||
project_id = item.get("project_id")
|
||||
if not skill_id:
|
||||
continue
|
||||
|
||||
# Use a composite key of (id, project_id) for deduplication
|
||||
# so that different projects can theoretically have the same skill_id
|
||||
dedupe_key = f"{skill_id}_{project_id}"
|
||||
|
||||
existing = deduped.get(dedupe_key)
|
||||
if existing is None:
|
||||
deduped[dedupe_key] = item
|
||||
continue
|
||||
|
||||
# If they somehow have the exact same dedupe_key, we just keep the later one
|
||||
deduped[dedupe_key] = item
|
||||
|
||||
return list(deduped.values())
|
||||
|
||||
def _safe_skill_dir_name(value: str) -> str:
|
||||
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
|
||||
return safe or "skill"
|
||||
|
||||
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
|
||||
os.makedirs(skill_dir, exist_ok=True)
|
||||
skill_md_path = os.path.join(skill_dir, "SKILL.md")
|
||||
final_description = description or "No description provided"
|
||||
body = content or ""
|
||||
markdown = (
|
||||
f"---\n"
|
||||
f"name: {skill_name}\n"
|
||||
f"description: {final_description}\n"
|
||||
f"---\n\n"
|
||||
f"{body}\n"
|
||||
)
|
||||
with open(skill_md_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown)
|
||||
return skill_md_path
|
||||
|
||||
def _scan_builtin_skills(data: List[Dict[str, Any]], registered_paths: set, source_dir: str, source_name: str):
|
||||
if not os.path.exists(source_dir):
|
||||
return
|
||||
for item in os.listdir(source_dir):
|
||||
skill_dir = os.path.abspath(os.path.join(source_dir, item))
|
||||
if os.path.isdir(skill_dir):
|
||||
skill_md_path = os.path.join(skill_dir, "SKILL.md")
|
||||
if os.path.exists(skill_md_path):
|
||||
metadata_res = _parse_skill_md(skill_md_path)
|
||||
skill_name = metadata_res.get("name") or item
|
||||
|
||||
existing = None
|
||||
for d in data:
|
||||
if (d.get("id") == item and d.get("is_builtin")) or d.get("file_path") == skill_dir:
|
||||
existing = d
|
||||
break
|
||||
|
||||
if existing:
|
||||
existing["name"] = skill_name
|
||||
existing["description"] = metadata_res.get("description") or "No description provided"
|
||||
existing["content"] = metadata_res.get("content") or ""
|
||||
existing["file_path"] = skill_dir
|
||||
existing["is_builtin"] = True
|
||||
existing["source"] = source_name
|
||||
existing["status"] = STATUS_SAFE
|
||||
registered_paths.add(skill_dir)
|
||||
else:
|
||||
new_skill = {
|
||||
"id": item,
|
||||
"name": skill_name,
|
||||
"description": metadata_res.get("description") or "No description provided",
|
||||
"content": metadata_res.get("content") or "",
|
||||
"type": "agentskill",
|
||||
"project_id": None,
|
||||
"source": source_name,
|
||||
"installation_time": datetime.now().strftime("%Y年%m月%d日"),
|
||||
"status": STATUS_SAFE,
|
||||
"file_path": skill_dir,
|
||||
"is_builtin": True
|
||||
}
|
||||
data.append(new_skill)
|
||||
registered_paths.add(skill_dir)
|
||||
|
||||
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
_ensure_skill_hub_dir()
|
||||
data = _load_data()
|
||||
|
||||
registered_paths = set()
|
||||
|
||||
# Sync registered skills with their SKILL.md if available
|
||||
for item in data:
|
||||
item["source"] = _normalize_source(item.get("source"))
|
||||
item["status"] = _normalize_status(item.get("status"))
|
||||
if item.get("id") in ("nl2sql", "visualization") or item.get("is_builtin"):
|
||||
item["is_builtin"] = True
|
||||
else:
|
||||
item.setdefault("is_builtin", False)
|
||||
|
||||
if item.get("file_path"):
|
||||
abs_path = os.path.abspath(item["file_path"])
|
||||
registered_paths.add(abs_path)
|
||||
skill_md_path = os.path.join(abs_path, "SKILL.md")
|
||||
if os.path.exists(skill_md_path):
|
||||
metadata_res = _parse_skill_md(skill_md_path)
|
||||
if metadata_res.get("name"):
|
||||
item["name"] = metadata_res["name"]
|
||||
if metadata_res.get("description"):
|
||||
item["description"] = metadata_res["description"]
|
||||
if metadata_res.get("content"):
|
||||
item["content"] = metadata_res["content"]
|
||||
|
||||
# Scan builtin skills
|
||||
_scan_builtin_skills(data, registered_paths, NANOBOT_BUILTIN_SKILLS_DIR, SOURCE_SYSTEM_BUILTIN)
|
||||
_scan_builtin_skills(data, registered_paths, BACKEND_BUILTIN_SKILLS_DIR, SOURCE_SYSTEM_BUILTIN)
|
||||
|
||||
# Scan for unregistered skills in SKILL_HUB_DIR (1-level deep to match nanobot's behavior)
|
||||
if os.path.exists(SKILL_HUB_DIR):
|
||||
for item in os.listdir(SKILL_HUB_DIR):
|
||||
skill_dir = os.path.abspath(os.path.join(SKILL_HUB_DIR, item))
|
||||
if os.path.isdir(skill_dir):
|
||||
skill_md_path = os.path.join(skill_dir, "SKILL.md")
|
||||
if os.path.exists(skill_md_path) and skill_dir not in registered_paths:
|
||||
metadata_res = _parse_skill_md(skill_md_path)
|
||||
skill_name = metadata_res.get("name") or item
|
||||
|
||||
# Try to deduce project_id from directory prefix (e.g., p123_skillname)
|
||||
deduced_project_id = None
|
||||
match = re.match(r'^p(\d+)_', item)
|
||||
if match:
|
||||
deduced_project_id = int(match.group(1))
|
||||
|
||||
new_skill = {
|
||||
"id": item,
|
||||
"name": skill_name,
|
||||
"description": metadata_res.get("description") or "No description provided",
|
||||
"content": metadata_res.get("content") or "",
|
||||
"type": "agentskill",
|
||||
"project_id": deduced_project_id,
|
||||
"source": SOURCE_BACKEND_GENERATED,
|
||||
"installation_time": datetime.now().strftime("%Y年%m月%d日"),
|
||||
"status": STATUS_SAFE,
|
||||
"file_path": skill_dir,
|
||||
"is_builtin": item in ("nl2sql", "visualization")
|
||||
}
|
||||
data.append(new_skill)
|
||||
registered_paths.add(skill_dir)
|
||||
|
||||
deduped = _dedupe_skills(data)
|
||||
if project_id is not None:
|
||||
return [item for item in deduped if item.get("project_id") == project_id or item.get("project_id") is None]
|
||||
return deduped
|
||||
|
||||
@router.get("/skills", response_model=List[Skill])
|
||||
def list_skills(project_id: Optional[int] = None):
|
||||
data = load_skills(project_id)
|
||||
return [Skill(**item) for item in data]
|
||||
|
||||
@router.get("/skills/{skill_id}", response_model=Skill)
|
||||
def get_skill(skill_id: str, project_id: Optional[int] = None):
|
||||
data = load_skills()
|
||||
for item in data:
|
||||
if item["id"] == skill_id:
|
||||
if project_id is not None and item.get("project_id") != project_id:
|
||||
continue
|
||||
return Skill(**item)
|
||||
raise HTTPException(status_code=404, detail="Skill not found")
|
||||
|
||||
@router.post("/skills/upload")
|
||||
async def upload_skill(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None)
|
||||
):
|
||||
"""Upload a skill file (SKILL.md) or a packaged skill (zip/tar.gz)."""
|
||||
filename = file.filename
|
||||
print(f"Uploading skill: {filename}, project_id: {project_id}")
|
||||
_ensure_skill_hub_dir()
|
||||
|
||||
# Create a unique temp directory
|
||||
temp_dir_name = f"temp_{datetime.now().timestamp()}_{os.urandom(4).hex()}"
|
||||
temp_dir = os.path.join(SKILL_HUB_DIR, temp_dir_name)
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
file_path = os.path.join(temp_dir, filename)
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
skill_source_dir = None
|
||||
|
||||
# Handle different file types
|
||||
if filename.endswith(".zip"):
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(temp_dir)
|
||||
os.remove(file_path)
|
||||
# Find the directory containing SKILL.md
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
if "SKILL.md" in files:
|
||||
skill_source_dir = root
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Zip extraction failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to extract zip: {str(e)}")
|
||||
|
||||
elif filename.endswith((".tar.gz", ".tgz")):
|
||||
try:
|
||||
with tarfile.open(file_path, 'r:gz') as tar_ref:
|
||||
tar_ref.extractall(temp_dir)
|
||||
os.remove(file_path)
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
if "SKILL.md" in files:
|
||||
skill_source_dir = root
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Tarball extraction failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to extract tarball: {str(e)}")
|
||||
|
||||
elif filename == "SKILL.md":
|
||||
skill_source_dir = temp_dir
|
||||
else:
|
||||
print(f"Unsupported file type: {filename}")
|
||||
raise HTTPException(status_code=400, detail="Only SKILL.md or packaged skills (zip/tar.gz) are supported")
|
||||
|
||||
if not skill_source_dir or not os.path.exists(os.path.join(skill_source_dir, "SKILL.md")):
|
||||
print(f"SKILL.md not found in {filename}")
|
||||
raise HTTPException(status_code=400, detail="SKILL.md not found in the uploaded file")
|
||||
|
||||
# Parse metadata
|
||||
skill_md_path = os.path.join(skill_source_dir, "SKILL.md")
|
||||
metadata_res = _parse_skill_md(skill_md_path)
|
||||
|
||||
# Use metadata name, or fallback to folder name or filename
|
||||
skill_name = metadata_res.get("name")
|
||||
if not skill_name:
|
||||
if filename == "SKILL.md":
|
||||
skill_name = "unnamed_skill"
|
||||
else:
|
||||
# Use filename without extension
|
||||
skill_name = os.path.splitext(filename)[0]
|
||||
|
||||
# Create a safe directory name for the skill
|
||||
safe_name = _safe_skill_dir_name(skill_name)
|
||||
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
if project_id is not None:
|
||||
# Prefix the folder name with p{project_id}_ to distinguish projects in storage
|
||||
# without breaking nanobot's 1-level-deep skill loader
|
||||
final_skill_dir = os.path.join(SKILL_HUB_DIR, f"p{project_id}_{final_skill_id}")
|
||||
final_skill_id = f"p{project_id}_{final_skill_id}"
|
||||
else:
|
||||
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
|
||||
|
||||
print(f"Finalizing skill: {skill_name} -> {final_skill_dir}")
|
||||
|
||||
# Move the skill content to final destination
|
||||
os.makedirs(final_skill_dir, exist_ok=True)
|
||||
for item in os.listdir(skill_source_dir):
|
||||
s = os.path.join(skill_source_dir, item)
|
||||
d = os.path.join(final_skill_dir, item)
|
||||
if os.path.isdir(s):
|
||||
shutil.copytree(s, d, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy2(s, d)
|
||||
|
||||
# Register in skills.json
|
||||
data = load_skills()
|
||||
new_skill = {
|
||||
"id": final_skill_id,
|
||||
"name": skill_name,
|
||||
"description": metadata_res.get("description") or "No description provided",
|
||||
"content": metadata_res.get("content") or "",
|
||||
"type": "agentskill",
|
||||
"project_id": project_id,
|
||||
"source": SOURCE_UPLOADED_FILE,
|
||||
"installation_time": datetime.now().strftime("%Y年%m月%d日"),
|
||||
"status": STATUS_SAFE,
|
||||
"file_path": final_skill_dir
|
||||
}
|
||||
|
||||
data.append(new_skill)
|
||||
_save_data(data)
|
||||
print(f"Skill registered successfully: {final_skill_id}")
|
||||
|
||||
return new_skill
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
||||
finally:
|
||||
# Cleanup temp directory
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@router.post("/skills", response_model=Skill)
|
||||
def create_skill(skill: SkillCreate):
|
||||
_ensure_skill_hub_dir()
|
||||
data = load_skills()
|
||||
if any(item["id"] == skill.id and item.get("project_id") == skill.project_id for item in data):
|
||||
raise HTTPException(status_code=400, detail="Skill with this ID already exists in this project")
|
||||
|
||||
new_skill_dict = skill.dict()
|
||||
new_skill_dict["source"] = _normalize_source(new_skill_dict.get("source"))
|
||||
new_skill_dict["status"] = _normalize_status(new_skill_dict.get("status"))
|
||||
if not new_skill_dict.get("installation_time"):
|
||||
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d日")
|
||||
if not new_skill_dict.get("file_path"):
|
||||
project_id = new_skill_dict.get("project_id")
|
||||
base_dir_name = _safe_skill_dir_name(new_skill_dict["id"])
|
||||
if project_id is not None:
|
||||
# Add prefix for project storage distinction
|
||||
if not base_dir_name.startswith(f"p{project_id}_"):
|
||||
base_dir_name = f"p{project_id}_{base_dir_name}"
|
||||
skill_dir = os.path.join(SKILL_HUB_DIR, base_dir_name)
|
||||
else:
|
||||
skill_dir = os.path.join(SKILL_HUB_DIR, base_dir_name)
|
||||
|
||||
_write_skill_markdown(
|
||||
skill_dir=skill_dir,
|
||||
skill_name=new_skill_dict["name"],
|
||||
description=new_skill_dict.get("description"),
|
||||
content=new_skill_dict.get("content", ""),
|
||||
)
|
||||
new_skill_dict["file_path"] = skill_dir
|
||||
new_skill_dict["id"] = base_dir_name
|
||||
|
||||
data.append(new_skill_dict)
|
||||
_save_data(data)
|
||||
return Skill(**new_skill_dict)
|
||||
|
||||
@router.put("/skills/{skill_id}", response_model=Skill)
|
||||
def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] = None):
|
||||
data = load_skills()
|
||||
for i, item in enumerate(data):
|
||||
if item["id"] == skill_id:
|
||||
if project_id is not None and item.get("project_id") != project_id:
|
||||
continue
|
||||
updated_item = item.copy()
|
||||
update_data = skill.dict(exclude_unset=True)
|
||||
if "source" in update_data:
|
||||
update_data["source"] = _normalize_source(update_data.get("source"))
|
||||
if "status" in update_data:
|
||||
update_data["status"] = _normalize_status(update_data.get("status"))
|
||||
updated_item.update(update_data)
|
||||
if updated_item.get("file_path"):
|
||||
_write_skill_markdown(
|
||||
skill_dir=updated_item["file_path"],
|
||||
skill_name=updated_item.get("name") or item.get("name") or "skill",
|
||||
description=updated_item.get("description"),
|
||||
content=updated_item.get("content", ""),
|
||||
)
|
||||
data[i] = updated_item
|
||||
_save_data(data)
|
||||
return Skill(**updated_item)
|
||||
raise HTTPException(status_code=404, detail="Skill not found")
|
||||
|
||||
@router.delete("/skills/{skill_id}")
|
||||
def delete_skill(skill_id: str, project_id: Optional[int] = None):
|
||||
data = load_skills()
|
||||
initial_len = len(data)
|
||||
|
||||
# If project_id is provided, we only delete if it matches
|
||||
new_data = []
|
||||
found = False
|
||||
skill_to_delete = None
|
||||
|
||||
for item in data:
|
||||
if item["id"] == skill_id:
|
||||
if item.get("is_builtin"):
|
||||
raise HTTPException(status_code=400, detail="Builtin skills cannot be deleted")
|
||||
if project_id is not None and item.get("project_id") not in (project_id, None):
|
||||
new_data.append(item)
|
||||
continue
|
||||
found = True
|
||||
skill_to_delete = item
|
||||
else:
|
||||
new_data.append(item)
|
||||
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail="Skill not found")
|
||||
|
||||
# Clean up file_path if it exists
|
||||
if skill_to_delete and skill_to_delete.get("file_path"):
|
||||
file_path = skill_to_delete["file_path"]
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
if os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
else:
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"Error deleting skill files at {file_path}: {e}")
|
||||
|
||||
_save_data(new_data)
|
||||
return {"message": "Skill deleted successfully"}
|
||||
@@ -0,0 +1,106 @@
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.subagent import Subagent
|
||||
from app.models.project import Project
|
||||
from app.schemas.subagent import SubagentCreate, SubagentUpdate, Subagent as SubagentSchema
|
||||
from app.core.security import get_current_user, CurrentUser
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/projects/{project_id}/subagents", response_model=List[SubagentSchema])
|
||||
def list_subagents(
|
||||
project_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not current_user.is_admin and project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
subagents = db.query(Subagent).filter(Subagent.project_id == project_id).offset(skip).limit(limit).all()
|
||||
return subagents
|
||||
|
||||
@router.post("/projects/{project_id}/subagents", response_model=SubagentSchema)
|
||||
def create_subagent(
|
||||
project_id: int,
|
||||
subagent: SubagentCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not current_user.is_admin and project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
db_subagent = Subagent(**subagent.dict(), project_id=project_id)
|
||||
db.add(db_subagent)
|
||||
db.commit()
|
||||
db.refresh(db_subagent)
|
||||
return db_subagent
|
||||
|
||||
@router.get("/subagents/{subagent_id}", response_model=SubagentSchema)
|
||||
def read_subagent(
|
||||
subagent_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_subagent = db.query(Subagent).filter(Subagent.id == subagent_id).first()
|
||||
if db_subagent is None:
|
||||
raise HTTPException(status_code=404, detail="Subagent not found")
|
||||
|
||||
project = db.query(Project).filter(Project.id == db_subagent.project_id).first()
|
||||
if not current_user.is_admin and project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
return db_subagent
|
||||
|
||||
@router.put("/subagents/{subagent_id}", response_model=SubagentSchema)
|
||||
def update_subagent(
|
||||
subagent_id: int,
|
||||
subagent: SubagentUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_subagent = db.query(Subagent).filter(Subagent.id == subagent_id).first()
|
||||
if db_subagent is None:
|
||||
raise HTTPException(status_code=404, detail="Subagent not found")
|
||||
|
||||
project = db.query(Project).filter(Project.id == db_subagent.project_id).first()
|
||||
if not current_user.is_admin and project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
subagent_data = subagent.dict(exclude_unset=True)
|
||||
for key, value in subagent_data.items():
|
||||
setattr(db_subagent, key, value)
|
||||
|
||||
db.add(db_subagent)
|
||||
db.commit()
|
||||
db.refresh(db_subagent)
|
||||
return db_subagent
|
||||
|
||||
@router.delete("/subagents/{subagent_id}")
|
||||
def delete_subagent(
|
||||
subagent_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_subagent = db.query(Subagent).filter(Subagent.id == subagent_id).first()
|
||||
if db_subagent is None:
|
||||
raise HTTPException(status_code=404, detail="Subagent not found")
|
||||
|
||||
project = db.query(Project).filter(Project.id == db_subagent.project_id).first()
|
||||
if not current_user.is_admin and project.owner_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
db.delete(db_subagent)
|
||||
db.commit()
|
||||
return {"status": "success"}
|
||||
@@ -0,0 +1,73 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
import pandas as pd
|
||||
import duckdb
|
||||
import io
|
||||
import uuid
|
||||
|
||||
from app.core.data_root import get_uploads_root
|
||||
|
||||
router = APIRouter()
|
||||
upload_dir = get_uploads_root()
|
||||
|
||||
@router.post("/upload/file")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
allowed_extensions = ('.csv', '.xls', '.xlsx', '.parquet', '.db', '.sqlite', '.sqlite3')
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith(allowed_extensions):
|
||||
raise HTTPException(status_code=400, detail="Invalid file type. Allowed: CSV, Excel, Parquet, SQLite.")
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="Empty file is not allowed.")
|
||||
file_obj = io.BytesIO(content)
|
||||
|
||||
unique_filename = f"{uuid.uuid4()}-{file.filename}"
|
||||
save_path = upload_dir / unique_filename
|
||||
save_path.write_bytes(content)
|
||||
file_url = f"local://{unique_filename}"
|
||||
|
||||
file_obj.seek(0)
|
||||
|
||||
try:
|
||||
if filename_lower.endswith('.csv'):
|
||||
df = pd.read_csv(file_obj)
|
||||
elif filename_lower.endswith(('.xls', '.xlsx')):
|
||||
df = pd.read_excel(file_obj)
|
||||
elif filename_lower.endswith('.parquet'):
|
||||
df = pd.read_parquet(file_obj)
|
||||
elif filename_lower.endswith(('.db', '.sqlite', '.sqlite3')):
|
||||
# For SQLite, we don't load into DF immediately for analysis here
|
||||
# Just return success
|
||||
return {
|
||||
"filename": unique_filename,
|
||||
"url": file_url,
|
||||
"rows": 0,
|
||||
"columns": [],
|
||||
"summary": "SQLite database uploaded"
|
||||
}
|
||||
|
||||
# For DF supported types
|
||||
duckdb_conn = duckdb.connect(database=':memory:')
|
||||
duckdb_conn.register('uploaded_file', df)
|
||||
summary = duckdb_conn.execute("DESCRIBE uploaded_file").fetchall()
|
||||
row_count = len(df)
|
||||
columns = list(df.columns)
|
||||
|
||||
return {
|
||||
"filename": unique_filename,
|
||||
"url": file_url,
|
||||
"rows": row_count,
|
||||
"columns": columns,
|
||||
"summary": str(summary)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"filename": unique_filename,
|
||||
"url": file_url,
|
||||
"analysis_error": str(e)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,215 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.user import User, EmailVerification
|
||||
from app.schemas.user import UserCreate, UserUpdate, UserResponse, ResendVerificationRequest
|
||||
from app.core.security import get_password_hash, verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
from app.core.email import send_verification_email
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def generate_verification_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@router.post("/auth/login")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username, "is_admin": user.is_admin, "id": user.id},
|
||||
expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"avatar": user.avatar,
|
||||
"is_admin": user.is_admin
|
||||
}
|
||||
}
|
||||
|
||||
@router.post("/auth/register", response_model=UserResponse)
|
||||
def register_user(user: UserCreate, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.username == user.username).first()
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
db_user_email = db.query(User).filter(User.email == user.email).first()
|
||||
if db_user_email:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
# If this is the first user, make them an admin
|
||||
is_first_user = db.query(User).count() == 0
|
||||
is_admin = is_first_user or user.is_admin
|
||||
is_active = True if is_first_user else False
|
||||
|
||||
db_user = User(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
avatar=user.avatar,
|
||||
hashed_password=hashed_password,
|
||||
is_active=is_active,
|
||||
is_admin=is_admin
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
if not is_active:
|
||||
token = generate_verification_token()
|
||||
hashed = hash_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
verification = EmailVerification(
|
||||
user_id=db_user.id,
|
||||
token_hash=hashed,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(verification)
|
||||
db.commit()
|
||||
|
||||
# 将用户的 email 保存到局部变量中,防止在后台任务执行前 session 关闭导致延迟加载失败
|
||||
user_email = db_user.email
|
||||
background_tasks.add_task(send_verification_email, user_email, token)
|
||||
|
||||
return db_user
|
||||
|
||||
@router.get("/auth/verify-email")
|
||||
def verify_email(token: str, db: Session = Depends(get_db)):
|
||||
hashed = hash_token(token)
|
||||
verification = db.query(EmailVerification).filter(
|
||||
EmailVerification.token_hash == hashed,
|
||||
EmailVerification.is_used == False
|
||||
).first()
|
||||
|
||||
if not verification:
|
||||
raise HTTPException(status_code=400, detail="Invalid or used token")
|
||||
|
||||
# Check if expired (make timezone-aware if naive)
|
||||
expires_at = verification.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if expires_at < datetime.now(timezone.utc):
|
||||
raise HTTPException(status_code=400, detail="Token expired")
|
||||
|
||||
user = db.query(User).filter(User.id == verification.user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
user.is_active = True
|
||||
verification.is_used = True
|
||||
db.commit()
|
||||
|
||||
return {"status": "success", "message": "Email verified successfully"}
|
||||
|
||||
@router.post("/auth/resend-verification")
|
||||
def resend_verification(request: ResendVerificationRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.username == request.username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
if user.is_active:
|
||||
raise HTTPException(status_code=400, detail="User already active")
|
||||
|
||||
token = generate_verification_token()
|
||||
hashed = hash_token(token)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
verification = EmailVerification(
|
||||
user_id=user.id,
|
||||
token_hash=hashed,
|
||||
expires_at=expires_at
|
||||
)
|
||||
db.add(verification)
|
||||
db.commit()
|
||||
|
||||
# 提取 email,避免后台任务访问已断开的 db session
|
||||
user_email = user.email
|
||||
background_tasks.add_task(send_verification_email, user_email, token)
|
||||
return {"status": "success", "message": "Verification email sent"}
|
||||
|
||||
@router.get("/users", response_model=List[UserResponse])
|
||||
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
users = db.query(User).offset(skip).limit(limit).all()
|
||||
return users
|
||||
|
||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||
def read_user(user_id: int, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.id == user_id).first()
|
||||
if db_user is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return db_user
|
||||
|
||||
@router.post("/users", response_model=UserResponse)
|
||||
def create_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.username == user.username).first()
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="Username already registered")
|
||||
|
||||
db_user_email = db.query(User).filter(User.email == user.email).first()
|
||||
if db_user_email:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
|
||||
db_user = User(
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
avatar=user.avatar,
|
||||
hashed_password=get_password_hash(user.password),
|
||||
is_active=user.is_active,
|
||||
is_admin=user.is_admin
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
@router.put("/users/{user_id}", response_model=UserResponse)
|
||||
def update_user(user_id: int, user: UserUpdate, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.id == user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
update_data = user.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
if key == "password" and value:
|
||||
db_user.hashed_password = get_password_hash(value)
|
||||
elif key != "password":
|
||||
setattr(db_user, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
def delete_user(user_id: int, db: Session = Depends(get_db)):
|
||||
db_user = db.query(User).filter(User.id == user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
db.delete(db_user)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
@@ -0,0 +1,31 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.llm import get_current_user, get_admin_user, CurrentUser
|
||||
from app.services.web_search_config_store import get_web_search_config, save_web_search_config
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class WebSearchConfigModel(BaseModel):
|
||||
provider: str = Field(default="duckduckgo", description="Web search provider (brave, tavily, duckduckgo, searxng, jina)")
|
||||
api_key: Optional[str] = Field(default="", description="API Key for the provider")
|
||||
base_url: Optional[str] = Field(default="", description="Base URL for SearXNG")
|
||||
max_results: int = Field(default=5, description="Maximum number of search results")
|
||||
|
||||
def _sanitize_config(config: Dict[str, Any], is_admin: bool) -> Dict[str, Any]:
|
||||
sanitized = config.copy()
|
||||
if not is_admin:
|
||||
sanitized["api_key"] = None
|
||||
return sanitized
|
||||
|
||||
@router.get("/web-search/config", response_model=WebSearchConfigModel)
|
||||
def get_config(current_user: CurrentUser = Depends(get_current_user)):
|
||||
config = get_web_search_config()
|
||||
return WebSearchConfigModel(**_sanitize_config(config, current_user.is_admin))
|
||||
|
||||
@router.put("/web-search/config", response_model=WebSearchConfigModel)
|
||||
def update_config(config: WebSearchConfigModel, _: CurrentUser = Depends(get_admin_user)):
|
||||
config_dict = config.dict()
|
||||
save_web_search_config(config_dict)
|
||||
return WebSearchConfigModel(**config_dict)
|
||||
Reference in New Issue
Block a user