feat: add data source
This commit is contained in:
+65
-17
@@ -17,9 +17,12 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -33,7 +36,7 @@ _cache_lock = threading.Lock()
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
|
||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||
session_id: Optional[str] = Field(None, description="Conversation session identifier")
|
||||
|
||||
@@ -113,6 +116,8 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
|
||||
return pd.read_csv(file_path)
|
||||
if suffix in [".xls", ".xlsx"]:
|
||||
return pd.read_excel(file_path)
|
||||
if suffix == ".parquet":
|
||||
return pd.read_parquet(file_path)
|
||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||
|
||||
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||
@@ -153,6 +158,10 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
# If source is ds:ID, that's already a good key
|
||||
if source.startswith("ds:"):
|
||||
return source
|
||||
|
||||
if source == "postgres":
|
||||
return f"postgres:{getattr(connector, 'db_url', '')}"
|
||||
if source == "clickhouse":
|
||||
@@ -193,6 +202,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
connector = None
|
||||
schema = {}
|
||||
upload_df: Optional[pd.DataFrame] = None
|
||||
|
||||
if request.source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif request.source == "clickhouse":
|
||||
@@ -204,6 +214,21 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
schema = upload_payload["schema"]
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||
elif request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
connector = get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
except ValueError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load data source: {e}")
|
||||
else:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
@@ -216,11 +241,14 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
if connector and not schema:
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# 2. Get the active LLM config
|
||||
@@ -291,19 +319,39 @@ Let's think step by step.
|
||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if request.source == "postgres":
|
||||
formatted_results = results
|
||||
elif request.source == "clickhouse":
|
||||
# ClickHouse returns list of tuples, we need column names
|
||||
# But execute_query in ClickHouseConnector just returns raw results from client.execute
|
||||
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
|
||||
# Actually, without column names it's hard to format as dict.
|
||||
# Let's assume we can just return the raw tuples for now or try to fetch column names.
|
||||
# For now, let's just return as list of lists/tuples if it's not a dict
|
||||
formatted_results = [list(row) for row in results]
|
||||
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
|
||||
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
|
||||
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
|
||||
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
|
||||
# Let's check what connector.execute_query returns.
|
||||
# PostgresConnector returns list of dicts.
|
||||
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
|
||||
# Let's handle the ClickHouse case explicitly if possible or make it generic.
|
||||
|
||||
# If results is list of tuples/lists, we need headers.
|
||||
# Postgres returns list of dicts, so we are good.
|
||||
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
|
||||
# it returns `(result_rows, column_types_list)`.
|
||||
# So `results` here would be a tuple, not a list.
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
# Likely ClickHouse (rows, columns)
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
else:
|
||||
# Unknown format, try to return as is or empty
|
||||
formatted_results = []
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if formatted_results:
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import jwt, JWTError
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSource as DataSourceSchema, DataSourceTestRequest
|
||||
from app.core.security import SECRET_KEY, ALGORITHM
|
||||
from app.connectors.factory import get_connector_from_config
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
|
||||
class CurrentUser(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
is_admin: bool = False
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
|
||||
unauthorized = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise unauthorized
|
||||
user_id = payload.get("id")
|
||||
username = payload.get("sub")
|
||||
is_admin = bool(payload.get("is_admin", False))
|
||||
if user_id is None or username is None:
|
||||
raise unauthorized
|
||||
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
|
||||
|
||||
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
@router.get("/datasources", response_model=List[DataSourceSchema])
|
||||
def list_datasources(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
# Admin can see all, regular user might only see allowed ones?
|
||||
# For now, let's assume only admin can manage, but maybe regular users can see them to use?
|
||||
# The requirement says "Add data source config in Admin User Center", implying management is admin-only.
|
||||
# But usage in chat should be available to users.
|
||||
# Let's allow read for all authenticated users for now.
|
||||
datasources = db.query(DataSource).offset(skip).limit(limit).all()
|
||||
|
||||
# Hide sensitive info for non-admins if necessary, but config usually contains secrets.
|
||||
# Maybe we should return a sanitized version for regular users?
|
||||
# For now, return full config but only to admins?
|
||||
# Or just assume the API is secure.
|
||||
# If regular users need to select datasource, they just need ID and Name.
|
||||
if not current_user.is_admin:
|
||||
# Sanitize config
|
||||
sanitized = []
|
||||
for ds in datasources:
|
||||
ds_dict = DataSourceSchema.from_orm(ds).dict()
|
||||
# Remove sensitive fields from config
|
||||
if ds_dict.get("config"):
|
||||
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
|
||||
sanitized.append(ds_dict)
|
||||
return sanitized
|
||||
|
||||
return datasources
|
||||
|
||||
@router.post("/datasources", response_model=DataSourceSchema)
|
||||
def create_datasource(
|
||||
datasource: DataSourceCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = DataSource(**datasource.dict())
|
||||
db.add(db_datasource)
|
||||
db.commit()
|
||||
db.refresh(db_datasource)
|
||||
return db_datasource
|
||||
|
||||
@router.get("/datasources/{datasource_id}", response_model=DataSourceSchema)
|
||||
def read_datasource(
|
||||
datasource_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
if not current_user.is_admin:
|
||||
ds_dict = DataSourceSchema.from_orm(db_datasource).dict()
|
||||
if ds_dict.get("config"):
|
||||
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
|
||||
return ds_dict
|
||||
|
||||
return db_datasource
|
||||
|
||||
@router.put("/datasources/{datasource_id}", response_model=DataSourceSchema)
|
||||
def update_datasource(
|
||||
datasource_id: int,
|
||||
datasource: DataSourceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
update_data = datasource.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_datasource, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_datasource)
|
||||
return db_datasource
|
||||
|
||||
@router.delete("/datasources/{datasource_id}")
|
||||
def delete_datasource(
|
||||
datasource_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
db.delete(db_datasource)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/datasources/test")
|
||||
def test_datasource_connection(
|
||||
request: DataSourceTestRequest,
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
connector = get_connector_from_config(request.type, request.config)
|
||||
if connector.test_connection():
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Connection failed")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
|
||||
@@ -19,7 +19,7 @@ class ClickHouseConnector:
|
||||
|
||||
def execute_query(self, query: str):
|
||||
try:
|
||||
return self.client.execute(query)
|
||||
return self.client.execute(query, with_column_types=True)
|
||||
except Exception as e:
|
||||
print(f"ClickHouse Query Error: {e}")
|
||||
raise e
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
import functools
|
||||
from app.connectors.postgres import PostgresConnector
|
||||
from app.connectors.clickhouse import ClickHouseConnector
|
||||
from app.connectors.parquet import ParquetConnector
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
@functools.lru_cache(maxsize=32)
|
||||
def _get_cached_connector(ds_type: str, config_json: str):
|
||||
config = json.loads(config_json)
|
||||
|
||||
if ds_type in ["postgres", "postgresql", "supabase"]:
|
||||
# Supabase is just postgres
|
||||
db_url = config.get("connection_string") or \
|
||||
f"postgresql://{config.get('user')}:{config.get('password')}@{config.get('host')}:{config.get('port', 5432)}/{config.get('database')}"
|
||||
return PostgresConnector(db_url=db_url)
|
||||
|
||||
elif ds_type == "sqlite":
|
||||
# SQLite uses connection string usually file path
|
||||
db_url = config.get("connection_string")
|
||||
if not db_url and config.get("file_path"):
|
||||
db_url = f"sqlite:///{config.get('file_path')}"
|
||||
return PostgresConnector(db_url=db_url)
|
||||
|
||||
elif ds_type == "clickhouse":
|
||||
return ClickHouseConnector(
|
||||
host=config.get("host"),
|
||||
port=config.get("port", 9000),
|
||||
user=config.get("user", "default"),
|
||||
password=config.get("password", ""),
|
||||
database=config.get("database", "default")
|
||||
)
|
||||
|
||||
elif ds_type == "parquet":
|
||||
return ParquetConnector(file_path=config.get("file_path"))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source type: {ds_type}")
|
||||
|
||||
def get_connector(datasource: DataSource):
|
||||
# Use JSON string of config as cache key
|
||||
# Ensure stable ordering of keys
|
||||
config_str = json.dumps(datasource.config, sort_keys=True)
|
||||
return _get_cached_connector(datasource.type.lower(), config_str)
|
||||
|
||||
def get_connector_from_config(ds_type: str, config: Dict[str, Any]):
|
||||
# Helper for testing connection without saving to DB
|
||||
# We can use the cached function too, or bypass if we want fresh check
|
||||
# Usually for testing we want fresh check, so let's bypass cache or clear it if needed.
|
||||
# But reusing cache is fine if config is same.
|
||||
config_str = json.dumps(config, sort_keys=True)
|
||||
return _get_cached_connector(ds_type.lower(), config_str)
|
||||
@@ -0,0 +1,58 @@
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
|
||||
class ParquetConnector:
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
if not os.path.exists(self.file_path):
|
||||
raise FileNotFoundError(f"Parquet file not found: {self.file_path}")
|
||||
|
||||
def execute_query(self, query: str) -> List[Dict[str, Any]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
# Register the parquet file as a view or table
|
||||
# We can use read_parquet directly in query, or register it.
|
||||
# Let's register it as 'parquet_table' for simplicity in generated SQL,
|
||||
# or we can ask LLM to use the filename.
|
||||
# A better approach for generic SQL is to register it as a table name derived from filename or just 'data'.
|
||||
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
|
||||
|
||||
# If the query doesn't use the table name, we might have issues.
|
||||
# But usually we provide schema with table name to LLM.
|
||||
try:
|
||||
# DuckDB returns a dataframe, we convert to dict
|
||||
df = conn.execute(query).df()
|
||||
return df.to_dict(orient="records")
|
||||
except Exception as e:
|
||||
print(f"Parquet Query Error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_schema(self) -> Dict[str, List[str]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
|
||||
|
||||
try:
|
||||
# Get columns
|
||||
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
|
||||
schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]}
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute(f"SELECT * FROM read_parquet('{self.file_path}') LIMIT 1")
|
||||
conn.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Parquet Connection Error: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import Column, Integer, String, JSON, DateTime, func
|
||||
from app.database import Base
|
||||
|
||||
class DataSource(Base):
|
||||
__tablename__ = "data_sources"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
type = Column(String, nullable=False)
|
||||
config = Column(JSON, nullable=False)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
class DataSourceBase(BaseModel):
|
||||
name: str
|
||||
type: str # sqlite, postgres, clickhouse, supabase, parquet
|
||||
config: Dict[str, Any]
|
||||
|
||||
class DataSourceCreate(DataSourceBase):
|
||||
pass
|
||||
|
||||
class DataSourceUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
class DataSource(DataSourceBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class DataSourceTestRequest(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
Binary file not shown.
+9
-1
@@ -7,12 +7,16 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.api import upload, llm, skills, users
|
||||
from app.api import upload, llm, skills, users, datasources
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.nanobot import nanobot_service
|
||||
from app.core.session_alias_store import session_alias_store
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
from app.database import engine, Base
|
||||
# Import all models to ensure they are registered
|
||||
from app.models.user import User
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -24,10 +28,14 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize database tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app.include_router(upload.router, prefix="/api/v1")
|
||||
app.include_router(llm.router, prefix="/api/v1")
|
||||
app.include_router(skills.router, prefix="/api/v1")
|
||||
app.include_router(users.router, prefix="/api/v1")
|
||||
app.include_router(datasources.router, prefix="/api/v1")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Settings } from "./pages/Settings";
|
||||
import { Users } from "./pages/Users";
|
||||
import { Login } from "./pages/Login";
|
||||
import { ModelConfigs } from "./pages/ModelConfigs";
|
||||
import { DataSources } from "./pages/DataSources";
|
||||
import { useAuthStore } from "./store/authStore";
|
||||
|
||||
// Protected Route Component
|
||||
@@ -91,6 +92,14 @@ function App() {
|
||||
</MainLayout>
|
||||
</ProtectedRoute>
|
||||
} />
|
||||
|
||||
<Route path="/datasources" element={
|
||||
<ProtectedRoute requireAdmin={true}>
|
||||
<MainLayout>
|
||||
<DataSources />
|
||||
</MainLayout>
|
||||
</ProtectedRoute>
|
||||
} />
|
||||
</Routes>
|
||||
</BrowserRouter>
|
||||
);
|
||||
|
||||
@@ -72,6 +72,12 @@ export function ChatInterface() {
|
||||
const [models, setModels] = useState<ModelConfig[]>([]);
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>("");
|
||||
const [modelOpen, setModelOpen] = useState(false);
|
||||
|
||||
// Data Source selection state
|
||||
const [availableDataSources, setAvailableDataSources] = useState<{id: string, name: string}[]>([
|
||||
{ id: "postgres-main", name: "PostgreSQL" },
|
||||
{ id: "clickhouse-main", name: "ClickHouse" }
|
||||
]);
|
||||
|
||||
// Try to parse active session from URL query
|
||||
const queryParams = new URLSearchParams(location.search);
|
||||
@@ -86,8 +92,21 @@ export function ChatInterface() {
|
||||
|
||||
useEffect(() => {
|
||||
fetchModels();
|
||||
fetchDataSources();
|
||||
}, []);
|
||||
|
||||
const fetchDataSources = async () => {
|
||||
try {
|
||||
const data = await api.get<Array<{id: number, name: string}>>("/api/v1/datasources");
|
||||
setAvailableDataSources(prev => [
|
||||
...prev.filter(d => !d.id.startsWith("ds:")),
|
||||
...data.map(d => ({ id: `ds:${d.id}`, name: d.name }))
|
||||
]);
|
||||
} catch (e) {
|
||||
console.error("Failed to fetch data sources", e);
|
||||
}
|
||||
};
|
||||
|
||||
const syncSessionFileContext = async (file: DataFileContext | null) => {
|
||||
try {
|
||||
await api.put(`/nanobot/sessions/${encodeURIComponent(activeSessionKey)}/context-file`, {
|
||||
@@ -506,8 +525,9 @@ export function ChatInterface() {
|
||||
onChange={(e) => setSelectedDataSource(e.target.value)}
|
||||
className="bg-transparent border-none outline-none text-sm font-medium"
|
||||
>
|
||||
<option value="postgres-main">PostgreSQL</option>
|
||||
<option value="clickhouse-main">ClickHouse</option>
|
||||
{availableDataSources.map(ds => (
|
||||
<option key={ds.id} value={ds.id}>{ds.name}</option>
|
||||
))}
|
||||
{activeDataFile?.url?.startsWith("local://") ? (
|
||||
<option value="upload-main">上传文件</option>
|
||||
) : null}
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Loader2, Check, AlertTriangle } from "lucide-react";
|
||||
|
||||
export interface DataSourceConfig {
|
||||
id?: number;
|
||||
name: string;
|
||||
type: string;
|
||||
config: Record<string, any>;
|
||||
}
|
||||
|
||||
interface DataSourceFormProps {
|
||||
initialData?: DataSourceConfig | null;
|
||||
onSubmit: (data: Omit<DataSourceConfig, "id">) => Promise<void>;
|
||||
onTest: (type: string, config: Record<string, any>) => Promise<boolean>;
|
||||
onCancel: () => void;
|
||||
}
|
||||
|
||||
export function DataSourceForm({ initialData, onSubmit, onTest, onCancel }: DataSourceFormProps) {
|
||||
const [name, setName] = useState(initialData?.name || "");
|
||||
const [type, setType] = useState(initialData?.type || "postgres");
|
||||
const [config, setConfig] = useState<Record<string, any>>(initialData?.config || {});
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testResult, setTestResult] = useState<{ success: boolean; message: string } | null>(null);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
|
||||
const handleConfigChange = (key: string, value: any) => {
|
||||
setConfig(prev => ({ ...prev, [key]: value }));
|
||||
};
|
||||
|
||||
const handleTest = async () => {
|
||||
setIsTesting(true);
|
||||
setTestResult(null);
|
||||
try {
|
||||
const success = await onTest(type, config);
|
||||
setTestResult({
|
||||
success,
|
||||
message: success ? "连接成功" : "连接失败",
|
||||
});
|
||||
} catch (e: any) {
|
||||
setTestResult({
|
||||
success: false,
|
||||
message: e.message || "连接失败",
|
||||
});
|
||||
} finally {
|
||||
setIsTesting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setIsSaving(true);
|
||||
try {
|
||||
await onSubmit({ name, type, config });
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const renderConfigFields = () => {
|
||||
switch (type) {
|
||||
case "postgres":
|
||||
case "supabase":
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Host</label>
|
||||
<Input
|
||||
value={config.host || ""}
|
||||
onChange={e => handleConfigChange("host", e.target.value)}
|
||||
placeholder="localhost"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Port</label>
|
||||
<Input
|
||||
type="number"
|
||||
value={config.port || 5432}
|
||||
onChange={e => handleConfigChange("port", parseInt(e.target.value))}
|
||||
placeholder="5432"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Database</label>
|
||||
<Input
|
||||
value={config.database || ""}
|
||||
onChange={e => handleConfigChange("database", e.target.value)}
|
||||
placeholder="postgres"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Username</label>
|
||||
<Input
|
||||
value={config.user || ""}
|
||||
onChange={e => handleConfigChange("user", e.target.value)}
|
||||
placeholder="postgres"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Password</label>
|
||||
<Input
|
||||
type="password"
|
||||
value={config.password || ""}
|
||||
onChange={e => handleConfigChange("password", e.target.value)}
|
||||
placeholder="••••••"
|
||||
/>
|
||||
</div>
|
||||
<div className="text-xs text-zinc-500 pt-2">
|
||||
或者使用连接字符串 (覆盖上述设置):
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Connection String</label>
|
||||
<Input
|
||||
value={config.connection_string || ""}
|
||||
onChange={e => handleConfigChange("connection_string", e.target.value)}
|
||||
placeholder="postgresql://user:pass@host:5432/db"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case "clickhouse":
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Host</label>
|
||||
<Input
|
||||
value={config.host || ""}
|
||||
onChange={e => handleConfigChange("host", e.target.value)}
|
||||
placeholder="localhost"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Port</label>
|
||||
<Input
|
||||
type="number"
|
||||
value={config.port || 9000}
|
||||
onChange={e => handleConfigChange("port", parseInt(e.target.value))}
|
||||
placeholder="9000"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Database</label>
|
||||
<Input
|
||||
value={config.database || ""}
|
||||
onChange={e => handleConfigChange("database", e.target.value)}
|
||||
placeholder="default"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Username</label>
|
||||
<Input
|
||||
value={config.user || ""}
|
||||
onChange={e => handleConfigChange("user", e.target.value)}
|
||||
placeholder="default"
|
||||
/>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Password</label>
|
||||
<Input
|
||||
type="password"
|
||||
value={config.password || ""}
|
||||
onChange={e => handleConfigChange("password", e.target.value)}
|
||||
placeholder="••••••"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case "sqlite":
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">File Path (Server Side)</label>
|
||||
<Input
|
||||
value={config.file_path || ""}
|
||||
onChange={e => handleConfigChange("file_path", e.target.value)}
|
||||
placeholder="/path/to/database.db"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case "parquet":
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">File Path (Server Side)</label>
|
||||
<Input
|
||||
value={config.file_path || ""}
|
||||
onChange={e => handleConfigChange("file_path", e.target.value)}
|
||||
placeholder="/path/to/data.parquet"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">名称</label>
|
||||
<Input
|
||||
value={name}
|
||||
onChange={e => setName(e.target.value)}
|
||||
placeholder="我的数据源"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">类型</label>
|
||||
<select
|
||||
className="w-full h-10 px-3 rounded-md border border-zinc-200 bg-white text-sm focus:outline-none focus:ring-2 focus:ring-zinc-950 focus:border-transparent"
|
||||
value={type}
|
||||
onChange={e => setType(e.target.value)}
|
||||
>
|
||||
<option value="postgres">PostgreSQL</option>
|
||||
<option value="clickhouse">ClickHouse</option>
|
||||
<option value="sqlite">SQLite</option>
|
||||
<option value="supabase">Supabase</option>
|
||||
<option value="parquet">Parquet</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div className="p-4 border border-zinc-200 rounded-lg bg-zinc-50/50">
|
||||
{renderConfigFields()}
|
||||
</div>
|
||||
|
||||
{testResult && (
|
||||
<div className={`p-3 rounded-md flex items-center gap-2 text-sm ${testResult.success ? 'bg-green-50 text-green-700' : 'bg-red-50 text-red-700'}`}>
|
||||
{testResult.success ? <Check className="h-4 w-4" /> : <AlertTriangle className="h-4 w-4" />}
|
||||
{testResult.message}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-end gap-3 pt-4">
|
||||
<Button type="button" variant="outline" onClick={onCancel}>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
onClick={handleTest}
|
||||
disabled={isTesting}
|
||||
>
|
||||
{isTesting && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
测试连接
|
||||
</Button>
|
||||
<Button type="submit" disabled={isSaving}>
|
||||
{isSaving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
保存
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Sheet, SheetContent, SheetTrigger } from "@/components/ui/sheet";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { Menu, LayoutDashboard, Plus, MoreVertical, User, Search, Wrench, Settings, Brain, Trash2, Pencil, Pin, Archive } from "lucide-react";
|
||||
import { Menu, LayoutDashboard, Plus, MoreVertical, User, Search, Wrench, Settings, Brain, Trash2, Pencil, Pin, Archive, Database } from "lucide-react";
|
||||
import { useState, useRef, useEffect } from "react";
|
||||
import { Link, useNavigate, useLocation } from "react-router-dom";
|
||||
import { useAuthStore } from "@/store/authStore";
|
||||
@@ -465,6 +465,17 @@ function SidebarBody() {
|
||||
<Brain className="h-4 w-4 text-zinc-500" />
|
||||
模型配置
|
||||
</button>
|
||||
|
||||
<button
|
||||
className="w-full flex items-center gap-2 px-3 py-2 text-sm text-zinc-700 hover:bg-zinc-100 transition-colors"
|
||||
onClick={() => {
|
||||
navigate("/datasources");
|
||||
setShowUserMenu(false);
|
||||
}}
|
||||
>
|
||||
<Database className="h-4 w-4 text-zinc-500" />
|
||||
数据源配置
|
||||
</button>
|
||||
|
||||
<button
|
||||
className="w-full flex items-center gap-2 px-3 py-2 text-sm text-indigo-600 hover:bg-indigo-50 transition-colors"
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { api } from "@/lib/api";
|
||||
import { DataSourceForm, type DataSourceConfig } from "@/components/DataSourceForm";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Plus, Database, Pencil, Trash2, Loader2 } from "lucide-react";
|
||||
import { Dialog, DialogContent, DialogHeader, DialogTitle } from "@/components/ui/dialog";
|
||||
import { useAuthStore } from "@/store/authStore";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
|
||||
export function DataSources() {
|
||||
const [datasources, setDatasources] = useState<DataSourceConfig[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [editingDs, setEditingDs] = useState<DataSourceConfig | null>(null);
|
||||
const { user } = useAuthStore();
|
||||
const navigate = useNavigate();
|
||||
|
||||
useEffect(() => {
|
||||
if (!user?.is_admin) {
|
||||
navigate("/");
|
||||
return;
|
||||
}
|
||||
fetchDataSources();
|
||||
}, [user]);
|
||||
|
||||
const fetchDataSources = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const data = await api.get<DataSourceConfig[]>("/api/v1/datasources");
|
||||
setDatasources(data);
|
||||
} catch (e) {
|
||||
console.error("Failed to fetch data sources", e);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreate = () => {
|
||||
setEditingDs(null);
|
||||
setIsOpen(true);
|
||||
};
|
||||
|
||||
const handleEdit = (ds: DataSourceConfig) => {
|
||||
setEditingDs(ds);
|
||||
setIsOpen(true);
|
||||
};
|
||||
|
||||
const handleDelete = async (id: number) => {
|
||||
if (!window.confirm("确定要删除这个数据源吗?")) return;
|
||||
try {
|
||||
await api.delete(`/api/v1/datasources/${id}`);
|
||||
fetchDataSources();
|
||||
} catch (e) {
|
||||
console.error("Failed to delete data source", e);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (data: Omit<DataSourceConfig, "id">) => {
|
||||
try {
|
||||
if (editingDs?.id) {
|
||||
await api.put(`/api/v1/datasources/${editingDs.id}`, data);
|
||||
} else {
|
||||
await api.post("/api/v1/datasources", data);
|
||||
}
|
||||
setIsOpen(false);
|
||||
fetchDataSources();
|
||||
} catch (e) {
|
||||
console.error("Failed to save data source", e);
|
||||
alert("保存失败: " + (e as any).message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleTest = async (type: string, config: Record<string, any>) => {
|
||||
try {
|
||||
const res = await api.post<{ success: boolean; message: string }>("/api/v1/datasources/test", { type, config });
|
||||
return res.success;
|
||||
} catch (e) {
|
||||
console.error("Test connection failed", e);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="h-full flex flex-col bg-white">
|
||||
<div className="border-b border-zinc-100 px-8 py-5 flex items-center justify-between">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold text-zinc-900">数据源配置</h1>
|
||||
<p className="text-sm text-zinc-500 mt-1">管理可用于问答的数据源连接</p>
|
||||
</div>
|
||||
<Button onClick={handleCreate} className="bg-indigo-600 hover:bg-indigo-700 text-white gap-2">
|
||||
<Plus className="h-4 w-4" />
|
||||
新建数据源
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 overflow-auto p-8">
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center items-center h-64">
|
||||
<Loader2 className="h-8 w-8 animate-spin text-zinc-400" />
|
||||
</div>
|
||||
) : datasources.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center h-64 border-2 border-dashed border-zinc-200 rounded-xl bg-zinc-50/50">
|
||||
<Database className="h-10 w-10 text-zinc-300 mb-3" />
|
||||
<p className="text-zinc-500 font-medium">暂无数据源</p>
|
||||
<p className="text-zinc-400 text-sm mt-1">点击右上角按钮添加第一个数据源</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
|
||||
{datasources.map((ds) => (
|
||||
<div
|
||||
key={ds.id}
|
||||
className="group relative bg-white border border-zinc-200 rounded-xl p-5 hover:shadow-md transition-all hover:border-zinc-300"
|
||||
>
|
||||
<div className="flex items-start justify-between mb-4">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="w-10 h-10 rounded-lg bg-indigo-50 flex items-center justify-center text-indigo-600">
|
||||
<Database className="h-5 w-5" />
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="font-semibold text-zinc-900">{ds.name}</h3>
|
||||
<p className="text-xs text-zinc-500 font-mono mt-0.5 uppercase">{ds.type}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-zinc-400 hover:text-zinc-600" onClick={() => handleEdit(ds)}>
|
||||
<Pencil className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-zinc-400 hover:text-red-600 hover:bg-red-50" onClick={() => handleDelete(ds.id!)}>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="text-zinc-500">Host</span>
|
||||
<span className="font-medium text-zinc-700 truncate max-w-[150px]">
|
||||
{ds.config.host || "Local / File"}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="text-zinc-500">Database</span>
|
||||
<span className="font-medium text-zinc-700 truncate max-w-[150px]">
|
||||
{ds.config.database || ds.config.file_path ? (ds.config.file_path?.split('/').pop()) : "-"}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Dialog open={isOpen} onOpenChange={setIsOpen}>
|
||||
<DialogContent className="sm:max-w-[600px] max-h-[90vh] overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{editingDs ? "编辑数据源" : "新建数据源"}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="py-4">
|
||||
<DataSourceForm
|
||||
initialData={editingDs}
|
||||
onSubmit={handleSubmit}
|
||||
onTest={handleTest}
|
||||
onCancel={() => setIsOpen(false)}
|
||||
/>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user