feat: add data source
This commit is contained in:
+65
-17
@@ -17,9 +17,12 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -33,7 +36,7 @@ _cache_lock = threading.Lock()
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
|
||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||
session_id: Optional[str] = Field(None, description="Conversation session identifier")
|
||||
|
||||
@@ -113,6 +116,8 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
|
||||
return pd.read_csv(file_path)
|
||||
if suffix in [".xls", ".xlsx"]:
|
||||
return pd.read_excel(file_path)
|
||||
if suffix == ".parquet":
|
||||
return pd.read_parquet(file_path)
|
||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||
|
||||
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||
@@ -153,6 +158,10 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
# If source is ds:ID, that's already a good key
|
||||
if source.startswith("ds:"):
|
||||
return source
|
||||
|
||||
if source == "postgres":
|
||||
return f"postgres:{getattr(connector, 'db_url', '')}"
|
||||
if source == "clickhouse":
|
||||
@@ -193,6 +202,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
connector = None
|
||||
schema = {}
|
||||
upload_df: Optional[pd.DataFrame] = None
|
||||
|
||||
if request.source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif request.source == "clickhouse":
|
||||
@@ -204,6 +214,21 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
schema = upload_payload["schema"]
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||
elif request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
connector = get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
except ValueError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load data source: {e}")
|
||||
else:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
@@ -216,11 +241,14 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
if connector and not schema:
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# 2. Get the active LLM config
|
||||
@@ -291,19 +319,39 @@ Let's think step by step.
|
||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if request.source == "postgres":
|
||||
formatted_results = results
|
||||
elif request.source == "clickhouse":
|
||||
# ClickHouse returns list of tuples, we need column names
|
||||
# But execute_query in ClickHouseConnector just returns raw results from client.execute
|
||||
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
|
||||
# Actually, without column names it's hard to format as dict.
|
||||
# Let's assume we can just return the raw tuples for now or try to fetch column names.
|
||||
# For now, let's just return as list of lists/tuples if it's not a dict
|
||||
formatted_results = [list(row) for row in results]
|
||||
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
|
||||
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
|
||||
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
|
||||
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
|
||||
# Let's check what connector.execute_query returns.
|
||||
# PostgresConnector returns list of dicts.
|
||||
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
|
||||
# Let's handle the ClickHouse case explicitly if possible or make it generic.
|
||||
|
||||
# If results is list of tuples/lists, we need headers.
|
||||
# Postgres returns list of dicts, so we are good.
|
||||
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
|
||||
# it returns `(result_rows, column_types_list)`.
|
||||
# So `results` here would be a tuple, not a list.
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
# Likely ClickHouse (rows, columns)
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
else:
|
||||
# Unknown format, try to return as is or empty
|
||||
formatted_results = []
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if formatted_results:
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import jwt, JWTError
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import get_db
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.datasource import DataSourceCreate, DataSourceUpdate, DataSource as DataSourceSchema, DataSourceTestRequest
|
||||
from app.core.security import SECRET_KEY, ALGORITHM
|
||||
from app.connectors.factory import get_connector_from_config
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
|
||||
class CurrentUser(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
is_admin: bool = False
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
|
||||
unauthorized = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise unauthorized
|
||||
user_id = payload.get("id")
|
||||
username = payload.get("sub")
|
||||
is_admin = bool(payload.get("is_admin", False))
|
||||
if user_id is None or username is None:
|
||||
raise unauthorized
|
||||
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
|
||||
|
||||
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
@router.get("/datasources", response_model=List[DataSourceSchema])
|
||||
def list_datasources(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
# Admin can see all, regular user might only see allowed ones?
|
||||
# For now, let's assume only admin can manage, but maybe regular users can see them to use?
|
||||
# The requirement says "Add data source config in Admin User Center", implying management is admin-only.
|
||||
# But usage in chat should be available to users.
|
||||
# Let's allow read for all authenticated users for now.
|
||||
datasources = db.query(DataSource).offset(skip).limit(limit).all()
|
||||
|
||||
# Hide sensitive info for non-admins if necessary, but config usually contains secrets.
|
||||
# Maybe we should return a sanitized version for regular users?
|
||||
# For now, return full config but only to admins?
|
||||
# Or just assume the API is secure.
|
||||
# If regular users need to select datasource, they just need ID and Name.
|
||||
if not current_user.is_admin:
|
||||
# Sanitize config
|
||||
sanitized = []
|
||||
for ds in datasources:
|
||||
ds_dict = DataSourceSchema.from_orm(ds).dict()
|
||||
# Remove sensitive fields from config
|
||||
if ds_dict.get("config"):
|
||||
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
|
||||
sanitized.append(ds_dict)
|
||||
return sanitized
|
||||
|
||||
return datasources
|
||||
|
||||
@router.post("/datasources", response_model=DataSourceSchema)
|
||||
def create_datasource(
|
||||
datasource: DataSourceCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = DataSource(**datasource.dict())
|
||||
db.add(db_datasource)
|
||||
db.commit()
|
||||
db.refresh(db_datasource)
|
||||
return db_datasource
|
||||
|
||||
@router.get("/datasources/{datasource_id}", response_model=DataSourceSchema)
|
||||
def read_datasource(
|
||||
datasource_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
if not current_user.is_admin:
|
||||
ds_dict = DataSourceSchema.from_orm(db_datasource).dict()
|
||||
if ds_dict.get("config"):
|
||||
ds_dict["config"] = {k: v for k, v in ds_dict["config"].items() if k not in ["password", "api_key", "secret"]}
|
||||
return ds_dict
|
||||
|
||||
return db_datasource
|
||||
|
||||
@router.put("/datasources/{datasource_id}", response_model=DataSourceSchema)
|
||||
def update_datasource(
|
||||
datasource_id: int,
|
||||
datasource: DataSourceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
update_data = datasource.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_datasource, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_datasource)
|
||||
return db_datasource
|
||||
|
||||
@router.delete("/datasources/{datasource_id}")
|
||||
def delete_datasource(
|
||||
datasource_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
db_datasource = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if db_datasource is None:
|
||||
raise HTTPException(status_code=404, detail="Data source not found")
|
||||
|
||||
db.delete(db_datasource)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/datasources/test")
|
||||
def test_datasource_connection(
|
||||
request: DataSourceTestRequest,
|
||||
_: CurrentUser = Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
connector = get_connector_from_config(request.type, request.config)
|
||||
if connector.test_connection():
|
||||
return {"success": True, "message": "Connection successful"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Connection failed")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Connection failed: {str(e)}")
|
||||
@@ -19,7 +19,7 @@ class ClickHouseConnector:
|
||||
|
||||
def execute_query(self, query: str):
|
||||
try:
|
||||
return self.client.execute(query)
|
||||
return self.client.execute(query, with_column_types=True)
|
||||
except Exception as e:
|
||||
print(f"ClickHouse Query Error: {e}")
|
||||
raise e
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
import functools
|
||||
from app.connectors.postgres import PostgresConnector
|
||||
from app.connectors.clickhouse import ClickHouseConnector
|
||||
from app.connectors.parquet import ParquetConnector
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
@functools.lru_cache(maxsize=32)
|
||||
def _get_cached_connector(ds_type: str, config_json: str):
|
||||
config = json.loads(config_json)
|
||||
|
||||
if ds_type in ["postgres", "postgresql", "supabase"]:
|
||||
# Supabase is just postgres
|
||||
db_url = config.get("connection_string") or \
|
||||
f"postgresql://{config.get('user')}:{config.get('password')}@{config.get('host')}:{config.get('port', 5432)}/{config.get('database')}"
|
||||
return PostgresConnector(db_url=db_url)
|
||||
|
||||
elif ds_type == "sqlite":
|
||||
# SQLite uses connection string usually file path
|
||||
db_url = config.get("connection_string")
|
||||
if not db_url and config.get("file_path"):
|
||||
db_url = f"sqlite:///{config.get('file_path')}"
|
||||
return PostgresConnector(db_url=db_url)
|
||||
|
||||
elif ds_type == "clickhouse":
|
||||
return ClickHouseConnector(
|
||||
host=config.get("host"),
|
||||
port=config.get("port", 9000),
|
||||
user=config.get("user", "default"),
|
||||
password=config.get("password", ""),
|
||||
database=config.get("database", "default")
|
||||
)
|
||||
|
||||
elif ds_type == "parquet":
|
||||
return ParquetConnector(file_path=config.get("file_path"))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source type: {ds_type}")
|
||||
|
||||
def get_connector(datasource: DataSource):
|
||||
# Use JSON string of config as cache key
|
||||
# Ensure stable ordering of keys
|
||||
config_str = json.dumps(datasource.config, sort_keys=True)
|
||||
return _get_cached_connector(datasource.type.lower(), config_str)
|
||||
|
||||
def get_connector_from_config(ds_type: str, config: Dict[str, Any]):
|
||||
# Helper for testing connection without saving to DB
|
||||
# We can use the cached function too, or bypass if we want fresh check
|
||||
# Usually for testing we want fresh check, so let's bypass cache or clear it if needed.
|
||||
# But reusing cache is fine if config is same.
|
||||
config_str = json.dumps(config, sort_keys=True)
|
||||
return _get_cached_connector(ds_type.lower(), config_str)
|
||||
@@ -0,0 +1,58 @@
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
|
||||
class ParquetConnector:
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
if not os.path.exists(self.file_path):
|
||||
raise FileNotFoundError(f"Parquet file not found: {self.file_path}")
|
||||
|
||||
def execute_query(self, query: str) -> List[Dict[str, Any]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
# Register the parquet file as a view or table
|
||||
# We can use read_parquet directly in query, or register it.
|
||||
# Let's register it as 'parquet_table' for simplicity in generated SQL,
|
||||
# or we can ask LLM to use the filename.
|
||||
# A better approach for generic SQL is to register it as a table name derived from filename or just 'data'.
|
||||
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
|
||||
|
||||
# If the query doesn't use the table name, we might have issues.
|
||||
# But usually we provide schema with table name to LLM.
|
||||
try:
|
||||
# DuckDB returns a dataframe, we convert to dict
|
||||
df = conn.execute(query).df()
|
||||
return df.to_dict(orient="records")
|
||||
except Exception as e:
|
||||
print(f"Parquet Query Error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_schema(self) -> Dict[str, List[str]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
|
||||
|
||||
try:
|
||||
# Get columns
|
||||
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
|
||||
schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]}
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute(f"SELECT * FROM read_parquet('{self.file_path}') LIMIT 1")
|
||||
conn.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Parquet Connection Error: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import Column, Integer, String, JSON, DateTime, func
|
||||
from app.database import Base
|
||||
|
||||
class DataSource(Base):
|
||||
__tablename__ = "data_sources"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
type = Column(String, nullable=False)
|
||||
config = Column(JSON, nullable=False)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
class DataSourceBase(BaseModel):
|
||||
name: str
|
||||
type: str # sqlite, postgres, clickhouse, supabase, parquet
|
||||
config: Dict[str, Any]
|
||||
|
||||
class DataSourceCreate(DataSourceBase):
|
||||
pass
|
||||
|
||||
class DataSourceUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
class DataSource(DataSourceBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class DataSourceTestRequest(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
Binary file not shown.
+9
-1
@@ -7,12 +7,16 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.api import upload, llm, skills, users
|
||||
from app.api import upload, llm, skills, users, datasources
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.nanobot import nanobot_service
|
||||
from app.core.session_alias_store import session_alias_store
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
from app.database import engine, Base
|
||||
# Import all models to ensure they are registered
|
||||
from app.models.user import User
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -24,10 +28,14 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize database tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
app.include_router(upload.router, prefix="/api/v1")
|
||||
app.include_router(llm.router, prefix="/api/v1")
|
||||
app.include_router(skills.router, prefix="/api/v1")
|
||||
app.include_router(users.router, prefix="/api/v1")
|
||||
app.include_router(datasources.router, prefix="/api/v1")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
|
||||
Reference in New Issue
Block a user