feat: add data source
This commit is contained in:
+65
-17
@@ -17,9 +17,12 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -33,7 +36,7 @@ _cache_lock = threading.Lock()
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
|
||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||
session_id: Optional[str] = Field(None, description="Conversation session identifier")
|
||||
|
||||
@@ -113,6 +116,8 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
|
||||
return pd.read_csv(file_path)
|
||||
if suffix in [".xls", ".xlsx"]:
|
||||
return pd.read_excel(file_path)
|
||||
if suffix == ".parquet":
|
||||
return pd.read_parquet(file_path)
|
||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||
|
||||
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||
@@ -153,6 +158,10 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
# If source is ds:ID, that's already a good key
|
||||
if source.startswith("ds:"):
|
||||
return source
|
||||
|
||||
if source == "postgres":
|
||||
return f"postgres:{getattr(connector, 'db_url', '')}"
|
||||
if source == "clickhouse":
|
||||
@@ -193,6 +202,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
connector = None
|
||||
schema = {}
|
||||
upload_df: Optional[pd.DataFrame] = None
|
||||
|
||||
if request.source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif request.source == "clickhouse":
|
||||
@@ -204,6 +214,21 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
schema = upload_payload["schema"]
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||
elif request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
connector = get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
except ValueError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load data source: {e}")
|
||||
else:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
@@ -216,11 +241,14 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
if connector and not schema:
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# 2. Get the active LLM config
|
||||
@@ -291,19 +319,39 @@ Let's think step by step.
|
||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if request.source == "postgres":
|
||||
formatted_results = results
|
||||
elif request.source == "clickhouse":
|
||||
# ClickHouse returns list of tuples, we need column names
|
||||
# But execute_query in ClickHouseConnector just returns raw results from client.execute
|
||||
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
|
||||
# Actually, without column names it's hard to format as dict.
|
||||
# Let's assume we can just return the raw tuples for now or try to fetch column names.
|
||||
# For now, let's just return as list of lists/tuples if it's not a dict
|
||||
formatted_results = [list(row) for row in results]
|
||||
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
|
||||
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
|
||||
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
|
||||
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
|
||||
# Let's check what connector.execute_query returns.
|
||||
# PostgresConnector returns list of dicts.
|
||||
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
|
||||
# Let's handle the ClickHouse case explicitly if possible or make it generic.
|
||||
|
||||
# If results is list of tuples/lists, we need headers.
|
||||
# Postgres returns list of dicts, so we are good.
|
||||
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
|
||||
# it returns `(result_rows, column_types_list)`.
|
||||
# So `results` here would be a tuple, not a list.
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
# Likely ClickHouse (rows, columns)
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
else:
|
||||
# Unknown format, try to return as is or empty
|
||||
formatted_results = []
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if formatted_results:
|
||||
|
||||
Reference in New Issue
Block a user