feat: add data source

This commit is contained in:
qixinbo
2026-03-15 19:36:02 +08:00
parent 219944f059
commit f1db709aae
14 changed files with 851 additions and 22 deletions
+65 -17
View File
@@ -17,9 +17,12 @@ if str(PROJECT_ROOT) not in sys.path:
from nanobot.providers.litellm_provider import LiteLLMProvider
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse
from app.agent.chart import generate_chart
from app.database import SessionLocal
from app.models.datasource import DataSource
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -33,7 +36,7 @@ _cache_lock = threading.Lock()
class NL2SQLRequest(BaseModel):
query: str = Field(..., description="User's natural language query")
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
session_id: Optional[str] = Field(None, description="Conversation session identifier")
@@ -113,6 +116,8 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
return pd.read_csv(file_path)
if suffix in [".xls", ".xlsx"]:
return pd.read_excel(file_path)
if suffix == ".parquet":
return pd.read_parquet(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
@@ -153,6 +158,10 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
return result_df.to_dict(orient="records")
def _build_schema_cache_key(source: str, connector: Any) -> str:
# If source is ds:ID, that's already a good key
if source.startswith("ds:"):
return source
if source == "postgres":
return f"postgres:{getattr(connector, 'db_url', '')}"
if source == "clickhouse":
@@ -193,6 +202,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = None
schema = {}
upload_df: Optional[pd.DataFrame] = None
if request.source == "postgres":
connector = postgres_connector
elif request.source == "clickhouse":
@@ -204,6 +214,21 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
schema = upload_payload["schema"]
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
elif request.source.startswith("ds:"):
try:
ds_id = int(request.source.split(":")[1])
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
if not ds:
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
connector = get_connector(ds)
finally:
db.close()
except ValueError:
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load data source: {e}")
else:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
@@ -216,11 +241,14 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
if connector and not schema:
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
schema_str = json.dumps(schema, indent=2)
# 2. Get the active LLM config
@@ -291,19 +319,39 @@ Let's think step by step.
formatted_results = _execute_upload_sql(sql_query, upload_df)
else:
results = connector.execute_query(sql_query)
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
# Format results
formatted_results = []
if request.source == "postgres":
formatted_results = results
elif request.source == "clickhouse":
# ClickHouse returns list of tuples, we need column names
# But execute_query in ClickHouseConnector just returns raw results from client.execute
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
# Actually, without column names it's hard to format as dict.
# Let's assume we can just return the raw tuples for now or try to fetch column names.
# For now, let's just return as list of lists/tuples if it's not a dict
formatted_results = [list(row) for row in results]
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
# Let's check what connector.execute_query returns.
# PostgresConnector returns list of dicts.
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
# Let's handle the ClickHouse case explicitly if possible or make it generic.
# If results is list of tuples/lists, we need headers.
# Postgres returns list of dicts, so we are good.
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
# it returns `(result_rows, column_types_list)`.
# So `results` here would be a tuple, not a list.
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
# Likely ClickHouse (rows, columns)
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
else:
# Unknown format, try to return as is or empty
formatted_results = []
# 7. Generate Chart
chart_response = None
if formatted_results:
+149
View File
@@ -0,0 +1,149 @@
from typing import List, Dict, Any
from fastapi import APIRouter, HTTPException, Depends, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import jwt, JWTError
from sqlalchemy.orm import Session
from app.database import get_db
from app.models.datasource import DataSource
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSource as DataSourceSchema, DataSourceTestRequest
from app.core.security import SECRET_KEY, ALGORITHM
from app.connectors.factory import get_connector_from_config
from pydantic import BaseModel
router = APIRouter()
security = HTTPBearer()
class CurrentUser(BaseModel):
id: int
username: str
is_admin: bool = False
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
unauthorized = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
try:
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise unauthorized
user_id = payload.get("id")
username = payload.get("sub")
is_admin = bool(payload.get("is_admin", False))
if user_id is None or username is None:
raise unauthorized
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
return current_user
@router.get("/datasources", response_model=List[DataSourceSchema])
def list_datasources(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
# Admin can see all, regular user might only see allowed ones?
# For now, let's assume only admin can manage, but maybe regular users can see them to use?
# The requirement says "Add data source config in Admin User Center", implying management is admin-only.
# But usage in chat should be available to users.
# Let's allow read for all authenticated users for now.
datasources = db.query(DataSource).offset(skip).limit(limit).all()
# Hide sensitive info for non-admins if necessary, but config usually contains secrets.
# Maybe we should return a sanitized version for regular users?
# For now, return full config but only to admins?
# Or just assume the API is secure.
# If regular users need to select datasource, they just need ID and Name.
if not current_user.is_admin:
# Sanitize config
sanitized = []
for ds in datasources:
ds_dict = DataSourceSchema.from_orm(ds).dict()
# Remove sensitive fields from config
if ds_dict.get("config"):
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
sanitized.append(ds_dict)
return sanitized
return datasources
@router.post("/datasources", response_model=DataSourceSchema)
def create_datasource(
datasource: DataSourceCreate,
db: Session = Depends(get_db),
_: CurrentUser = Depends(get_admin_user)
):
db_datasource = DataSource(**datasource.dict())
db.add(db_datasource)
db.commit()
db.refresh(db_datasource)
return db_datasource
@router.get("/datasources/{datasource_id}", response_model=DataSourceSchema)
def read_datasource(
datasource_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user)
):
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if db_datasource is None:
raise HTTPException(status_code=404, detail="Data source not found")
if not current_user.is_admin:
ds_dict = DataSourceSchema.from_orm(db_datasource).dict()
if ds_dict.get("config"):
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
return ds_dict
return db_datasource
@router.put("/datasources/{datasource_id}", response_model=DataSourceSchema)
def update_datasource(
datasource_id: int,
datasource: DataSourceUpdate,
db: Session = Depends(get_db),
_: CurrentUser = Depends(get_admin_user)
):
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if db_datasource is None:
raise HTTPException(status_code=404, detail="Data source not found")
update_data = datasource.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_datasource, key, value)
db.commit()
db.refresh(db_datasource)
return db_datasource
@router.delete("/datasources/{datasource_id}")
def delete_datasource(
datasource_id: int,
db: Session = Depends(get_db),
_: CurrentUser = Depends(get_admin_user)
):
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if db_datasource is None:
raise HTTPException(status_code=404, detail="Data source not found")
db.delete(db_datasource)
db.commit()
return {"ok": True}
@router.post("/datasources/test")
def test_datasource_connection(
request: DataSourceTestRequest,
_: CurrentUser = Depends(get_admin_user)
):
try:
connector = get_connector_from_config(request.type, request.config)
if connector.test_connection():
return {"success": True, "message": "Connection successful"}
else:
raise HTTPException(status_code=400, detail="Connection failed")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
+1 -1
View File
@@ -19,7 +19,7 @@ class ClickHouseConnector:
def execute_query(self, query: str):
try:
return self.client.execute(query)
return self.client.execute(query, with_column_types=True)
except Exception as e:
print(f"ClickHouse Query Error: {e}")
raise e
+53
View File
@@ -0,0 +1,53 @@
from typing import Dict, Any, Optional
import json
import functools
from app.connectors.postgres import PostgresConnector
from app.connectors.clickhouse import ClickHouseConnector
from app.connectors.parquet import ParquetConnector
from app.models.datasource import DataSource
@functools.lru_cache(maxsize=32)
def _get_cached_connector(ds_type: str, config_json: str):
config = json.loads(config_json)
if ds_type in ["postgres", "postgresql", "supabase"]:
# Supabase is just postgres
db_url = config.get("connection_string") or \
f"postgresql://{config.get('user')}:{config.get('password')}@{config.get('host')}:{config.get('port', 5432)}/{config.get('database')}"
return PostgresConnector(db_url=db_url)
elif ds_type == "sqlite":
# SQLite uses connection string usually file path
db_url = config.get("connection_string")
if not db_url and config.get("file_path"):
db_url = f"sqlite:///{config.get('file_path')}"
return PostgresConnector(db_url=db_url)
elif ds_type == "clickhouse":
return ClickHouseConnector(
host=config.get("host"),
port=config.get("port", 9000),
user=config.get("user", "default"),
password=config.get("password", ""),
database=config.get("database", "default")
)
elif ds_type == "parquet":
return ParquetConnector(file_path=config.get("file_path"))
else:
raise ValueError(f"Unsupported data source type: {ds_type}")
def get_connector(datasource: DataSource):
# Use JSON string of config as cache key
# Ensure stable ordering of keys
config_str = json.dumps(datasource.config, sort_keys=True)
return _get_cached_connector(datasource.type.lower(), config_str)
def get_connector_from_config(ds_type: str, config: Dict[str, Any]):
# Helper for testing connection without saving to DB
# We can use the cached function too, or bypass if we want fresh check
# Usually for testing we want fresh check, so let's bypass cache or clear it if needed.
# But reusing cache is fine if config is same.
config_str = json.dumps(config, sort_keys=True)
return _get_cached_connector(ds_type.lower(), config_str)
+58
View File
@@ -0,0 +1,58 @@
import duckdb
import pandas as pd
from typing import List, Dict, Any
import os
class ParquetConnector:
def __init__(self, file_path: str):
self.file_path = file_path
if not os.path.exists(self.file_path):
raise FileNotFoundError(f"Parquet file not found: {self.file_path}")
def execute_query(self, query: str) -> List[Dict[str, Any]]:
conn = duckdb.connect(":memory:")
# Register the parquet file as a view or table
# We can use read_parquet directly in query, or register it.
# Let's register it as 'parquet_table' for simplicity in generated SQL,
# or we can ask LLM to use the filename.
# A better approach for generic SQL is to register it as a table name derived from filename or just 'data'.
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
# If the query doesn't use the table name, we might have issues.
# But usually we provide schema with table name to LLM.
try:
# DuckDB returns a dataframe, we convert to dict
df = conn.execute(query).df()
return df.to_dict(orient="records")
except Exception as e:
print(f"Parquet Query Error: {e}")
raise e
finally:
conn.close()
def get_schema(self) -> Dict[str, List[str]]:
conn = duckdb.connect(":memory:")
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
try:
# Get columns
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]}
return schema
except Exception as e:
print(f"Error getting schema: {e}")
return {}
finally:
conn.close()
def test_connection(self) -> bool:
try:
conn = duckdb.connect(":memory:")
conn.execute(f"SELECT * FROM read_parquet('{self.file_path}') LIMIT 1")
conn.close()
return True
except Exception as e:
print(f"Parquet Connection Error: {e}")
return False
+12
View File
@@ -0,0 +1,12 @@
from sqlalchemy import Column, Integer, String, JSON, DateTime, func
from app.database import Base
class DataSource(Base):
__tablename__ = "data_sources"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
type = Column(String, nullable=False)
config = Column(JSON, nullable=False)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
+28
View File
@@ -0,0 +1,28 @@
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
from datetime import datetime
class DataSourceBase(BaseModel):
name: str
type: str # sqlite, postgres, clickhouse, supabase, parquet
config: Dict[str, Any]
class DataSourceCreate(DataSourceBase):
pass
class DataSourceUpdate(BaseModel):
name: Optional[str] = None
type: Optional[str] = None
config: Optional[Dict[str, Any]] = None
class DataSource(DataSourceBase):
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class DataSourceTestRequest(BaseModel):
type: str
config: Dict[str, Any]
Binary file not shown.
+9 -1
View File
@@ -7,12 +7,16 @@ import asyncio
import json
from datetime import datetime
from app.api import upload, llm, skills, users
from app.api import upload, llm, skills, users, datasources
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.core.nanobot import nanobot_service
from app.core.session_alias_store import session_alias_store
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
from app.database import engine, Base
# Import all models to ensure they are registered
from app.models.user import User
from app.models.datasource import DataSource
app = FastAPI()
@@ -24,10 +28,14 @@ app.add_middleware(
allow_headers=["*"],
)
# Initialize database tables
Base.metadata.create_all(bind=engine)
app.include_router(upload.router, prefix="/api/v1")
app.include_router(llm.router, prefix="/api/v1")
app.include_router(skills.router, prefix="/api/v1")
app.include_router(users.router, prefix="/api/v1")
app.include_router(datasources.router, prefix="/api/v1")
@app.on_event("startup")
async def startup_event():