Update 2026-05-13 16:43:53

This commit is contained in:
yi
2026-05-13 16:43:53 +08:00
parent 6af5c584f4
commit afd7c5fe85
490 changed files with 850 additions and 922 deletions
View File
File diff suppressed because it is too large Load Diff
+139
View File
@@ -0,0 +1,139 @@
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Depends, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.datasource import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSource as DataSourceSchema, DataSourceTestRequest
from app.core.security import get_current_user, get_admin_user, CurrentUser
from app.connectors.factory import get_connector_from_config
from pydantic import BaseModel
router = APIRouter()
@router.get("/datasources", response_model=List[DataSourceSchema])
def list_datasources(
project_id: Optional[int] = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
query = db.query(DataSource)
if project_id:
query = query.filter(DataSource.project_id == project_id)
# If not admin, check if user has access to the project
if not current_user.is_admin and project_id:
from app.models.project import Project
project = db.query(Project).filter(Project.id == project_id).first()
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions for this project")
datasources = query.offset(skip).limit(limit).all()
# Hide sensitive info for non-admins if necessary, but config usually contains secrets.
# Maybe we should return a sanitized version for regular users?
# For now, return full config but only to admins?
# Or just assume the API is secure.
# If regular users need to select datasource, they just need ID and Name.
if not current_user.is_admin:
# Sanitize config
sanitized = []
for ds in datasources:
ds_dict = DataSourceSchema.from_orm(ds).dict()
# Remove sensitive fields from config
if ds_dict.get("config"):
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
sanitized.append(ds_dict)
return sanitized
return datasources
@router.post("/datasources", response_model=DataSourceSchema)
def create_datasource(
datasource: DataSourceCreate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
# Check if project exists and user has access
from app.models.project import Project
project = db.query(Project).filter(Project.id == datasource.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not current_user.is_admin and project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions for this project")
db_datasource = DataSource(**datasource.dict())
db.add(db_datasource)
db.commit()
db.refresh(db_datasource)
return db_datasource
@router.get("/datasources/{datasource_id}", response_model=DataSourceSchema)
def read_datasource(
datasource_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if db_datasource is None:
raise HTTPException(status_code=404, detail="Data source not found")
if not current_user.is_admin:
ds_dict = DataSourceSchema.from_orm(db_datasource).dict()
if ds_dict.get("config"):
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
return ds_dict
return db_datasource
@router.put("/datasources/{datasource_id}", response_model=DataSourceSchema)
def update_datasource(
datasource_id: int,
datasource: DataSourceUpdate,
db: Session = Depends(get_db),
_: CurrentUser = Depends(get_admin_user)
):
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if db_datasource is None:
raise HTTPException(status_code=404, detail="Data source not found")
update_data = datasource.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_datasource, key, value)
db.commit()
db.refresh(db_datasource)
return db_datasource
@router.delete("/datasources/{datasource_id}")
def delete_datasource(
datasource_id: int,
db: Session = Depends(get_db),
_: CurrentUser = Depends(get_admin_user)
):
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if db_datasource is None:
raise HTTPException(status_code=404, detail="Data source not found")
db.delete(db_datasource)
db.commit()
return {"ok": True}
@router.post("/datasources/test")
def test_datasource_connection(
request: DataSourceTestRequest,
_: CurrentUser = Depends(get_admin_user)
):
try:
connector = get_connector_from_config(request.type, request.config)
if connector.test_connection():
return {"success": True, "message": "Connection successful"}
else:
raise HTTPException(status_code=400, detail="Connection failed")
except Exception as e:
import traceback
import sys
print(f"Datasource Test Error: {str(e)}\n{traceback.format_exc()}", file=sys.stderr)
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
+96
View File
@@ -0,0 +1,96 @@
from typing import List, Optional
from fastapi import APIRouter, HTTPException, Depends
from openai import OpenAI
from app.schemas.embedding_model import (
EmbeddingModelConfig,
EmbeddingModelConfigCreate,
EmbeddingModelConfigUpdate,
EmbeddingModelConnectionTestRequest
)
from app.services.embedding_model_store import embedding_model_store
from app.services.openai_compat import normalize_openai_base_url
from app.api.llm import get_admin_user, get_current_user, CurrentUser
router = APIRouter()
def _mask_api_key(value: Optional[str]) -> Optional[str]:
if not value:
return None
if len(value) <= 8:
return "*" * len(value)
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
@router.get("/embedding-models", response_model=List[EmbeddingModelConfig])
def list_embedding_models(current_user: CurrentUser = Depends(get_current_user)):
models = embedding_model_store.list_models()
for m in models:
if not current_user.is_admin:
m["api_key"] = None
return models
@router.post("/embedding-models", response_model=EmbeddingModelConfig)
def create_embedding_model(payload: EmbeddingModelConfigCreate, _: CurrentUser = Depends(get_admin_user)):
return embedding_model_store.create_model(payload.model_dump())
@router.get("/embedding-models/{model_id}", response_model=EmbeddingModelConfig)
def get_embedding_model(model_id: str, current_user: CurrentUser = Depends(get_current_user)):
model = embedding_model_store.get_model(model_id)
if not model:
raise HTTPException(status_code=404, detail="Embedding model not found")
if not current_user.is_admin:
model["api_key"] = None
return model
@router.put("/embedding-models/{model_id}", response_model=EmbeddingModelConfig)
def update_embedding_model(model_id: str, payload: EmbeddingModelConfigUpdate, _: CurrentUser = Depends(get_admin_user)):
model = embedding_model_store.update_model(model_id, payload.model_dump(exclude_unset=True))
if not model:
raise HTTPException(status_code=404, detail="Embedding model not found")
return model
@router.delete("/embedding-models/{model_id}")
def delete_embedding_model(model_id: str, _: CurrentUser = Depends(get_admin_user)):
deleted = embedding_model_store.delete_model(model_id)
if not deleted:
raise HTTPException(status_code=404, detail="Embedding model not found")
return {"status": "success"}
@router.post("/embedding-models/test")
def test_embedding_model_connection(payload: EmbeddingModelConnectionTestRequest, _: CurrentUser = Depends(get_admin_user)):
api_base = normalize_openai_base_url(payload.api_base or "")
api_key = payload.api_key
model_name = (payload.model or "").strip()
if not api_base:
raise HTTPException(status_code=400, detail="API Base is required")
if not api_key:
raise HTTPException(status_code=400, detail="API Key is required")
if not model_name:
raise HTTPException(status_code=400, detail="Model name is required")
try:
client = OpenAI(
api_key=api_key,
base_url=api_base,
)
embedding_resp = client.embeddings.create(
model=model_name,
input="connection test",
)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Embedding call failed: {exc}")
dimension = None
if getattr(embedding_resp, "data", None):
first = embedding_resp.data[0]
vector = getattr(first, "embedding", None)
if isinstance(vector, list):
dimension = len(vector)
return {
"success": True,
"message": "Connection successful",
"model_name": model_name,
"embedding_dimension": dimension,
}
+302
View File
@@ -0,0 +1,302 @@
from typing import List, Optional
import io
import json
from fastapi import APIRouter, HTTPException
from fastapi import UploadFile, File, Form
from openai import OpenAI
import pandas as pd
from app.schemas.knowledge import (
KnowledgeBase,
KnowledgeBaseCreate,
KnowledgeConnectionTestRequest,
KnowledgeConnectionTestResponse,
KnowledgeGlobalConfig,
KnowledgeGlobalConfigUpdate,
KnowledgeBaseUpdate,
KnowledgeDocument,
KnowledgeDocumentCreate,
KnowledgeDocumentUpdate,
KnowledgeSearchRequest,
KnowledgeSearchResponse,
)
from app.services.knowledge_base_store import knowledge_base_store
from app.services.knowledge_global_config_store import knowledge_global_config_store
from app.services.knowledge_index import knowledge_index_service
from app.services.openai_compat import normalize_openai_base_url
router = APIRouter()
def _mask_api_key(value: Optional[str]) -> Optional[str]:
if not value:
return None
if len(value) <= 8:
return "*" * len(value)
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
def _extract_upload_text(filename: str, content: bytes) -> str:
lower = filename.lower()
if lower.endswith((".txt", ".md", ".markdown", ".json", ".yaml", ".yml", ".log", ".xml", ".html", ".htm")):
try:
return content.decode("utf-8")
except UnicodeDecodeError:
return content.decode("utf-8", errors="ignore")
if lower.endswith(".csv"):
df = pd.read_csv(io.BytesIO(content))
return df.to_csv(index=False)
if lower.endswith((".xls", ".xlsx")):
df = pd.read_excel(io.BytesIO(content))
return df.to_csv(index=False)
# 增加对 PDF 的文本提取支持
if lower.endswith(".pdf"):
try:
import PyPDF2
pdf_reader = PyPDF2.PdfReader(io.BytesIO(content))
text = []
for page in pdf_reader.pages:
page_text = page.extract_text()
if page_text:
text.append(page_text)
return "\n".join(text)
except ImportError:
raise ValueError("PyPDF2 is not installed. Cannot parse PDF files.")
except Exception as e:
raise ValueError(f"Failed to parse PDF: {str(e)}")
# 增加对 Word 文档的文本提取支持
if lower.endswith((".doc", ".docx")):
try:
import docx
doc = docx.Document(io.BytesIO(content))
return "\n".join([para.text for para in doc.paragraphs])
except ImportError:
raise ValueError("python-docx is not installed. Cannot parse Word files.")
except Exception as e:
raise ValueError(f"Failed to parse Word document: {str(e)}")
# 增加对 PPT 文档的文本提取支持
if lower.endswith((".ppt", ".pptx")):
try:
import pptx
prs = pptx.Presentation(io.BytesIO(content))
text = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text.append(shape.text)
return "\n".join(text)
except ImportError:
raise ValueError("python-pptx is not installed. Cannot parse PPT files.")
except Exception as e:
raise ValueError(f"Failed to parse PPT document: {str(e)}")
raise ValueError("Unsupported file type")
@router.get("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
def get_knowledge_global_config():
config = knowledge_global_config_store.get()
raw_api_key = config.get("api_key")
return {
"api_base": config.get("api_base"),
"api_key": None,
"api_key_masked": _mask_api_key(raw_api_key),
"has_api_key": bool(raw_api_key),
"default_embedding_model": config.get("default_embedding_model"),
}
@router.put("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
def update_knowledge_global_config(payload: KnowledgeGlobalConfigUpdate):
updated = knowledge_global_config_store.update(payload.model_dump(exclude_unset=True))
raw_api_key = updated.get("api_key")
return {
"api_base": updated.get("api_base"),
"api_key": None,
"api_key_masked": _mask_api_key(raw_api_key),
"has_api_key": bool(raw_api_key),
"default_embedding_model": updated.get("default_embedding_model"),
}
@router.post("/knowledge-bases/global-config/test-connection", response_model=KnowledgeConnectionTestResponse)
def test_knowledge_global_connection(payload: KnowledgeConnectionTestRequest):
saved = knowledge_global_config_store.get()
api_base = normalize_openai_base_url(payload.api_base or saved.get("api_base") or "")
api_key = payload.api_key or saved.get("api_key")
model_name = (payload.model_name or "").strip()
if not api_base:
raise HTTPException(status_code=400, detail="API Base 未配置")
if not api_key:
raise HTTPException(status_code=400, detail="API Key 未配置")
if not model_name:
raise HTTPException(status_code=400, detail="测试连接必须显式填写向量模型名称")
if not api_base:
raise HTTPException(status_code=400, detail="API Base 未配置")
try:
client = OpenAI(
api_key=api_key,
base_url=api_base,
)
embedding_resp = client.embeddings.create(
model=model_name,
input="connection test",
)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Embedding调用失败: {exc}")
dimension = None
if getattr(embedding_resp, "data", None):
first = embedding_resp.data[0]
vector = getattr(first, "embedding", None)
if isinstance(vector, list):
dimension = len(vector)
return {
"success": True,
"message": "连接成功,Embedding调用正常",
"model_name": model_name,
"embedding_dimension": dimension,
"resolved_api_base": api_base,
"available_models": [],
}
@router.get("/knowledge-bases", response_model=List[KnowledgeBase])
def list_knowledge_bases(project_id: Optional[int] = None):
return knowledge_base_store.list(project_id=project_id)
@router.post("/knowledge-bases", response_model=KnowledgeBase)
def create_knowledge_base(payload: KnowledgeBaseCreate):
return knowledge_base_store.create(payload.model_dump())
@router.get("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
def get_knowledge_base(kb_id: str):
kb = knowledge_base_store.get(kb_id)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.put("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
def update_knowledge_base(kb_id: str, payload: KnowledgeBaseUpdate):
kb = knowledge_base_store.update(kb_id, payload.model_dump(exclude_unset=True))
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.delete("/knowledge-bases/{kb_id}")
def delete_knowledge_base(kb_id: str):
deleted = knowledge_base_store.delete(kb_id)
if not deleted:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return {"status": "success"}
@router.get("/knowledge-bases/{kb_id}/documents", response_model=List[KnowledgeDocument])
def list_knowledge_documents(kb_id: str):
kb = knowledge_base_store.get(kb_id)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb.get("documents", [])
@router.post("/knowledge-bases/{kb_id}/documents", response_model=KnowledgeDocument)
def create_knowledge_document(kb_id: str, payload: KnowledgeDocumentCreate):
doc = knowledge_base_store.create_document(kb_id=kb_id, payload=payload.model_dump())
if not doc:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return doc
@router.put("/knowledge-bases/{kb_id}/documents/{doc_id}", response_model=KnowledgeDocument)
def update_knowledge_document(kb_id: str, doc_id: str, payload: KnowledgeDocumentUpdate):
doc = knowledge_base_store.update_document(
kb_id=kb_id,
doc_id=doc_id,
payload=payload.model_dump(exclude_unset=True),
)
if not doc:
raise HTTPException(status_code=404, detail="Knowledge document not found")
return doc
@router.delete("/knowledge-bases/{kb_id}/documents/{doc_id}")
def delete_knowledge_document(kb_id: str, doc_id: str):
deleted = knowledge_base_store.delete_document(kb_id=kb_id, doc_id=doc_id)
if not deleted:
raise HTTPException(status_code=404, detail="Knowledge document not found")
return {"status": "success"}
@router.post("/knowledge-bases/{kb_id}/reindex")
def reindex_knowledge_base(kb_id: str):
try:
return knowledge_index_service.reindex(kb_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
@router.post("/knowledge-bases/{kb_id}/search", response_model=KnowledgeSearchResponse)
def search_knowledge_base(kb_id: str, payload: KnowledgeSearchRequest):
try:
result = knowledge_index_service.search(
kb_id=kb_id,
query=payload.query,
top_k=payload.top_k,
)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
return result
@router.post("/knowledge-bases/{kb_id}/documents/upload")
async def upload_knowledge_documents(
kb_id: str,
files: List[UploadFile] = File(...),
metadata: Optional[str] = Form(default=None),
):
kb = knowledge_base_store.get(kb_id)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
metadata_payload: dict[str, Any] = {}
if metadata:
try:
parsed_metadata = json.loads(metadata)
if isinstance(parsed_metadata, dict):
metadata_payload = parsed_metadata
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="metadata 必须是合法 JSON 对象")
created: List[dict[str, Any]] = []
for file in files:
filename = file.filename or "untitled"
content = await file.read()
if not content:
continue
# 将大小限制从 5MB 放宽到 15MB,以更好地支持带有图片的 PDF 文件
if len(content) > 15 * 1024 * 1024:
raise HTTPException(status_code=400, detail=f"文件过大 (超过 15MB): {filename}")
try:
text = _extract_upload_text(filename, content)
except Exception:
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {filename}")
doc = knowledge_base_store.create_document(
kb_id=kb_id,
payload={
"title": filename,
"content": text,
"metadata": {**metadata_payload, "source": "upload", "filename": filename},
},
)
if doc:
created.append(doc)
return {"status": "success", "count": len(created), "documents": created}
+183
View File
@@ -0,0 +1,183 @@
import json
import os
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Depends, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import jwt, JWTError
from pydantic import BaseModel, Field
from app.core.security import SECRET_KEY, ALGORITHM
from app.core.data_root import get_data_root
from app.core.llm_provider import build_llm_provider
router = APIRouter()
security = HTTPBearer()
DATA_FILE = str(get_data_root() / "llm_config.json")
class CurrentUser(BaseModel):
id: int
username: str
is_admin: bool = False
class LLMConfig(BaseModel):
id: str = Field(..., description="Unique identifier for the LLM configuration")
name: Optional[str] = Field(None, description="Display name")
provider: str = Field(..., description="Provider name (e.g., openai, azure, anthropic)")
model: str = Field(..., description="Model name (e.g., gpt-4, claude-3-opus)")
api_key: Optional[str] = Field(None, description="API Key for the provider")
api_base: Optional[str] = Field(None, description="Base URL for the API")
extra_headers: Optional[Dict[str, str]] = Field(None, description="Extra headers for the request")
is_active: bool = Field(True, description="Whether this configuration is active")
class LLMConfigCreate(BaseModel):
id: str
name: Optional[str] = None
provider: str
model: str
api_key: Optional[str] = None
api_base: Optional[str] = None
extra_headers: Optional[Dict[str, str]] = None
is_active: bool = True
class LLMConfigUpdate(BaseModel):
name: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
extra_headers: Optional[Dict[str, str]] = None
is_active: Optional[bool] = None
class TestConnectionRequest(BaseModel):
provider: str
model: str
api_key: Optional[str] = None
api_base: Optional[str] = None
extra_headers: Optional[Dict[str, str]] = None
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
unauthorized = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
try:
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise unauthorized
user_id = payload.get("id")
username = payload.get("sub")
is_admin = bool(payload.get("is_admin", False))
if user_id is None or username is None:
raise unauthorized
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
return current_user
def _load_data() -> List[Dict[str, Any]]:
if not os.path.exists(DATA_FILE):
return []
try:
with open(DATA_FILE, "r") as f:
return json.load(f)
except json.JSONDecodeError:
return []
def _save_data(data: List[Dict[str, Any]]):
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
with open(DATA_FILE, "w") as f:
json.dump(data, f, indent=2)
def _sanitize_config(item: Dict[str, Any], is_admin: bool) -> Dict[str, Any]:
config = item.copy()
if not is_admin:
config["api_key"] = None
return config
@router.get("/llm", response_model=List[LLMConfig])
def list_llm_configs(current_user: CurrentUser = Depends(get_current_user)):
data = _load_data()
return [LLMConfig(**_sanitize_config(item, current_user.is_admin)) for item in data]
@router.get("/llm/{config_id}", response_model=LLMConfig)
def get_llm_config(config_id: str, current_user: CurrentUser = Depends(get_current_user)):
data = _load_data()
for item in data:
if item["id"] == config_id:
return LLMConfig(**_sanitize_config(item, current_user.is_admin))
raise HTTPException(status_code=404, detail="LLM configuration not found")
@router.post("/llm", response_model=LLMConfig)
def create_llm_config(config: LLMConfigCreate, _: CurrentUser = Depends(get_admin_user)):
data = _load_data()
if any(item["id"] == config.id for item in data):
raise HTTPException(status_code=400, detail="LLM configuration with this ID already exists")
new_config = config.dict()
if new_config.get("is_active"):
for item in data:
item["is_active"] = False
data.append(new_config)
_save_data(data)
return LLMConfig(**new_config)
@router.put("/llm/{config_id}", response_model=LLMConfig)
def update_llm_config(config_id: str, config: LLMConfigUpdate, _: CurrentUser = Depends(get_admin_user)):
data = _load_data()
for i, item in enumerate(data):
if item["id"] == config_id:
updated_item = item.copy()
update_data = config.dict(exclude_unset=True)
if update_data.get("is_active"):
for j in range(len(data)):
data[j]["is_active"] = False
updated_item.update(update_data)
data[i] = updated_item
_save_data(data)
return LLMConfig(**updated_item)
raise HTTPException(status_code=404, detail="LLM configuration not found")
@router.delete("/llm/{config_id}")
def delete_llm_config(config_id: str, _: CurrentUser = Depends(get_admin_user)):
data = _load_data()
initial_len = len(data)
data = [item for item in data if item["id"] != config_id]
if len(data) == initial_len:
raise HTTPException(status_code=404, detail="LLM configuration not found")
_save_data(data)
return {"message": "LLM configuration deleted successfully"}
@router.post("/llm/test")
async def test_connection(request: TestConnectionRequest, _: CurrentUser = Depends(get_admin_user)):
try:
provider = build_llm_provider(
model=request.model.strip(),
provider=request.provider,
api_key=request.api_key,
api_base=request.api_base,
extra_headers=request.extra_headers,
)
response = await provider.chat(
messages=[{"role": "user", "content": "Hello"}],
max_tokens=5,
temperature=0,
)
if response.finish_reason == "error":
raise ValueError(response.content or "Unknown provider error")
return {
"success": True,
"message": "Connection successful",
"details": {
"content": response.content,
"finish_reason": response.finish_reason,
"usage": response.usage,
},
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
+135
View File
@@ -0,0 +1,135 @@
import json
import uuid
import asyncio
from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from app.schemas.mcp import MCPServer, MCPServerCreate, MCPServerUpdate
from app.core.data_root import get_data_root
router = APIRouter()
def get_mcp_servers_file() -> Path:
return get_data_root() / "mcp_servers.json"
def read_mcp_servers() -> List[dict]:
file_path = get_mcp_servers_file()
if not file_path.exists():
return []
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
return []
def write_mcp_servers(servers: List[dict]) -> None:
file_path = get_mcp_servers_file()
with open(file_path, "w", encoding="utf-8") as f:
json.dump(servers, f, indent=2, ensure_ascii=False)
async def _check_single_mcp_health(server: dict) -> str:
try:
async with AsyncExitStack() as stack:
server_type = server.get("type")
if server_type == "stdio":
params = StdioServerParameters(
command=server.get("command", ""),
args=server.get("args", []),
env=server.get("env")
)
read, write = await stack.enter_async_context(stdio_client(params))
elif server_type in ["sse", "streamableHttp"]:
read, write = await stack.enter_async_context(sse_client(server.get("url", "")))
else:
return "error: unsupported type"
session = await stack.enter_async_context(ClientSession(read, write))
await asyncio.wait_for(session.initialize(), timeout=5.0)
return "connected"
except Exception as e:
err_msg = str(e)
if "unhandled errors in a TaskGroup" in err_msg:
return "error: connection refused"
return f"error: {err_msg or 'unknown'}"
@router.get("/mcp", response_model=List[MCPServer])
async def list_mcp_servers(project_id: Optional[int] = None):
servers = read_mcp_servers()
if project_id is not None:
servers = [s for s in servers if s.get("project_id") == project_id]
if not servers:
return []
tasks = [_check_single_mcp_health(s) for s in servers]
statuses = await asyncio.gather(*tasks, return_exceptions=True)
needs_update = False
for server, status in zip(servers, statuses):
new_status = status if isinstance(status, str) else f"error: {str(status)}"
if server.get("status") != new_status:
server["status"] = new_status
needs_update = True
if needs_update:
# Write back to persist the new statuses
all_servers = read_mcp_servers()
for s in all_servers:
for checked_s in servers:
if s.get("id") == checked_s.get("id"):
s["status"] = checked_s["status"]
write_mcp_servers(all_servers)
return servers
@router.post("/mcp", response_model=MCPServer)
def create_mcp_server(server_in: MCPServerCreate):
servers = read_mcp_servers()
server_data = server_in.dict()
server_data["id"] = str(uuid.uuid4())
if "status" not in server_data or not server_data["status"]:
server_data["status"] = "disconnected"
servers.append(server_data)
write_mcp_servers(servers)
return server_data
@router.get("/mcp/{server_id}", response_model=MCPServer)
def get_mcp_server(server_id: str):
servers = read_mcp_servers()
for server in servers:
if server.get("id") == server_id:
return server
raise HTTPException(status_code=404, detail="MCP Server not found")
@router.put("/mcp/{server_id}", response_model=MCPServer)
def update_mcp_server(server_id: str, server_in: MCPServerUpdate):
servers = read_mcp_servers()
for i, server in enumerate(servers):
if server.get("id") == server_id:
update_data = server_in.dict(exclude_unset=True)
for key, value in update_data.items():
server[key] = value
servers[i] = server
write_mcp_servers(servers)
return server
raise HTTPException(status_code=404, detail="MCP Server not found")
@router.delete("/mcp/{server_id}")
def delete_mcp_server(server_id: str):
servers = read_mcp_servers()
filtered_servers = [s for s in servers if s.get("id") != server_id]
if len(servers) == len(filtered_servers):
raise HTTPException(status_code=404, detail="MCP Server not found")
write_mcp_servers(filtered_servers)
return {"status": "success"}
+92
View File
@@ -0,0 +1,92 @@
from typing import List
from fastapi import APIRouter, HTTPException, Depends, status
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.project import Project
from app.schemas.project import ProjectCreate, ProjectUpdate, Project as ProjectSchema
from app.core.security import get_current_user, CurrentUser
router = APIRouter()
@router.get("/projects", response_model=List[ProjectSchema])
def list_projects(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
# Users can only see their own projects, unless they are admin (who can see all?)
# For simplicity, let's allow users to see their own projects.
query = db.query(Project)
if not current_user.is_admin:
query = query.filter(Project.owner_id == current_user.id)
projects = query.offset(skip).limit(limit).all()
return projects
@router.post("/projects", response_model=ProjectSchema)
def create_project(
project: ProjectCreate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_project = Project(**project.dict(), owner_id=current_user.id)
db.add(db_project)
db.commit()
db.refresh(db_project)
return db_project
@router.get("/projects/{project_id}", response_model=ProjectSchema)
def read_project(
project_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_project = db.query(Project).filter(Project.id == project_id).first()
if db_project is None:
raise HTTPException(status_code=404, detail="Project not found")
if not current_user.is_admin and db_project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
return db_project
@router.put("/projects/{project_id}", response_model=ProjectSchema)
def update_project(
project_id: int,
project: ProjectUpdate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_project = db.query(Project).filter(Project.id == project_id).first()
if db_project is None:
raise HTTPException(status_code=404, detail="Project not found")
if not current_user.is_admin and db_project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
project_data = project.dict(exclude_unset=True)
for key, value in project_data.items():
setattr(db_project, key, value)
db.add(db_project)
db.commit()
db.refresh(db_project)
return db_project
@router.delete("/projects/{project_id}")
def delete_project(
project_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_project = db.query(Project).filter(Project.id == project_id).first()
if db_project is None:
raise HTTPException(status_code=404, detail="Project not found")
if not current_user.is_admin and db_project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
db.delete(db_project)
db.commit()
return {"status": "success"}
+146
View File
@@ -0,0 +1,146 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, List, Optional
from pydantic import BaseModel
from app.database import get_db
from app.models.datasource import DataSource
from app.schemas.mdl import MDLManifest
from app.services.mdl import MDLService
from app.connectors.factory import get_connector
router = APIRouter(tags=["semantic"])
class GenerateMDLRequest(BaseModel):
selected_tables: Optional[List[str]] = None
selected_columns: Optional[Dict[str, List[str]]] = None
class ModelDetailResponse(BaseModel):
model: Dict[str, Any]
relationships: List[Dict[str, Any]]
preview_rows: List[Dict[str, Any]]
def _normalize_query_result(results: Any) -> List[Dict[str, Any]]:
if isinstance(results, list):
if results and isinstance(results[0], dict):
return results
if results and isinstance(results[0], (list, tuple)):
return [dict(enumerate(row)) for row in results]
return []
if isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
return [dict(zip(col_names, row)) for row in rows]
return []
@router.get("/semantic/{datasource_id}/schema", response_model=Dict[str, List[Dict[str, str]]])
def get_semantic_schema(datasource_id: int, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
raw_schema = MDLService.get_raw_schema(ds)
result = {}
for table, data in raw_schema.items():
if isinstance(data, dict) and "columns" in data:
result[table] = data["columns"]
elif isinstance(data, list):
result[table] = data
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/semantic/{datasource_id}", response_model=MDLManifest)
def get_semantic_model(datasource_id: int, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
# Get or generate MDL
try:
mdl = MDLService.get_or_create_mdl(datasource_id)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/semantic/{datasource_id}", response_model=MDLManifest)
def update_semantic_model(datasource_id: int, mdl: MDLManifest, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
MDLService.save_mdl(datasource_id, mdl)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/semantic/{datasource_id}/generate", response_model=MDLManifest)
def regenerate_semantic_model(datasource_id: int, request: Optional[GenerateMDLRequest] = None, db: Session = Depends(get_db)):
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
selected_tables = request.selected_tables if request else None
selected_columns = request.selected_columns if request else None
mdl = MDLService.generate_default_mdl(
ds,
selected_tables=selected_tables,
selected_columns=selected_columns,
)
MDLService.save_mdl(datasource_id, mdl)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/semantic/{datasource_id}/models/{model_name}", response_model=ModelDetailResponse)
def get_model_detail(datasource_id: int, model_name: str, limit: int = 10, db: Session = Depends(get_db)):
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
mdl = MDLService.get_or_create_mdl(datasource_id)
model = next((m for m in mdl.models if m.name == model_name), None)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
relationships = [
{
"name": rel.name,
"models": rel.models,
"joinType": rel.joinType,
"condition": rel.condition,
"properties": rel.properties,
}
for rel in mdl.relationships
if model_name in rel.models
]
preview_rows: List[Dict[str, Any]] = []
try:
connector = get_connector(ds)
table_name = model.tableReference.table if model.tableReference else model.name
query = f'SELECT * FROM "{table_name}" LIMIT {max(1, min(limit, 100))}'
raw = connector.execute_query(query)
preview_rows = _normalize_query_result(raw)
except Exception:
preview_rows = []
model_payload = {
"name": model.name,
"tableReference": model.tableReference.model_dump(by_alias=True) if model.tableReference else None,
"primaryKey": model.primaryKey,
"properties": model.properties,
"columns": [c.model_dump(by_alias=True) for c in model.columns],
}
return ModelDetailResponse(
model=model_payload,
relationships=relationships,
preview_rows=preview_rows,
)
+562
View File
@@ -0,0 +1,562 @@
import json
import os
import shutil
import zipfile
import tarfile
import re
import yaml
from pathlib import Path
from typing import List, Optional, Dict, Any
from datetime import datetime
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from app.core.data_root import get_data_root, get_workspace_root
from nanobot.agent.skills import BUILTIN_SKILLS_DIR as NANOBOT_BUILTIN_SKILLS_DIR
router = APIRouter()
DATA_FILE = str(get_data_root() / "skills.json")
SKILL_HUB_DIR = str(get_workspace_root() / "skills")
BACKEND_BUILTIN_SKILLS_DIR = str(Path(__file__).resolve().parents[1] / "skills_builtin")
SOURCE_LOCAL_IMPORT = "local_import"
SOURCE_SYSTEM_BUILTIN = "system_builtin"
SOURCE_BACKEND_GENERATED = "backend_generated"
SOURCE_UPLOADED_FILE = "uploaded_file"
STATUS_SAFE = "safe"
STATUS_LOW_RISK = "low_risk"
_SOURCE_ALIASES = {
SOURCE_LOCAL_IMPORT: SOURCE_LOCAL_IMPORT,
"本地导入": SOURCE_LOCAL_IMPORT,
"Local Import": SOURCE_LOCAL_IMPORT,
SOURCE_SYSTEM_BUILTIN: SOURCE_SYSTEM_BUILTIN,
"系统内置": SOURCE_SYSTEM_BUILTIN,
"System Built-in": SOURCE_SYSTEM_BUILTIN,
SOURCE_BACKEND_GENERATED: SOURCE_BACKEND_GENERATED,
"后台生成": SOURCE_BACKEND_GENERATED,
"Backend Generated": SOURCE_BACKEND_GENERATED,
SOURCE_UPLOADED_FILE: SOURCE_UPLOADED_FILE,
"文件上传": SOURCE_UPLOADED_FILE,
"File Upload": SOURCE_UPLOADED_FILE,
}
_STATUS_ALIASES = {
STATUS_SAFE: STATUS_SAFE,
"安全": STATUS_SAFE,
"Safe": STATUS_SAFE,
STATUS_LOW_RISK: STATUS_LOW_RISK,
"低风险": STATUS_LOW_RISK,
"Low Risk": STATUS_LOW_RISK,
}
def _normalize_source(value: Optional[str]) -> str:
if not value:
return SOURCE_LOCAL_IMPORT
return _SOURCE_ALIASES.get(value, value)
def _normalize_status(value: Optional[str]) -> str:
if not value:
return STATUS_SAFE
return _STATUS_ALIASES.get(value, value)
def _ensure_skill_hub_dir() -> None:
os.makedirs(SKILL_HUB_DIR, exist_ok=True)
class Skill(BaseModel):
id: str = Field(..., description="Unique identifier for the skill")
name: str = Field(..., description="Name of the skill")
description: Optional[str] = Field(None, description="Description of what the skill does")
content: str = Field(..., description="The content/prompt/logic of the skill")
type: str = Field("python", description="Type of the skill (python, sql, api)")
project_id: Optional[int] = Field(None, description="The ID of the project this skill belongs to")
source: str = Field(SOURCE_LOCAL_IMPORT, description="Stable source key of the skill")
installation_time: str = Field(default_factory=lambda: datetime.now().strftime("%Y年%m月%d"), description="Time when the skill was installed")
status: str = Field(STATUS_SAFE, description="Stable security status key")
file_path: Optional[str] = Field(None, description="Path to the skill folder in skill-hub")
is_builtin: bool = Field(False, description="Whether this is a system builtin skill")
class SkillCreate(BaseModel):
id: str
name: str
description: Optional[str] = None
content: str
type: str = "python"
project_id: Optional[int] = None
source: str = SOURCE_LOCAL_IMPORT
installation_time: Optional[str] = None
status: str = STATUS_SAFE
file_path: Optional[str] = None
class SkillUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
content: Optional[str] = None
type: Optional[str] = None
project_id: Optional[int] = None
source: Optional[str] = None
installation_time: Optional[str] = None
status: Optional[str] = None
file_path: Optional[str] = None
def _parse_skill_md(file_path: str) -> Dict[str, Any]:
"""Parse SKILL.md for metadata and content according to agentskills.io standard."""
if not os.path.exists(file_path):
return {}
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
print(f"Error reading {file_path}: {e}")
return {}
# Split YAML frontmatter and Markdown body
# Support both --- and +++ for frontmatter
metadata = {}
body = content
if content.startswith('---'):
parts = content.split('---', 2)
if len(parts) >= 3:
try:
metadata = yaml.safe_load(parts[1]) or {}
body = parts[2].strip()
except Exception as e:
print(f"Error parsing YAML frontmatter: {e}")
# Extract name and description, fallback to some defaults
name = metadata.get("name")
description = metadata.get("description")
# If name not in metadata, try to find the first H1 in markdown body
if not name:
for line in body.split('\n'):
if line.startswith('# '):
name = line[2:].strip()
break
return {
"name": name,
"description": description,
"content": body,
"metadata": metadata
}
def _load_data() -> List[Dict[str, Any]]:
if not os.path.exists(DATA_FILE):
return []
try:
with open(DATA_FILE, "r") as f:
return json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
return []
def _save_data(data: List[Dict[str, Any]]):
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
with open(DATA_FILE, "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def _dedupe_skills(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
deduped: Dict[str, Dict[str, Any]] = {}
for item in data:
skill_id = str(item.get("id") or "").strip()
project_id = item.get("project_id")
if not skill_id:
continue
# Use a composite key of (id, project_id) for deduplication
# so that different projects can theoretically have the same skill_id
dedupe_key = f"{skill_id}_{project_id}"
existing = deduped.get(dedupe_key)
if existing is None:
deduped[dedupe_key] = item
continue
# If they somehow have the exact same dedupe_key, we just keep the later one
deduped[dedupe_key] = item
return list(deduped.values())
def _safe_skill_dir_name(value: str) -> str:
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
return safe or "skill"
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
os.makedirs(skill_dir, exist_ok=True)
skill_md_path = os.path.join(skill_dir, "SKILL.md")
final_description = description or "No description provided"
body = content or ""
markdown = (
f"---\n"
f"name: {skill_name}\n"
f"description: {final_description}\n"
f"---\n\n"
f"{body}\n"
)
with open(skill_md_path, "w", encoding="utf-8") as f:
f.write(markdown)
return skill_md_path
def _scan_builtin_skills(data: List[Dict[str, Any]], registered_paths: set, source_dir: str, source_name: str):
if not os.path.exists(source_dir):
return
for item in os.listdir(source_dir):
skill_dir = os.path.abspath(os.path.join(source_dir, item))
if os.path.isdir(skill_dir):
skill_md_path = os.path.join(skill_dir, "SKILL.md")
if os.path.exists(skill_md_path):
metadata_res = _parse_skill_md(skill_md_path)
skill_name = metadata_res.get("name") or item
existing = None
for d in data:
if (d.get("id") == item and d.get("is_builtin")) or d.get("file_path") == skill_dir:
existing = d
break
if existing:
existing["name"] = skill_name
existing["description"] = metadata_res.get("description") or "No description provided"
existing["content"] = metadata_res.get("content") or ""
existing["file_path"] = skill_dir
existing["is_builtin"] = True
existing["source"] = source_name
existing["status"] = STATUS_SAFE
registered_paths.add(skill_dir)
else:
new_skill = {
"id": item,
"name": skill_name,
"description": metadata_res.get("description") or "No description provided",
"content": metadata_res.get("content") or "",
"type": "agentskill",
"project_id": None,
"source": source_name,
"installation_time": datetime.now().strftime("%Y年%m月%d"),
"status": STATUS_SAFE,
"file_path": skill_dir,
"is_builtin": True
}
data.append(new_skill)
registered_paths.add(skill_dir)
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
_ensure_skill_hub_dir()
data = _load_data()
registered_paths = set()
# Sync registered skills with their SKILL.md if available
for item in data:
item["source"] = _normalize_source(item.get("source"))
item["status"] = _normalize_status(item.get("status"))
if item.get("id") in ("nl2sql", "visualization") or item.get("is_builtin"):
item["is_builtin"] = True
else:
item.setdefault("is_builtin", False)
if item.get("file_path"):
abs_path = os.path.abspath(item["file_path"])
registered_paths.add(abs_path)
skill_md_path = os.path.join(abs_path, "SKILL.md")
if os.path.exists(skill_md_path):
metadata_res = _parse_skill_md(skill_md_path)
if metadata_res.get("name"):
item["name"] = metadata_res["name"]
if metadata_res.get("description"):
item["description"] = metadata_res["description"]
if metadata_res.get("content"):
item["content"] = metadata_res["content"]
# Scan builtin skills
_scan_builtin_skills(data, registered_paths, NANOBOT_BUILTIN_SKILLS_DIR, SOURCE_SYSTEM_BUILTIN)
_scan_builtin_skills(data, registered_paths, BACKEND_BUILTIN_SKILLS_DIR, SOURCE_SYSTEM_BUILTIN)
# Scan for unregistered skills in SKILL_HUB_DIR (1-level deep to match nanobot's behavior)
if os.path.exists(SKILL_HUB_DIR):
for item in os.listdir(SKILL_HUB_DIR):
skill_dir = os.path.abspath(os.path.join(SKILL_HUB_DIR, item))
if os.path.isdir(skill_dir):
skill_md_path = os.path.join(skill_dir, "SKILL.md")
if os.path.exists(skill_md_path) and skill_dir not in registered_paths:
metadata_res = _parse_skill_md(skill_md_path)
skill_name = metadata_res.get("name") or item
# Try to deduce project_id from directory prefix (e.g., p123_skillname)
deduced_project_id = None
match = re.match(r'^p(\d+)_', item)
if match:
deduced_project_id = int(match.group(1))
new_skill = {
"id": item,
"name": skill_name,
"description": metadata_res.get("description") or "No description provided",
"content": metadata_res.get("content") or "",
"type": "agentskill",
"project_id": deduced_project_id,
"source": SOURCE_BACKEND_GENERATED,
"installation_time": datetime.now().strftime("%Y年%m月%d"),
"status": STATUS_SAFE,
"file_path": skill_dir,
"is_builtin": item in ("nl2sql", "visualization")
}
data.append(new_skill)
registered_paths.add(skill_dir)
deduped = _dedupe_skills(data)
if project_id is not None:
return [item for item in deduped if item.get("project_id") == project_id or item.get("project_id") is None]
return deduped
@router.get("/skills", response_model=List[Skill])
def list_skills(project_id: Optional[int] = None):
data = load_skills(project_id)
return [Skill(**item) for item in data]
@router.get("/skills/{skill_id}", response_model=Skill)
def get_skill(skill_id: str, project_id: Optional[int] = None):
data = load_skills()
for item in data:
if item["id"] == skill_id:
if project_id is not None and item.get("project_id") != project_id:
continue
return Skill(**item)
raise HTTPException(status_code=404, detail="Skill not found")
@router.post("/skills/upload")
async def upload_skill(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None)
):
"""Upload a skill file (SKILL.md) or a packaged skill (zip/tar.gz)."""
filename = file.filename
print(f"Uploading skill: {filename}, project_id: {project_id}")
_ensure_skill_hub_dir()
# Create a unique temp directory
temp_dir_name = f"temp_{datetime.now().timestamp()}_{os.urandom(4).hex()}"
temp_dir = os.path.join(SKILL_HUB_DIR, temp_dir_name)
os.makedirs(temp_dir, exist_ok=True)
try:
file_path = os.path.join(temp_dir, filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
skill_source_dir = None
# Handle different file types
if filename.endswith(".zip"):
try:
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
os.remove(file_path)
# Find the directory containing SKILL.md
for root, dirs, files in os.walk(temp_dir):
if "SKILL.md" in files:
skill_source_dir = root
break
except Exception as e:
print(f"Zip extraction failed: {e}")
raise HTTPException(status_code=400, detail=f"Failed to extract zip: {str(e)}")
elif filename.endswith((".tar.gz", ".tgz")):
try:
with tarfile.open(file_path, 'r:gz') as tar_ref:
tar_ref.extractall(temp_dir)
os.remove(file_path)
for root, dirs, files in os.walk(temp_dir):
if "SKILL.md" in files:
skill_source_dir = root
break
except Exception as e:
print(f"Tarball extraction failed: {e}")
raise HTTPException(status_code=400, detail=f"Failed to extract tarball: {str(e)}")
elif filename == "SKILL.md":
skill_source_dir = temp_dir
else:
print(f"Unsupported file type: {filename}")
raise HTTPException(status_code=400, detail="Only SKILL.md or packaged skills (zip/tar.gz) are supported")
if not skill_source_dir or not os.path.exists(os.path.join(skill_source_dir, "SKILL.md")):
print(f"SKILL.md not found in {filename}")
raise HTTPException(status_code=400, detail="SKILL.md not found in the uploaded file")
# Parse metadata
skill_md_path = os.path.join(skill_source_dir, "SKILL.md")
metadata_res = _parse_skill_md(skill_md_path)
# Use metadata name, or fallback to folder name or filename
skill_name = metadata_res.get("name")
if not skill_name:
if filename == "SKILL.md":
skill_name = "unnamed_skill"
else:
# Use filename without extension
skill_name = os.path.splitext(filename)[0]
# Create a safe directory name for the skill
safe_name = _safe_skill_dir_name(skill_name)
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
if project_id is not None:
# Prefix the folder name with p{project_id}_ to distinguish projects in storage
# without breaking nanobot's 1-level-deep skill loader
final_skill_dir = os.path.join(SKILL_HUB_DIR, f"p{project_id}_{final_skill_id}")
final_skill_id = f"p{project_id}_{final_skill_id}"
else:
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
print(f"Finalizing skill: {skill_name} -> {final_skill_dir}")
# Move the skill content to final destination
os.makedirs(final_skill_dir, exist_ok=True)
for item in os.listdir(skill_source_dir):
s = os.path.join(skill_source_dir, item)
d = os.path.join(final_skill_dir, item)
if os.path.isdir(s):
shutil.copytree(s, d, dirs_exist_ok=True)
else:
shutil.copy2(s, d)
# Register in skills.json
data = load_skills()
new_skill = {
"id": final_skill_id,
"name": skill_name,
"description": metadata_res.get("description") or "No description provided",
"content": metadata_res.get("content") or "",
"type": "agentskill",
"project_id": project_id,
"source": SOURCE_UPLOADED_FILE,
"installation_time": datetime.now().strftime("%Y年%m月%d"),
"status": STATUS_SAFE,
"file_path": final_skill_dir
}
data.append(new_skill)
_save_data(data)
print(f"Skill registered successfully: {final_skill_id}")
return new_skill
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
finally:
# Cleanup temp directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
@router.post("/skills", response_model=Skill)
def create_skill(skill: SkillCreate):
_ensure_skill_hub_dir()
data = load_skills()
if any(item["id"] == skill.id and item.get("project_id") == skill.project_id for item in data):
raise HTTPException(status_code=400, detail="Skill with this ID already exists in this project")
new_skill_dict = skill.dict()
new_skill_dict["source"] = _normalize_source(new_skill_dict.get("source"))
new_skill_dict["status"] = _normalize_status(new_skill_dict.get("status"))
if not new_skill_dict.get("installation_time"):
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d")
if not new_skill_dict.get("file_path"):
project_id = new_skill_dict.get("project_id")
base_dir_name = _safe_skill_dir_name(new_skill_dict["id"])
if project_id is not None:
# Add prefix for project storage distinction
if not base_dir_name.startswith(f"p{project_id}_"):
base_dir_name = f"p{project_id}_{base_dir_name}"
skill_dir = os.path.join(SKILL_HUB_DIR, base_dir_name)
else:
skill_dir = os.path.join(SKILL_HUB_DIR, base_dir_name)
_write_skill_markdown(
skill_dir=skill_dir,
skill_name=new_skill_dict["name"],
description=new_skill_dict.get("description"),
content=new_skill_dict.get("content", ""),
)
new_skill_dict["file_path"] = skill_dir
new_skill_dict["id"] = base_dir_name
data.append(new_skill_dict)
_save_data(data)
return Skill(**new_skill_dict)
@router.put("/skills/{skill_id}", response_model=Skill)
def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] = None):
data = load_skills()
for i, item in enumerate(data):
if item["id"] == skill_id:
if project_id is not None and item.get("project_id") != project_id:
continue
updated_item = item.copy()
update_data = skill.dict(exclude_unset=True)
if "source" in update_data:
update_data["source"] = _normalize_source(update_data.get("source"))
if "status" in update_data:
update_data["status"] = _normalize_status(update_data.get("status"))
updated_item.update(update_data)
if updated_item.get("file_path"):
_write_skill_markdown(
skill_dir=updated_item["file_path"],
skill_name=updated_item.get("name") or item.get("name") or "skill",
description=updated_item.get("description"),
content=updated_item.get("content", ""),
)
data[i] = updated_item
_save_data(data)
return Skill(**updated_item)
raise HTTPException(status_code=404, detail="Skill not found")
@router.delete("/skills/{skill_id}")
def delete_skill(skill_id: str, project_id: Optional[int] = None):
data = load_skills()
initial_len = len(data)
# If project_id is provided, we only delete if it matches
new_data = []
found = False
skill_to_delete = None
for item in data:
if item["id"] == skill_id:
if item.get("is_builtin"):
raise HTTPException(status_code=400, detail="Builtin skills cannot be deleted")
if project_id is not None and item.get("project_id") not in (project_id, None):
new_data.append(item)
continue
found = True
skill_to_delete = item
else:
new_data.append(item)
if not found:
raise HTTPException(status_code=404, detail="Skill not found")
# Clean up file_path if it exists
if skill_to_delete and skill_to_delete.get("file_path"):
file_path = skill_to_delete["file_path"]
if os.path.exists(file_path):
try:
if os.path.isdir(file_path):
shutil.rmtree(file_path)
else:
os.remove(file_path)
except Exception as e:
print(f"Error deleting skill files at {file_path}: {e}")
_save_data(new_data)
return {"message": "Skill deleted successfully"}
+106
View File
@@ -0,0 +1,106 @@
from typing import List
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.subagent import Subagent
from app.models.project import Project
from app.schemas.subagent import SubagentCreate, SubagentUpdate, Subagent as SubagentSchema
from app.core.security import get_current_user, CurrentUser
router = APIRouter()
@router.get("/projects/{project_id}/subagents", response_model=List[SubagentSchema])
def list_subagents(
project_id: int,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not current_user.is_admin and project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
subagents = db.query(Subagent).filter(Subagent.project_id == project_id).offset(skip).limit(limit).all()
return subagents
@router.post("/projects/{project_id}/subagents", response_model=SubagentSchema)
def create_subagent(
project_id: int,
subagent: SubagentCreate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not current_user.is_admin and project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
db_subagent = Subagent(**subagent.dict(), project_id=project_id)
db.add(db_subagent)
db.commit()
db.refresh(db_subagent)
return db_subagent
@router.get("/subagents/{subagent_id}", response_model=SubagentSchema)
def read_subagent(
subagent_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_subagent = db.query(Subagent).filter(Subagent.id == subagent_id).first()
if db_subagent is None:
raise HTTPException(status_code=404, detail="Subagent not found")
project = db.query(Project).filter(Project.id == db_subagent.project_id).first()
if not current_user.is_admin and project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
return db_subagent
@router.put("/subagents/{subagent_id}", response_model=SubagentSchema)
def update_subagent(
subagent_id: int,
subagent: SubagentUpdate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_subagent = db.query(Subagent).filter(Subagent.id == subagent_id).first()
if db_subagent is None:
raise HTTPException(status_code=404, detail="Subagent not found")
project = db.query(Project).filter(Project.id == db_subagent.project_id).first()
if not current_user.is_admin and project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
subagent_data = subagent.dict(exclude_unset=True)
for key, value in subagent_data.items():
setattr(db_subagent, key, value)
db.add(db_subagent)
db.commit()
db.refresh(db_subagent)
return db_subagent
@router.delete("/subagents/{subagent_id}")
def delete_subagent(
subagent_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_subagent = db.query(Subagent).filter(Subagent.id == subagent_id).first()
if db_subagent is None:
raise HTTPException(status_code=404, detail="Subagent not found")
project = db.query(Project).filter(Project.id == db_subagent.project_id).first()
if not current_user.is_admin and project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
db.delete(db_subagent)
db.commit()
return {"status": "success"}
+73
View File
@@ -0,0 +1,73 @@
from fastapi import APIRouter, UploadFile, File, HTTPException
import pandas as pd
import duckdb
import io
import uuid
from app.core.data_root import get_uploads_root
router = APIRouter()
upload_dir = get_uploads_root()
@router.post("/upload/file")
async def upload_file(file: UploadFile = File(...)):
upload_dir.mkdir(parents=True, exist_ok=True)
allowed_extensions = ('.csv', '.xls', '.xlsx', '.parquet', '.db', '.sqlite', '.sqlite3')
filename_lower = file.filename.lower()
if not filename_lower.endswith(allowed_extensions):
raise HTTPException(status_code=400, detail="Invalid file type. Allowed: CSV, Excel, Parquet, SQLite.")
try:
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="Empty file is not allowed.")
file_obj = io.BytesIO(content)
unique_filename = f"{uuid.uuid4()}-{file.filename}"
save_path = upload_dir / unique_filename
save_path.write_bytes(content)
file_url = f"local://{unique_filename}"
file_obj.seek(0)
try:
if filename_lower.endswith('.csv'):
df = pd.read_csv(file_obj)
elif filename_lower.endswith(('.xls', '.xlsx')):
df = pd.read_excel(file_obj)
elif filename_lower.endswith('.parquet'):
df = pd.read_parquet(file_obj)
elif filename_lower.endswith(('.db', '.sqlite', '.sqlite3')):
# For SQLite, we don't load into DF immediately for analysis here
# Just return success
return {
"filename": unique_filename,
"url": file_url,
"rows": 0,
"columns": [],
"summary": "SQLite database uploaded"
}
# For DF supported types
duckdb_conn = duckdb.connect(database=':memory:')
duckdb_conn.register('uploaded_file', df)
summary = duckdb_conn.execute("DESCRIBE uploaded_file").fetchall()
row_count = len(df)
columns = list(df.columns)
return {
"filename": unique_filename,
"url": file_url,
"rows": row_count,
"columns": columns,
"summary": str(summary)
}
except Exception as e:
return {
"filename": unique_filename,
"url": file_url,
"analysis_error": str(e)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+215
View File
@@ -0,0 +1,215 @@
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from typing import List
import secrets
import hashlib
from datetime import datetime, timedelta, timezone
from app.database import get_db
from app.models.user import User, EmailVerification
from app.schemas.user import UserCreate, UserUpdate, UserResponse, ResendVerificationRequest
from app.core.security import get_password_hash, verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES
from app.core.email import send_verification_email
router = APIRouter()
def generate_verification_token() -> str:
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest()
@router.post("/auth/login")
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
user = db.query(User).filter(User.username == form_data.username).first()
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username, "is_admin": user.is_admin, "id": user.id},
expires_delta=access_token_expires
)
return {
"access_token": access_token,
"token_type": "bearer",
"user": {
"id": user.id,
"username": user.username,
"email": user.email,
"avatar": user.avatar,
"is_admin": user.is_admin
}
}
@router.post("/auth/register", response_model=UserResponse)
def register_user(user: UserCreate, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.username == user.username).first()
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user_email = db.query(User).filter(User.email == user.email).first()
if db_user_email:
raise HTTPException(status_code=400, detail="Email already registered")
hashed_password = get_password_hash(user.password)
# If this is the first user, make them an admin
is_first_user = db.query(User).count() == 0
is_admin = is_first_user or user.is_admin
is_active = True if is_first_user else False
db_user = User(
username=user.username,
email=user.email,
avatar=user.avatar,
hashed_password=hashed_password,
is_active=is_active,
is_admin=is_admin
)
db.add(db_user)
db.commit()
db.refresh(db_user)
if not is_active:
token = generate_verification_token()
hashed = hash_token(token)
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
verification = EmailVerification(
user_id=db_user.id,
token_hash=hashed,
expires_at=expires_at
)
db.add(verification)
db.commit()
# 将用户的 email 保存到局部变量中,防止在后台任务执行前 session 关闭导致延迟加载失败
user_email = db_user.email
background_tasks.add_task(send_verification_email, user_email, token)
return db_user
@router.get("/auth/verify-email")
def verify_email(token: str, db: Session = Depends(get_db)):
hashed = hash_token(token)
verification = db.query(EmailVerification).filter(
EmailVerification.token_hash == hashed,
EmailVerification.is_used == False
).first()
if not verification:
raise HTTPException(status_code=400, detail="Invalid or used token")
# Check if expired (make timezone-aware if naive)
expires_at = verification.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
if expires_at < datetime.now(timezone.utc):
raise HTTPException(status_code=400, detail="Token expired")
user = db.query(User).filter(User.id == verification.user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user.is_active = True
verification.is_used = True
db.commit()
return {"status": "success", "message": "Email verified successfully"}
@router.post("/auth/resend-verification")
def resend_verification(request: ResendVerificationRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
user = db.query(User).filter(User.username == request.username).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
if user.is_active:
raise HTTPException(status_code=400, detail="User already active")
token = generate_verification_token()
hashed = hash_token(token)
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
verification = EmailVerification(
user_id=user.id,
token_hash=hashed,
expires_at=expires_at
)
db.add(verification)
db.commit()
# 提取 email,避免后台任务访问已断开的 db session
user_email = user.email
background_tasks.add_task(send_verification_email, user_email, token)
return {"status": "success", "message": "Verification email sent"}
@router.get("/users", response_model=List[UserResponse])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
users = db.query(User).offset(skip).limit(limit).all()
return users
@router.get("/users/{user_id}", response_model=UserResponse)
def read_user(user_id: int, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.id == user_id).first()
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
@router.post("/users", response_model=UserResponse)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.username == user.username).first()
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user_email = db.query(User).filter(User.email == user.email).first()
if db_user_email:
raise HTTPException(status_code=400, detail="Email already registered")
db_user = User(
username=user.username,
email=user.email,
avatar=user.avatar,
hashed_password=get_password_hash(user.password),
is_active=user.is_active,
is_admin=user.is_admin
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
@router.put("/users/{user_id}", response_model=UserResponse)
def update_user(user_id: int, user: UserUpdate, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.id == user_id).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
update_data = user.model_dump(exclude_unset=True)
for key, value in update_data.items():
if key == "password" and value:
db_user.hashed_password = get_password_hash(value)
elif key != "password":
setattr(db_user, key, value)
db.commit()
db.refresh(db_user)
return db_user
@router.delete("/users/{user_id}")
def delete_user(user_id: int, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.id == user_id).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
db.delete(db_user)
db.commit()
return {"ok": True}
+31
View File
@@ -0,0 +1,31 @@
from typing import Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from app.api.llm import get_current_user, get_admin_user, CurrentUser
from app.services.web_search_config_store import get_web_search_config, save_web_search_config
router = APIRouter()
class WebSearchConfigModel(BaseModel):
provider: str = Field(default="duckduckgo", description="Web search provider (brave, tavily, duckduckgo, searxng, jina)")
api_key: Optional[str] = Field(default="", description="API Key for the provider")
base_url: Optional[str] = Field(default="", description="Base URL for SearXNG")
max_results: int = Field(default=5, description="Maximum number of search results")
def _sanitize_config(config: Dict[str, Any], is_admin: bool) -> Dict[str, Any]:
sanitized = config.copy()
if not is_admin:
sanitized["api_key"] = None
return sanitized
@router.get("/web-search/config", response_model=WebSearchConfigModel)
def get_config(current_user: CurrentUser = Depends(get_current_user)):
config = get_web_search_config()
return WebSearchConfigModel(**_sanitize_config(config, current_user.is_admin))
@router.put("/web-search/config", response_model=WebSearchConfigModel)
def update_config(config: WebSearchConfigModel, _: CurrentUser = Depends(get_admin_user)):
config_dict = config.dict()
save_web_search_config(config_dict)
return WebSearchConfigModel(**config_dict)