diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index e90750e..941b03f 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -24,6 +24,7 @@ from app.agent.chart import generate_chart from app.database import SessionLocal from app.models.datasource import DataSource from app.core.files import resolve_upload_file_path +from app.services.mdl import MDLService SCHEMA_CACHE_TTL_SECONDS = 300 CONNECTION_CACHE_TTL_SECONDS = 30 @@ -116,11 +117,11 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame: return pd.read_parquet(file_path) raise ValueError(f"Unsupported uploaded file type: {suffix}") -def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]: +def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[Dict[str, str]]]: conn = duckdb.connect(":memory:") conn.register("uploaded_file", df) columns = conn.execute("DESCRIBE uploaded_file").fetchall() - schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]} + schema = {"uploaded_file": [{"name": col[0], "type": col[1]} for col in columns]} conn.close() return schema @@ -167,7 +168,7 @@ def _build_schema_cache_key(source: str, connector: Any) -> str: ) return source -def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]: +def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[Dict[str, str]]]]: key = _build_schema_cache_key(source, connector) now = time.time() with _cache_lock: @@ -176,7 +177,7 @@ def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[s return cached["schema"] return None -def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None: +def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[str, str]]]) -> None: key = _build_schema_cache_key(source, connector) with _cache_lock: _schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS} @@ -247,6 +248,37 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: schema_str = json.dumps(schema, indent=2) + # Try to load MDL context + mdl_context = "" + if request.source.startswith("ds:"): + try: + ds_id = int(request.source.split(":")[1]) + mdl = MDLService.get_mdl(ds_id) + if mdl: + mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"] + + mdl_lines.append("MODELS:") + for model in mdl.models: + table_ref = model.tableReference.table if model.tableReference else model.name + desc = f" - Description: {model.properties.get('description', '')}" if model.properties.get('description') else "" + mdl_lines.append(f"- Model: {model.name} (Table: {table_ref}){desc}") + + if model.columns: + mdl_lines.append(" Columns:") + for col in model.columns: + col_desc = f" ({col.properties.get('description')})" if col.properties.get('description') else "" + expr = f" [Calculated: {col.expression}]" if col.isCalculated else "" + mdl_lines.append(f" - {col.name} ({col.type}){col_desc}{expr}") + + if mdl.relationships: + mdl_lines.append("\nRELATIONSHIPS:") + for rel in mdl.relationships: + mdl_lines.append(f"- {rel.name}: {rel.joinType} between {rel.models} ON {rel.condition}") + + mdl_context = "\n".join(mdl_lines) + except Exception as e: + print(f"Failed to load MDL: {e}") + # 2. Get the active LLM config llm_configs = load_llm_config() active_config = next((c for c in llm_configs if c.get("is_active")), None) @@ -270,6 +302,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: user_prompt = f""" ### DATABASE SCHEMA ### {schema_str} +{mdl_context} ### INPUTS ### User's Question: {request.query} diff --git a/backend/app/api/semantic.py b/backend/app/api/semantic.py new file mode 100644 index 0000000..3abebab --- /dev/null +++ b/backend/app/api/semantic.py @@ -0,0 +1,139 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session +from typing import Dict, Any, List, Optional +from pydantic import BaseModel + +from app.database import get_db +from app.models.datasource import DataSource +from app.schemas.mdl import MDLManifest +from app.services.mdl import MDLService +from app.connectors.factory import get_connector + +router = APIRouter(tags=["semantic"]) + +class GenerateMDLRequest(BaseModel): + selected_tables: Optional[List[str]] = None + selected_columns: Optional[Dict[str, List[str]]] = None + +class ModelDetailResponse(BaseModel): + model: Dict[str, Any] + relationships: List[Dict[str, Any]] + preview_rows: List[Dict[str, Any]] + +def _normalize_query_result(results: Any) -> List[Dict[str, Any]]: + if isinstance(results, list): + if results and isinstance(results[0], dict): + return results + if results and isinstance(results[0], (list, tuple)): + return [dict(enumerate(row)) for row in results] + return [] + if isinstance(results, tuple) and len(results) == 2: + rows, cols = results + col_names = [c[0] for c in cols] + return [dict(zip(col_names, row)) for row in rows] + return [] + +@router.get("/semantic/{datasource_id}/schema", response_model=Dict[str, List[Dict[str, str]]]) +def get_semantic_schema(datasource_id: int, db: Session = Depends(get_db)): + # Check if datasource exists + ds = db.query(DataSource).filter(DataSource.id == datasource_id).first() + if not ds: + raise HTTPException(status_code=404, detail="DataSource not found") + + try: + return MDLService.get_raw_schema(ds) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/semantic/{datasource_id}", response_model=MDLManifest) +def get_semantic_model(datasource_id: int, db: Session = Depends(get_db)): + # Check if datasource exists + ds = db.query(DataSource).filter(DataSource.id == datasource_id).first() + if not ds: + raise HTTPException(status_code=404, detail="DataSource not found") + + # Get or generate MDL + try: + mdl = MDLService.get_or_create_mdl(datasource_id) + return mdl + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.put("/semantic/{datasource_id}", response_model=MDLManifest) +def update_semantic_model(datasource_id: int, mdl: MDLManifest, db: Session = Depends(get_db)): + # Check if datasource exists + ds = db.query(DataSource).filter(DataSource.id == datasource_id).first() + if not ds: + raise HTTPException(status_code=404, detail="DataSource not found") + + try: + MDLService.save_mdl(datasource_id, mdl) + return mdl + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/semantic/{datasource_id}/generate", response_model=MDLManifest) +def regenerate_semantic_model(datasource_id: int, request: Optional[GenerateMDLRequest] = None, db: Session = Depends(get_db)): + ds = db.query(DataSource).filter(DataSource.id == datasource_id).first() + if not ds: + raise HTTPException(status_code=404, detail="DataSource not found") + + try: + selected_tables = request.selected_tables if request else None + selected_columns = request.selected_columns if request else None + mdl = MDLService.generate_default_mdl( + ds, + selected_tables=selected_tables, + selected_columns=selected_columns, + ) + MDLService.save_mdl(datasource_id, mdl) + return mdl + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.get("/semantic/{datasource_id}/models/{model_name}", response_model=ModelDetailResponse) +def get_model_detail(datasource_id: int, model_name: str, limit: int = 10, db: Session = Depends(get_db)): + ds = db.query(DataSource).filter(DataSource.id == datasource_id).first() + if not ds: + raise HTTPException(status_code=404, detail="DataSource not found") + + mdl = MDLService.get_or_create_mdl(datasource_id) + model = next((m for m in mdl.models if m.name == model_name), None) + if not model: + raise HTTPException(status_code=404, detail="Model not found") + + relationships = [ + { + "name": rel.name, + "models": rel.models, + "joinType": rel.joinType, + "condition": rel.condition, + "properties": rel.properties, + } + for rel in mdl.relationships + if model_name in rel.models + ] + + preview_rows: List[Dict[str, Any]] = [] + try: + connector = get_connector(ds) + table_name = model.tableReference.table if model.tableReference else model.name + query = f'SELECT * FROM "{table_name}" LIMIT {max(1, min(limit, 100))}' + raw = connector.execute_query(query) + preview_rows = _normalize_query_result(raw) + except Exception: + preview_rows = [] + + model_payload = { + "name": model.name, + "tableReference": model.tableReference.model_dump(by_alias=True) if model.tableReference else None, + "primaryKey": model.primaryKey, + "properties": model.properties, + "columns": [c.model_dump(by_alias=True) for c in model.columns], + } + + return ModelDetailResponse( + model=model_payload, + relationships=relationships, + preview_rows=preview_rows, + ) diff --git a/backend/app/connectors/clickhouse.py b/backend/app/connectors/clickhouse.py index a4cac85..608663e 100644 --- a/backend/app/connectors/clickhouse.py +++ b/backend/app/connectors/clickhouse.py @@ -33,7 +33,7 @@ class ClickHouseConnector: table = row[0] if table not in schema: schema[table] = [] - schema[table].append(f"{row[1]} ({row[2]})") + schema[table].append({"name": row[1], "type": row[2]}) return schema except Exception as e: print(f"Error getting schema: {e}") diff --git a/backend/app/connectors/csv.py b/backend/app/connectors/csv.py new file mode 100644 index 0000000..7c7b642 --- /dev/null +++ b/backend/app/connectors/csv.py @@ -0,0 +1,68 @@ +import duckdb +import pandas as pd +from typing import List, Dict, Any +import os +from app.core.files import resolve_upload_file_path + +class CSVConnector: + def __init__(self, file_path: str): + self.file_path = file_path + if not os.path.exists(self.file_path): + raise FileNotFoundError(f"CSV file not found: {self.file_path}") + + def _get_table_name(self) -> str: + # Normalize table name to be SQL safe-ish + base = os.path.splitext(os.path.basename(self.file_path))[0] + # Replace non-alphanumeric chars with underscore + safe_name = "".join([c if c.isalnum() else "_" for c in base]) + # Ensure it doesn't start with a number + if safe_name and safe_name[0].isdigit(): + safe_name = f"t_{safe_name}" + return safe_name + + def execute_query(self, query: str) -> List[Dict[str, Any]]: + conn = duckdb.connect(":memory:") + table_name = self._get_table_name() + + # Register the csv file as a view + # read_csv_auto is powerful + try: + conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')") + + # Execute the user query + # The query should rely on the table name provided in schema + df = conn.execute(query).df() + return df.to_dict(orient="records") + except Exception as e: + print(f"CSV Query Error: {e}") + raise e + finally: + conn.close() + + def get_schema(self) -> Dict[str, List[Dict[str, str]]]: + conn = duckdb.connect(":memory:") + table_name = self._get_table_name() + + try: + conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')") + + # Get columns + columns = conn.execute(f"DESCRIBE {table_name}").fetchall() + # col[0] is name, col[1] is type + schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]} + return schema + except Exception as e: + print(f"Error getting schema: {e}") + return {} + finally: + conn.close() + + def test_connection(self) -> bool: + try: + conn = duckdb.connect(":memory:") + conn.execute(f"SELECT * FROM read_csv_auto('{self.file_path}') LIMIT 1") + conn.close() + return True + except Exception as e: + print(f"CSV Connection Error: {e}") + return False diff --git a/backend/app/connectors/factory.py b/backend/app/connectors/factory.py index 1023a6c..87da6da 100644 --- a/backend/app/connectors/factory.py +++ b/backend/app/connectors/factory.py @@ -4,6 +4,7 @@ import functools from app.connectors.postgres import PostgresConnector from app.connectors.clickhouse import ClickHouseConnector from app.connectors.parquet import ParquetConnector +from app.connectors.csv import CSVConnector from app.models.datasource import DataSource from app.core.files import resolve_upload_file_path @@ -37,6 +38,10 @@ def _get_cached_connector(ds_type: str, config_json: str): elif ds_type == "parquet": file_path = str(resolve_upload_file_path(config.get("file_path"))) return ParquetConnector(file_path=file_path) + + elif ds_type == "csv": + file_path = str(resolve_upload_file_path(config.get("file_path"))) + return CSVConnector(file_path=file_path) else: raise ValueError(f"Unsupported data source type: {ds_type}") diff --git a/backend/app/connectors/parquet.py b/backend/app/connectors/parquet.py index 11a6f57..c21adf8 100644 --- a/backend/app/connectors/parquet.py +++ b/backend/app/connectors/parquet.py @@ -31,7 +31,7 @@ class ParquetConnector: finally: conn.close() - def get_schema(self) -> Dict[str, List[str]]: + def get_schema(self) -> Dict[str, List[Dict[str, str]]]: conn = duckdb.connect(":memory:") table_name = os.path.splitext(os.path.basename(self.file_path))[0] conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')") @@ -39,7 +39,7 @@ class ParquetConnector: try: # Get columns columns = conn.execute(f"DESCRIBE {table_name}").fetchall() - schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]} + schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]} return schema except Exception as e: print(f"Error getting schema: {e}") diff --git a/backend/app/connectors/postgres.py b/backend/app/connectors/postgres.py index 49a96b9..d406cdd 100644 --- a/backend/app/connectors/postgres.py +++ b/backend/app/connectors/postgres.py @@ -22,6 +22,9 @@ class PostgresConnector: return [dict(row._mapping) for row in result] def get_schema(self): + if self.engine.dialect.name == "sqlite": + return self._get_sqlite_schema() + query = """ SELECT table_name, column_name, data_type FROM information_schema.columns @@ -35,12 +38,27 @@ class PostgresConnector: table = row['table_name'] if table not in schema: schema[table] = [] - schema[table].append(f"{row['column_name']} ({row['data_type']})") + schema[table].append({"name": row['column_name'], "type": row['data_type']}) return schema except Exception as e: print(f"Error getting schema: {e}") return {} + def _get_sqlite_schema(self): + try: + from sqlalchemy import inspect + inspector = inspect(self.engine) + schema = {} + for table_name in inspector.get_table_names(): + columns = [] + for col in inspector.get_columns(table_name): + columns.append({"name": col['name'], "type": str(col['type'])}) + schema[table_name] = columns + return schema + except Exception as e: + print(f"Error getting SQLite schema: {e}") + return {} + def test_connection(self) -> bool: try: with self.engine.connect() as connection: diff --git a/backend/app/schemas/mdl.py b/backend/app/schemas/mdl.py new file mode 100644 index 0000000..1f9dbe5 --- /dev/null +++ b/backend/app/schemas/mdl.py @@ -0,0 +1,115 @@ +from typing import List, Optional, Dict, Any, Union, Literal +from pydantic import BaseModel, Field + +# Common Types +AccessControlOperator = Literal[ + "EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN", + "GREATER_THAN_OR_EQUALS", "LESS_THAN_OR_EQUALS" +] + +JoinType = Literal["ONE_TO_ONE", "ONE_TO_MANY", "MANY_TO_ONE", "MANY_TO_MANY"] + +# Column Definitions +class SessionProperty(BaseModel): + name: str + required: bool + defaultExpr: Optional[str] = None + +class AccessControlThreshold(BaseModel): + value: str + dataType: Literal["NUMERIC", "STRING"] + +class ColumnAccessControl(BaseModel): + name: str + operator: AccessControlOperator + requiredProperties: List[SessionProperty] + threshold: Optional[AccessControlThreshold] = None + +class Column(BaseModel): + name: str + type: str + relationship: Optional[str] = None + isCalculated: bool = False + notNull: bool = False + expression: Optional[str] = None + isHidden: bool = False + columnLevelAccessControl: Optional[ColumnAccessControl] = None + properties: Dict[str, str] = Field(default_factory=dict) + +# Model Definitions +class TableReference(BaseModel): + catalog: Optional[str] = None + schema_: Optional[str] = Field(None, alias="schema") + table: str + +class RowLevelAccessControl(BaseModel): + name: str + requiredProperties: List[SessionProperty] + condition: str + +class Model(BaseModel): + name: str + tableReference: Optional[TableReference] = None + refSql: Optional[str] = None + baseObject: Optional[str] = None + columns: List[Column] = Field(default_factory=list) + primaryKey: Optional[str] = None + cached: bool = False + refreshTime: Optional[str] = None + rowLevelAccessControls: List[RowLevelAccessControl] = Field(default_factory=list) + properties: Dict[str, Any] = Field(default_factory=dict) + +# Relationship Definitions +class Relationship(BaseModel): + name: str + models: List[str] # minItems: 2, maxItems: 2 + joinType: JoinType + condition: str + properties: Dict[str, Any] = Field(default_factory=dict) + +# Metric Definitions +class MetricTimeGrain(BaseModel): + name: str + refColumn: str + dateParts: List[str] + +class Metric(BaseModel): + name: str + baseObject: str + dimension: List[Column] = Field(default_factory=list) + measure: List[Column] = Field(default_factory=list) + timeGrain: List[MetricTimeGrain] = Field(default_factory=list) + cached: bool = False + refreshTime: Optional[str] = None + properties: Dict[str, Any] = Field(default_factory=dict) + +# View Definitions +class View(BaseModel): + name: str + statement: str + properties: Dict[str, Any] = Field(default_factory=dict) + +# Enum Definitions +class EnumValue(BaseModel): + name: str + value: Optional[str] = None + properties: Dict[str, Any] = Field(default_factory=dict) + +class EnumDefinition(BaseModel): + name: str + values: List[EnumValue] + properties: Dict[str, Any] = Field(default_factory=dict) + +# Main Manifest +class MDLManifest(BaseModel): + catalog: str + schema_: str = Field(..., alias="schema") # 'schema' is a reserved word in Pydantic v1/Python, aliasing + dataSource: Optional[str] = None + models: List[Model] = Field(default_factory=list) + relationships: List[Relationship] = Field(default_factory=list) + metrics: List[Metric] = Field(default_factory=list) + views: List[View] = Field(default_factory=list) + enumDefinitions: List[EnumDefinition] = Field(default_factory=list) + + class Config: + populate_by_name = True diff --git a/backend/app/services/mdl.py b/backend/app/services/mdl.py new file mode 100644 index 0000000..e43d5c0 --- /dev/null +++ b/backend/app/services/mdl.py @@ -0,0 +1,122 @@ +import json +import os +from pathlib import Path +from typing import Optional, Dict, Any, List +from app.models.datasource import DataSource +from app.schemas.mdl import MDLManifest, Model, Column, TableReference +from app.connectors.factory import get_connector +from app.database import SessionLocal + +# Assuming running from backend/ directory +MDL_STORAGE_PATH = Path("data/mdl") + +class MDLService: + @staticmethod + def _get_mdl_path(datasource_id: int) -> Path: + MDL_STORAGE_PATH.mkdir(parents=True, exist_ok=True) + return MDL_STORAGE_PATH / f"{datasource_id}.json" + + @staticmethod + def get_raw_schema(datasource: DataSource) -> Dict[str, List[Dict[str, str]]]: + connector = get_connector(datasource) + try: + return connector.get_schema() + except Exception as e: + print(f"Error fetching schema for DS {datasource.id}: {e}") + return {} + + @staticmethod + def generate_default_mdl( + datasource: DataSource, + selected_tables: Optional[List[str]] = None, + selected_columns: Optional[Dict[str, List[str]]] = None, + ) -> MDLManifest: + raw_schema = MDLService.get_raw_schema(datasource) + + models = [] + for table_name, columns in raw_schema.items(): + if selected_tables is not None and table_name not in selected_tables: + continue + + model_cols = [] + for col_info in columns: + if isinstance(col_info, dict): + name = col_info.get("name", "UNKNOWN") + type_ = col_info.get("type", "UNKNOWN") + elif isinstance(col_info, str): + # Fallback for old string format "name (type)" + if "(" in col_info and col_info.endswith(")"): + parts = col_info.rsplit(" (", 1) + if len(parts) == 2: + name = parts[0] + type_ = parts[1][:-1] + else: + name = col_info + type_ = "UNKNOWN" + else: + name = col_info + type_ = "UNKNOWN" + else: + name = str(col_info) + type_ = "UNKNOWN" + + if selected_columns is not None: + allowed = selected_columns.get(table_name, []) + if allowed and name not in allowed: + continue + + model_cols.append(Column(name=name, type=type_)) + + if not model_cols: + continue + + models.append(Model( + name=table_name, + tableReference=TableReference(table=table_name), + columns=model_cols + )) + + return MDLManifest( + catalog="default", + schema="public", # Default schema, might need adjustment based on datasource config + dataSource=datasource.type.upper(), + models=models + ) + + @staticmethod + def get_mdl(datasource_id: int) -> Optional[MDLManifest]: + path = MDLService._get_mdl_path(datasource_id) + if path.exists(): + try: + with open(path, "r") as f: + data = json.load(f) + # Pydantic v2 compatible + return MDLManifest.model_validate(data) + except Exception as e: + print(f"Error loading MDL for {datasource_id}: {e}") + return None + return None + + @staticmethod + def save_mdl(datasource_id: int, mdl: MDLManifest): + path = MDLService._get_mdl_path(datasource_id) + with open(path, "w") as f: + f.write(mdl.model_dump_json(indent=2, by_alias=True)) + + @staticmethod + def get_or_create_mdl(datasource_id: int) -> MDLManifest: + mdl = MDLService.get_mdl(datasource_id) + if mdl: + return mdl + + # Generate new + db = SessionLocal() + try: + ds = db.query(DataSource).filter(DataSource.id == datasource_id).first() + if not ds: + raise ValueError(f"DataSource {datasource_id} not found") + mdl = MDLService.generate_default_mdl(ds) + MDLService.save_mdl(datasource_id, mdl) + return mdl + finally: + db.close() diff --git a/backend/backend/dataclaw.db b/backend/backend/dataclaw.db new file mode 100644 index 0000000..e69de29 diff --git a/backend/dataclaw.db b/backend/dataclaw.db index d2c3199..3a2895f 100644 Binary files a/backend/dataclaw.db and b/backend/dataclaw.db differ diff --git a/backend/main.py b/backend/main.py index c8c4a66..919d0b3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,7 +7,7 @@ import asyncio import json from datetime import datetime -from app.api import upload, llm, skills, users, datasources, projects +from app.api import upload, llm, skills, users, datasources, projects, semantic from app.connectors.postgres import postgres_connector from app.connectors.clickhouse import clickhouse_connector from app.core.nanobot import nanobot_service @@ -38,6 +38,7 @@ app.include_router(skills.router, prefix="/api/v1") app.include_router(users.router, prefix="/api/v1") app.include_router(projects.router, prefix="/api/v1") app.include_router(datasources.router, prefix="/api/v1") +app.include_router(semantic.router, prefix="/api/v1") @app.on_event("startup") async def startup_event(): diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 6aa0987..99f90d1 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -10,6 +10,7 @@ import { Projects } from "./pages/Projects"; import { Login } from "./pages/Login"; import { ModelConfigs } from "./pages/ModelConfigs"; import { DataSources } from "./pages/DataSources"; +import { Modeling } from "./pages/Modeling"; import { useAuthStore } from "./store/authStore"; // Protected Route Component @@ -115,6 +116,14 @@ function App() { } /> + + + + + + + } /> ); diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index c14801d..8bd0af1 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -2,7 +2,7 @@ import { useState, useRef, useEffect } from "react"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { ScrollArea } from "@/components/ui/scroll-area"; -import { User, Loader2, Sparkles, ArrowUp, ChevronDown, Paperclip, Check, X, File as FileIcon, Square, Plus, Database, Wand2, Search, Zap, LayoutGrid, CheckCircle2, Table, XCircle } from "lucide-react"; +import { User, Loader2, Sparkles, ArrowUp, ChevronDown, Paperclip, Check, X, Square, Plus, Database, Wand2, Search, Zap, LayoutGrid, CheckCircle2, Table, XCircle } from "lucide-react"; import { api } from "@/lib/api"; import { type ChartSpec } from "@/store/visualizationStore"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; @@ -70,7 +70,7 @@ interface SessionData { export function ChatInterface() { const [messages, setMessages] = useState([]); const [input, setInput] = useState(""); - const [selectedDataSource, setSelectedDataSource] = useState("postgres-main"); + const [selectedDataSource, setSelectedDataSource] = useState(""); const [availableSkills, setAvailableSkills] = useState([]); const [selectedSkillIds, setSelectedSkillIds] = useState([]); const [isMenuOpen, setIsMenuOpen] = useState(false); @@ -104,6 +104,7 @@ export function ChatInterface() { useEffect(() => { if (currentProject) { + setSelectedDataSource(""); fetchDataSources(); } }, [currentProject]); @@ -114,14 +115,8 @@ export function ChatInterface() { const data = await api.get>(`/api/v1/datasources?project_id=${currentProject.id}`); const projectSources = data.map(d => ({ id: `ds:${d.id}`, name: d.name })); setAvailableDataSources(projectSources); - - // Default select the first one if current selection is not in the list - if (projectSources.length > 0) { - if (!selectedDataSource.startsWith("ds:") || !projectSources.find(ds => ds.id === selectedDataSource)) { - setSelectedDataSource(projectSources[0].id); - } - } else { - setSelectedDataSource("upload"); // Default to upload if no data sources + if (selectedDataSource && !projectSources.find(ds => ds.id === selectedDataSource)) { + setSelectedDataSource(""); } } catch (e) { console.error("Failed to fetch data sources", e); @@ -141,6 +136,8 @@ export function ChatInterface() { useEffect(() => { const fetchSessionData = async () => { setIsLoading(true); + setSelectedDataSource(""); + setSelectedSkillIds([]); try { const data = await api.get(`/nanobot/sessions/${activeSessionKey}`); if (data.messages && data.messages.length > 0) { @@ -157,11 +154,6 @@ export function ChatInterface() { const restoredFile = data.metadata?.active_data_file || null; setActiveDataFile(restoredFile); setAttachedFile(null); - if (restoredFile) { - setSelectedDataSource("upload-main"); - } else if (selectedDataSource.startsWith("upload")) { - setSelectedDataSource("postgres-main"); - } } catch (e) { console.error("Failed to fetch session messages", e); setMessages([]); @@ -245,7 +237,7 @@ export function ChatInterface() { }; setAttachedFile(uploadedFile); setActiveDataFile(uploadedFile); - setSelectedDataSource("upload-main"); + setSelectedDataSource(""); await syncSessionFileContext(uploadedFile); } catch (error) { console.error("File upload error:", error); @@ -261,12 +253,37 @@ export function ChatInterface() { const handleRemoveFile = async () => { setAttachedFile(null); setActiveDataFile(null); - if (selectedDataSource.startsWith("upload")) { - setSelectedDataSource("postgres-main"); - } await syncSessionFileContext(null); }; + const selectedDataSourceName = availableDataSources.find(ds => ds.id === selectedDataSource)?.name || ""; + const selectedSkills = availableSkills.filter(skill => selectedSkillIds.includes(skill.id)); + + const renderActiveSelections = () => { + if (!selectedDataSource && selectedSkills.length === 0) return null; + return ( +
+
+ {selectedDataSource ? ( +
+ + {`数据源:${selectedDataSourceName}`} +
+ ) : null} + {selectedSkills.map((skill) => ( +
+ + {`Skill:${skill.name}`} +
+ ))} +
+
+ ); + }; + const renderFileCard = () => { const file = attachedFile || activeDataFile; if (!file) return null; @@ -328,7 +345,7 @@ export function ChatInterface() { }; const handleSend = async () => { - if (!input.trim() || isLoading) return; + if (!input.trim() || isLoading || !selectedDataSource) return; const newMessage: Message = { id: Date.now().toString(), role: 'user', content: input }; setMessages(prev => [...prev, newMessage]); @@ -357,16 +374,9 @@ export function ChatInterface() { const token = localStorage.getItem("token"); const effectiveModelId = selectedModelId || currentModel?.id || ""; - // Correctly parse source from selectedDataSource (could be 'ds:ID', 'upload', or legacy 'postgres-main') - let source = selectedDataSource; - if (selectedDataSource.includes("-")) { - source = selectedDataSource.split("-")[0]; - } + let source = selectedDataSource; - const useUploadSource = Boolean( - currentAttachedFile?.url?.startsWith("local://") || - (source === "upload" && activeDataFile?.url?.startsWith("local://")) - ); + const useUploadSource = Boolean(currentAttachedFile?.url?.startsWith("local://")); if (useUploadSource) { source = "upload"; } @@ -580,6 +590,7 @@ export function ChatInterface() {
{renderFileCard()} + {renderActiveSelections()}
@@ -615,26 +626,16 @@ export function ChatInterface() { {selectedDataSource === ds.id && } ))} - -
- {(selectedDataSource === 'upload' || selectedDataSource === 'upload-main') && } - + )}
@@ -652,26 +653,21 @@ export function ChatInterface() { @@ -713,7 +709,7 @@ export function ChatInterface() {
))} - -
- {(selectedDataSource === 'upload' || selectedDataSource === 'upload-main') && } - + )}
@@ -870,26 +857,21 @@ export function ChatInterface() { @@ -931,7 +913,7 @@ export function ChatInterface() {
- + + diff --git a/frontend/src/pages/Modeling.tsx b/frontend/src/pages/Modeling.tsx new file mode 100644 index 0000000..03c3637 --- /dev/null +++ b/frontend/src/pages/Modeling.tsx @@ -0,0 +1,514 @@ +import { useEffect, useState } from "react"; +import { useParams, useNavigate } from "react-router-dom"; +import { api } from "../lib/api"; +import { Button } from "../components/ui/button"; +import { Card, CardContent, CardHeader, CardTitle } from "../components/ui/card"; +import { Label } from "../components/ui/label"; +import { ScrollArea } from "../components/ui/scroll-area"; +import { Dialog, DialogContent, DialogHeader, DialogTitle } from "../components/ui/dialog"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "../components/ui/table"; +import { ArrowLeft, Table as TableIcon, Network } from "lucide-react"; + +interface RawSchema { + [table: string]: { name: string; type: string }[]; +} + +interface Column { + name: string; + type: string; + isCalculated: boolean; + relationship?: string; + expression?: string; + properties?: Record; +} + +interface Model { + name: string; + columns: Column[]; + primaryKey?: string; + properties?: Record; +} + +interface Relationship { + name: string; + models: string[]; + joinType: string; + condition: string; +} + +interface MDLManifest { + catalog: string; + schema: string; + dataSource: string; + models: Model[]; + relationships: Relationship[]; +} + +interface ModelDetailResponse { + model: { + name: string; + tableReference?: { + table: string; + schema?: string; + catalog?: string; + } | null; + primaryKey?: string; + properties?: Record; + columns: Column[]; + }; + relationships: { + name: string; + models: string[]; + joinType: string; + condition: string; + properties?: Record; + }[]; + preview_rows: Record[]; +} + +export function Modeling() { + const { id } = useParams<{ id: string }>(); + const navigate = useNavigate(); + const [loading, setLoading] = useState(true); + const [schema, setSchema] = useState(null); + const [mdl, setMdl] = useState(null); + const [selectedTables, setSelectedTables] = useState([]); + const [selectedColumns, setSelectedColumns] = useState>({}); + const [expandedTables, setExpandedTables] = useState>({}); + const [step, setStep] = useState<"select" | "view">("select"); + const [detailOpen, setDetailOpen] = useState(false); + const [detailLoading, setDetailLoading] = useState(false); + const [modelDetail, setModelDetail] = useState(null); + + useEffect(() => { + fetchInitialData(); + }, [id]); + + const initSelectionFromSchema = (schemaRes: RawSchema) => { + const tableNames = Object.keys(schemaRes); + const columnsMap: Record = {}; + const expanded: Record = {}; + for (const tableName of tableNames) { + columnsMap[tableName] = schemaRes[tableName].map((c) => c.name); + expanded[tableName] = true; + } + setSchema(schemaRes); + setSelectedTables(tableNames); + setSelectedColumns(columnsMap); + setExpandedTables(expanded); + }; + + const fetchSchemaOnly = async () => { + const schemaRes = await api.get(`/api/v1/semantic/${id}/schema`) as RawSchema; + initSelectionFromSchema(schemaRes); + setStep("select"); + }; + + const fetchInitialData = async () => { + try { + setLoading(true); + const mdlRes = await api.get(`/api/v1/semantic/${id}`) as any; + if (mdlRes && mdlRes.models && mdlRes.models.length > 0) { + setMdl(mdlRes as MDLManifest); + setStep("view"); + } else { + await fetchSchemaOnly(); + } + } catch (error) { + console.error("Failed to fetch modeling data:", error); + try { + await fetchSchemaOnly(); + } catch (e) { + console.error("Failed to fetch schema:", e); + } + } finally { + setLoading(false); + } + }; + + const handleGenerate = async () => { + try { + setLoading(true); + const res = await api.post(`/api/v1/semantic/${id}/generate`, { + selected_tables: selectedTables, + selected_columns: Object.fromEntries( + selectedTables.map((table) => [table, selectedColumns[table] ?? []]) + ), + }) as MDLManifest; + setMdl(res); + setStep("view"); + } catch (error) { + console.error("Failed to generate MDL:", error); + } finally { + setLoading(false); + } + }; + + const toggleTable = (table: string) => { + setSelectedTables((prev) => + prev.includes(table) ? prev.filter((t) => t !== table) : [...prev, table] + ); + if (!schema) return; + if (!selectedTables.includes(table) && (!selectedColumns[table] || selectedColumns[table].length === 0)) { + setSelectedColumns((prev) => ({ + ...prev, + [table]: schema[table].map((c) => c.name), + })); + } + }; + + const toggleColumn = (table: string, column: string) => { + setSelectedColumns((prev) => { + const current = prev[table] ?? []; + const has = current.includes(column); + const next = has ? current.filter((c) => c !== column) : [...current, column]; + return { ...prev, [table]: next }; + }); + setSelectedTables((prev) => { + const exists = prev.includes(table); + const current = selectedColumns[table] ?? []; + const has = current.includes(column); + const nextLen = has ? current.length - 1 : current.length + 1; + if (nextLen <= 0) { + return prev.filter((t) => t !== table); + } + if (!exists) { + return [...prev, table]; + } + return prev; + }); + }; + + const toggleExpandTable = (table: string) => { + setExpandedTables((prev) => ({ ...prev, [table]: !prev[table] })); + }; + + const handleSelectAll = () => { + if (!schema) return; + const tableNames = Object.keys(schema); + setSelectedTables(tableNames); + setSelectedColumns( + Object.fromEntries( + tableNames.map((table) => [table, schema[table].map((c) => c.name)]) + ) + ); + }; + + const handleClearAll = () => { + setSelectedTables([]); + setSelectedColumns({}); + }; + + const handleReselectTables = async () => { + try { + setLoading(true); + await fetchSchemaOnly(); + } finally { + setLoading(false); + } + }; + + const openModelDetail = async (modelName: string) => { + try { + setDetailOpen(true); + setDetailLoading(true); + const detail = await api.get( + `/api/v1/semantic/${id}/models/${encodeURIComponent(modelName)}?limit=10` + ); + setModelDetail(detail); + } catch (error) { + console.error("Failed to fetch model detail:", error); + setModelDetail(null); + } finally { + setDetailLoading(false); + } + }; + + if (loading) { + return
Loading modeling data...
; + } + + return ( +
+ {/* Header */} +
+
+ +
+

Data Modeling

+

+ DataSource ID: {id} • {step === "select" ? "Select Tables" : "Entity Relationship Diagram"} +

+
+
+ {step === "view" && ( +
+ +
+ )} +
+ + {/* Content */} +
+ {step === "select" ? ( +
+ + + Select tables to create data models +

+ Choose the tables you want to include in your semantic model. +

+
+ +
+
+ {selectedTables.length} / {schema ? Object.keys(schema).length : 0} selected +
+
+ + +
+
+ + +
+ {schema && Object.keys(schema).map((table) => ( +
+
+
+ toggleTable(table)} + /> + +
+ +
+ {expandedTables[table] && ( +
+ {schema[table].map((col) => ( + + ))} +
+ )} +
+ ))} +
+
+ +
+ +
+
+
+
+ ) : ( +
+ {/* Sidebar List */} + + + Models ({mdl?.models.length}) + + +
+ {mdl?.models.map((model) => ( +
openModelDetail(model.name)} + > + + {model.name} +
+ ))} +
+
+
+ + {/* Canvas Area (Simulated) */} +
+
+ +
+ {mdl?.models.map((model) => ( + openModelDetail(model.name)}> + +
+ + {model.name} +
+
+ +
+ {model.columns.map((col) => ( +
+ {col.name} + {col.type} +
+ ))} +
+ {/* Show Relationships if any */} + {mdl.relationships.filter(r => r.models.includes(model.name)).length > 0 && ( +
+
+ Relationships +
+ {mdl.relationships + .filter(r => r.models.includes(model.name)) + .map(r => { + const other = r.models.find(m => m !== model.name); + return ( +
+ ⟷ {other} +
+ ); + }) + } +
+ )} +
+
+ ))} +
+
+
+ )} +
+ + + + {modelDetail?.model?.name ?? "Model Detail"} + + {detailLoading ? ( +
Loading model detail...
+ ) : !modelDetail ? ( +
No metadata available.
+ ) : ( +
+
+
Columns Metadata
+ + + + Name + Type + Description + + + + {modelDetail.model.columns.map((col) => ( + + {col.name} + {col.type} + {String(col.properties?.description ?? "-")} + + ))} + +
+
+
+
Relationships ({modelDetail.relationships.length})
+ + + + Name + Models + Type + Condition + + + + {modelDetail.relationships.map((rel) => ( + + {rel.name} + {rel.models.join(" ↔ ")} + {rel.joinType} + {rel.condition} + + ))} + +
+
+
+
Data Preview (Top 10)
+ {modelDetail.preview_rows.length === 0 ? ( +
No preview data.
+ ) : ( + + + + {Object.keys(modelDetail.preview_rows[0]).map((key) => ( + {key} + ))} + + + + {modelDetail.preview_rows.map((row, idx) => ( + + {Object.keys(modelDetail.preview_rows[0]).map((key) => ( + {String(row[key] ?? "")} + ))} + + ))} + +
+ )} +
+
+ )} +
+
+
+ ); +}