feat: add modelling layer
This commit is contained in:
@@ -24,6 +24,7 @@ from app.agent.chart import generate_chart
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
from app.core.files import resolve_upload_file_path
|
||||
from app.services.mdl import MDLService
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -116,11 +117,11 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
|
||||
return pd.read_parquet(file_path)
|
||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||
|
||||
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[Dict[str, str]]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.register("uploaded_file", df)
|
||||
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
|
||||
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
|
||||
schema = {"uploaded_file": [{"name": col[0], "type": col[1]} for col in columns]}
|
||||
conn.close()
|
||||
return schema
|
||||
|
||||
@@ -167,7 +168,7 @@ def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
)
|
||||
return source
|
||||
|
||||
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
|
||||
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[Dict[str, str]]]]:
|
||||
key = _build_schema_cache_key(source, connector)
|
||||
now = time.time()
|
||||
with _cache_lock:
|
||||
@@ -176,7 +177,7 @@ def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[s
|
||||
return cached["schema"]
|
||||
return None
|
||||
|
||||
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
|
||||
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[str, str]]]) -> None:
|
||||
key = _build_schema_cache_key(source, connector)
|
||||
with _cache_lock:
|
||||
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
||||
@@ -247,6 +248,37 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# Try to load MDL context
|
||||
mdl_context = ""
|
||||
if request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
mdl = MDLService.get_mdl(ds_id)
|
||||
if mdl:
|
||||
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
|
||||
|
||||
mdl_lines.append("MODELS:")
|
||||
for model in mdl.models:
|
||||
table_ref = model.tableReference.table if model.tableReference else model.name
|
||||
desc = f" - Description: {model.properties.get('description', '')}" if model.properties.get('description') else ""
|
||||
mdl_lines.append(f"- Model: {model.name} (Table: {table_ref}){desc}")
|
||||
|
||||
if model.columns:
|
||||
mdl_lines.append(" Columns:")
|
||||
for col in model.columns:
|
||||
col_desc = f" ({col.properties.get('description')})" if col.properties.get('description') else ""
|
||||
expr = f" [Calculated: {col.expression}]" if col.isCalculated else ""
|
||||
mdl_lines.append(f" - {col.name} ({col.type}){col_desc}{expr}")
|
||||
|
||||
if mdl.relationships:
|
||||
mdl_lines.append("\nRELATIONSHIPS:")
|
||||
for rel in mdl.relationships:
|
||||
mdl_lines.append(f"- {rel.name}: {rel.joinType} between {rel.models} ON {rel.condition}")
|
||||
|
||||
mdl_context = "\n".join(mdl_lines)
|
||||
except Exception as e:
|
||||
print(f"Failed to load MDL: {e}")
|
||||
|
||||
# 2. Get the active LLM config
|
||||
llm_configs = load_llm_config()
|
||||
active_config = next((c for c in llm_configs if c.get("is_active")), None)
|
||||
@@ -270,6 +302,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
user_prompt = f"""
|
||||
### DATABASE SCHEMA ###
|
||||
{schema_str}
|
||||
{mdl_context}
|
||||
|
||||
### INPUTS ###
|
||||
User's Question: {request.query}
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.mdl import MDLManifest
|
||||
from app.services.mdl import MDLService
|
||||
from app.connectors.factory import get_connector
|
||||
|
||||
router = APIRouter(tags=["semantic"])
|
||||
|
||||
class GenerateMDLRequest(BaseModel):
|
||||
selected_tables: Optional[List[str]] = None
|
||||
selected_columns: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
class ModelDetailResponse(BaseModel):
|
||||
model: Dict[str, Any]
|
||||
relationships: List[Dict[str, Any]]
|
||||
preview_rows: List[Dict[str, Any]]
|
||||
|
||||
def _normalize_query_result(results: Any) -> List[Dict[str, Any]]:
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
return results
|
||||
if results and isinstance(results[0], (list, tuple)):
|
||||
return [dict(enumerate(row)) for row in results]
|
||||
return []
|
||||
if isinstance(results, tuple) and len(results) == 2:
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
return [dict(zip(col_names, row)) for row in rows]
|
||||
return []
|
||||
|
||||
@router.get("/semantic/{datasource_id}/schema", response_model=Dict[str, List[Dict[str, str]]])
|
||||
def get_semantic_schema(datasource_id: int, db: Session = Depends(get_db)):
|
||||
# Check if datasource exists
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
try:
|
||||
return MDLService.get_raw_schema(ds)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/semantic/{datasource_id}", response_model=MDLManifest)
|
||||
def get_semantic_model(datasource_id: int, db: Session = Depends(get_db)):
|
||||
# Check if datasource exists
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
# Get or generate MDL
|
||||
try:
|
||||
mdl = MDLService.get_or_create_mdl(datasource_id)
|
||||
return mdl
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.put("/semantic/{datasource_id}", response_model=MDLManifest)
|
||||
def update_semantic_model(datasource_id: int, mdl: MDLManifest, db: Session = Depends(get_db)):
|
||||
# Check if datasource exists
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
try:
|
||||
MDLService.save_mdl(datasource_id, mdl)
|
||||
return mdl
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/semantic/{datasource_id}/generate", response_model=MDLManifest)
|
||||
def regenerate_semantic_model(datasource_id: int, request: Optional[GenerateMDLRequest] = None, db: Session = Depends(get_db)):
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
try:
|
||||
selected_tables = request.selected_tables if request else None
|
||||
selected_columns = request.selected_columns if request else None
|
||||
mdl = MDLService.generate_default_mdl(
|
||||
ds,
|
||||
selected_tables=selected_tables,
|
||||
selected_columns=selected_columns,
|
||||
)
|
||||
MDLService.save_mdl(datasource_id, mdl)
|
||||
return mdl
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/semantic/{datasource_id}/models/{model_name}", response_model=ModelDetailResponse)
|
||||
def get_model_detail(datasource_id: int, model_name: str, limit: int = 10, db: Session = Depends(get_db)):
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="DataSource not found")
|
||||
|
||||
mdl = MDLService.get_or_create_mdl(datasource_id)
|
||||
model = next((m for m in mdl.models if m.name == model_name), None)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
relationships = [
|
||||
{
|
||||
"name": rel.name,
|
||||
"models": rel.models,
|
||||
"joinType": rel.joinType,
|
||||
"condition": rel.condition,
|
||||
"properties": rel.properties,
|
||||
}
|
||||
for rel in mdl.relationships
|
||||
if model_name in rel.models
|
||||
]
|
||||
|
||||
preview_rows: List[Dict[str, Any]] = []
|
||||
try:
|
||||
connector = get_connector(ds)
|
||||
table_name = model.tableReference.table if model.tableReference else model.name
|
||||
query = f'SELECT * FROM "{table_name}" LIMIT {max(1, min(limit, 100))}'
|
||||
raw = connector.execute_query(query)
|
||||
preview_rows = _normalize_query_result(raw)
|
||||
except Exception:
|
||||
preview_rows = []
|
||||
|
||||
model_payload = {
|
||||
"name": model.name,
|
||||
"tableReference": model.tableReference.model_dump(by_alias=True) if model.tableReference else None,
|
||||
"primaryKey": model.primaryKey,
|
||||
"properties": model.properties,
|
||||
"columns": [c.model_dump(by_alias=True) for c in model.columns],
|
||||
}
|
||||
|
||||
return ModelDetailResponse(
|
||||
model=model_payload,
|
||||
relationships=relationships,
|
||||
preview_rows=preview_rows,
|
||||
)
|
||||
@@ -33,7 +33,7 @@ class ClickHouseConnector:
|
||||
table = row[0]
|
||||
if table not in schema:
|
||||
schema[table] = []
|
||||
schema[table].append(f"{row[1]} ({row[2]})")
|
||||
schema[table].append({"name": row[1], "type": row[2]})
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any
|
||||
import os
|
||||
from app.core.files import resolve_upload_file_path
|
||||
|
||||
class CSVConnector:
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
if not os.path.exists(self.file_path):
|
||||
raise FileNotFoundError(f"CSV file not found: {self.file_path}")
|
||||
|
||||
def _get_table_name(self) -> str:
|
||||
# Normalize table name to be SQL safe-ish
|
||||
base = os.path.splitext(os.path.basename(self.file_path))[0]
|
||||
# Replace non-alphanumeric chars with underscore
|
||||
safe_name = "".join([c if c.isalnum() else "_" for c in base])
|
||||
# Ensure it doesn't start with a number
|
||||
if safe_name and safe_name[0].isdigit():
|
||||
safe_name = f"t_{safe_name}"
|
||||
return safe_name
|
||||
|
||||
def execute_query(self, query: str) -> List[Dict[str, Any]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
table_name = self._get_table_name()
|
||||
|
||||
# Register the csv file as a view
|
||||
# read_csv_auto is powerful
|
||||
try:
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')")
|
||||
|
||||
# Execute the user query
|
||||
# The query should rely on the table name provided in schema
|
||||
df = conn.execute(query).df()
|
||||
return df.to_dict(orient="records")
|
||||
except Exception as e:
|
||||
print(f"CSV Query Error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
table_name = self._get_table_name()
|
||||
|
||||
try:
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')")
|
||||
|
||||
# Get columns
|
||||
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
|
||||
# col[0] is name, col[1] is type
|
||||
schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]}
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute(f"SELECT * FROM read_csv_auto('{self.file_path}') LIMIT 1")
|
||||
conn.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"CSV Connection Error: {e}")
|
||||
return False
|
||||
@@ -4,6 +4,7 @@ import functools
|
||||
from app.connectors.postgres import PostgresConnector
|
||||
from app.connectors.clickhouse import ClickHouseConnector
|
||||
from app.connectors.parquet import ParquetConnector
|
||||
from app.connectors.csv import CSVConnector
|
||||
from app.models.datasource import DataSource
|
||||
from app.core.files import resolve_upload_file_path
|
||||
|
||||
@@ -37,6 +38,10 @@ def _get_cached_connector(ds_type: str, config_json: str):
|
||||
elif ds_type == "parquet":
|
||||
file_path = str(resolve_upload_file_path(config.get("file_path")))
|
||||
return ParquetConnector(file_path=file_path)
|
||||
|
||||
elif ds_type == "csv":
|
||||
file_path = str(resolve_upload_file_path(config.get("file_path")))
|
||||
return CSVConnector(file_path=file_path)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source type: {ds_type}")
|
||||
|
||||
@@ -31,7 +31,7 @@ class ParquetConnector:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_schema(self) -> Dict[str, List[str]]:
|
||||
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
|
||||
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
|
||||
@@ -39,7 +39,7 @@ class ParquetConnector:
|
||||
try:
|
||||
# Get columns
|
||||
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
|
||||
schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]}
|
||||
schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]}
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
|
||||
@@ -22,6 +22,9 @@ class PostgresConnector:
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
def get_schema(self):
|
||||
if self.engine.dialect.name == "sqlite":
|
||||
return self._get_sqlite_schema()
|
||||
|
||||
query = """
|
||||
SELECT table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
@@ -35,12 +38,27 @@ class PostgresConnector:
|
||||
table = row['table_name']
|
||||
if table not in schema:
|
||||
schema[table] = []
|
||||
schema[table].append(f"{row['column_name']} ({row['data_type']})")
|
||||
schema[table].append({"name": row['column_name'], "type": row['data_type']})
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
|
||||
def _get_sqlite_schema(self):
|
||||
try:
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(self.engine)
|
||||
schema = {}
|
||||
for table_name in inspector.get_table_names():
|
||||
columns = []
|
||||
for col in inspector.get_columns(table_name):
|
||||
columns.append({"name": col['name'], "type": str(col['type'])})
|
||||
schema[table_name] = columns
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting SQLite schema: {e}")
|
||||
return {}
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
with self.engine.connect() as connection:
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from typing import List, Optional, Dict, Any, Union, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Common Types
|
||||
AccessControlOperator = Literal[
|
||||
"EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN",
|
||||
"GREATER_THAN_OR_EQUALS", "LESS_THAN_OR_EQUALS"
|
||||
]
|
||||
|
||||
JoinType = Literal["ONE_TO_ONE", "ONE_TO_MANY", "MANY_TO_ONE", "MANY_TO_MANY"]
|
||||
|
||||
# Column Definitions
|
||||
class SessionProperty(BaseModel):
|
||||
name: str
|
||||
required: bool
|
||||
defaultExpr: Optional[str] = None
|
||||
|
||||
class AccessControlThreshold(BaseModel):
|
||||
value: str
|
||||
dataType: Literal["NUMERIC", "STRING"]
|
||||
|
||||
class ColumnAccessControl(BaseModel):
|
||||
name: str
|
||||
operator: AccessControlOperator
|
||||
requiredProperties: List[SessionProperty]
|
||||
threshold: Optional[AccessControlThreshold] = None
|
||||
|
||||
class Column(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
relationship: Optional[str] = None
|
||||
isCalculated: bool = False
|
||||
notNull: bool = False
|
||||
expression: Optional[str] = None
|
||||
isHidden: bool = False
|
||||
columnLevelAccessControl: Optional[ColumnAccessControl] = None
|
||||
properties: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
# Model Definitions
|
||||
class TableReference(BaseModel):
|
||||
catalog: Optional[str] = None
|
||||
schema_: Optional[str] = Field(None, alias="schema")
|
||||
table: str
|
||||
|
||||
class RowLevelAccessControl(BaseModel):
|
||||
name: str
|
||||
requiredProperties: List[SessionProperty]
|
||||
condition: str
|
||||
|
||||
class Model(BaseModel):
|
||||
name: str
|
||||
tableReference: Optional[TableReference] = None
|
||||
refSql: Optional[str] = None
|
||||
baseObject: Optional[str] = None
|
||||
columns: List[Column] = Field(default_factory=list)
|
||||
primaryKey: Optional[str] = None
|
||||
cached: bool = False
|
||||
refreshTime: Optional[str] = None
|
||||
rowLevelAccessControls: List[RowLevelAccessControl] = Field(default_factory=list)
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Relationship Definitions
|
||||
class Relationship(BaseModel):
|
||||
name: str
|
||||
models: List[str] # minItems: 2, maxItems: 2
|
||||
joinType: JoinType
|
||||
condition: str
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Metric Definitions
|
||||
class MetricTimeGrain(BaseModel):
|
||||
name: str
|
||||
refColumn: str
|
||||
dateParts: List[str]
|
||||
|
||||
class Metric(BaseModel):
|
||||
name: str
|
||||
baseObject: str
|
||||
dimension: List[Column] = Field(default_factory=list)
|
||||
measure: List[Column] = Field(default_factory=list)
|
||||
timeGrain: List[MetricTimeGrain] = Field(default_factory=list)
|
||||
cached: bool = False
|
||||
refreshTime: Optional[str] = None
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# View Definitions
|
||||
class View(BaseModel):
|
||||
name: str
|
||||
statement: str
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Enum Definitions
|
||||
class EnumValue(BaseModel):
|
||||
name: str
|
||||
value: Optional[str] = None
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class EnumDefinition(BaseModel):
|
||||
name: str
|
||||
values: List[EnumValue]
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Main Manifest
|
||||
class MDLManifest(BaseModel):
|
||||
catalog: str
|
||||
schema_: str = Field(..., alias="schema") # 'schema' is a reserved word in Pydantic v1/Python, aliasing
|
||||
dataSource: Optional[str] = None
|
||||
models: List[Model] = Field(default_factory=list)
|
||||
relationships: List[Relationship] = Field(default_factory=list)
|
||||
metrics: List[Metric] = Field(default_factory=list)
|
||||
views: List[View] = Field(default_factory=list)
|
||||
enumDefinitions: List[EnumDefinition] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
@@ -0,0 +1,122 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from app.models.datasource import DataSource
|
||||
from app.schemas.mdl import MDLManifest, Model, Column, TableReference
|
||||
from app.connectors.factory import get_connector
|
||||
from app.database import SessionLocal
|
||||
|
||||
# Assuming running from backend/ directory
|
||||
MDL_STORAGE_PATH = Path("data/mdl")
|
||||
|
||||
class MDLService:
|
||||
@staticmethod
|
||||
def _get_mdl_path(datasource_id: int) -> Path:
|
||||
MDL_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||
return MDL_STORAGE_PATH / f"{datasource_id}.json"
|
||||
|
||||
@staticmethod
|
||||
def get_raw_schema(datasource: DataSource) -> Dict[str, List[Dict[str, str]]]:
|
||||
connector = get_connector(datasource)
|
||||
try:
|
||||
return connector.get_schema()
|
||||
except Exception as e:
|
||||
print(f"Error fetching schema for DS {datasource.id}: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def generate_default_mdl(
|
||||
datasource: DataSource,
|
||||
selected_tables: Optional[List[str]] = None,
|
||||
selected_columns: Optional[Dict[str, List[str]]] = None,
|
||||
) -> MDLManifest:
|
||||
raw_schema = MDLService.get_raw_schema(datasource)
|
||||
|
||||
models = []
|
||||
for table_name, columns in raw_schema.items():
|
||||
if selected_tables is not None and table_name not in selected_tables:
|
||||
continue
|
||||
|
||||
model_cols = []
|
||||
for col_info in columns:
|
||||
if isinstance(col_info, dict):
|
||||
name = col_info.get("name", "UNKNOWN")
|
||||
type_ = col_info.get("type", "UNKNOWN")
|
||||
elif isinstance(col_info, str):
|
||||
# Fallback for old string format "name (type)"
|
||||
if "(" in col_info and col_info.endswith(")"):
|
||||
parts = col_info.rsplit(" (", 1)
|
||||
if len(parts) == 2:
|
||||
name = parts[0]
|
||||
type_ = parts[1][:-1]
|
||||
else:
|
||||
name = col_info
|
||||
type_ = "UNKNOWN"
|
||||
else:
|
||||
name = col_info
|
||||
type_ = "UNKNOWN"
|
||||
else:
|
||||
name = str(col_info)
|
||||
type_ = "UNKNOWN"
|
||||
|
||||
if selected_columns is not None:
|
||||
allowed = selected_columns.get(table_name, [])
|
||||
if allowed and name not in allowed:
|
||||
continue
|
||||
|
||||
model_cols.append(Column(name=name, type=type_))
|
||||
|
||||
if not model_cols:
|
||||
continue
|
||||
|
||||
models.append(Model(
|
||||
name=table_name,
|
||||
tableReference=TableReference(table=table_name),
|
||||
columns=model_cols
|
||||
))
|
||||
|
||||
return MDLManifest(
|
||||
catalog="default",
|
||||
schema="public", # Default schema, might need adjustment based on datasource config
|
||||
dataSource=datasource.type.upper(),
|
||||
models=models
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_mdl(datasource_id: int) -> Optional[MDLManifest]:
|
||||
path = MDLService._get_mdl_path(datasource_id)
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
# Pydantic v2 compatible
|
||||
return MDLManifest.model_validate(data)
|
||||
except Exception as e:
|
||||
print(f"Error loading MDL for {datasource_id}: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def save_mdl(datasource_id: int, mdl: MDLManifest):
|
||||
path = MDLService._get_mdl_path(datasource_id)
|
||||
with open(path, "w") as f:
|
||||
f.write(mdl.model_dump_json(indent=2, by_alias=True))
|
||||
|
||||
@staticmethod
|
||||
def get_or_create_mdl(datasource_id: int) -> MDLManifest:
|
||||
mdl = MDLService.get_mdl(datasource_id)
|
||||
if mdl:
|
||||
return mdl
|
||||
|
||||
# Generate new
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
|
||||
if not ds:
|
||||
raise ValueError(f"DataSource {datasource_id} not found")
|
||||
mdl = MDLService.generate_default_mdl(ds)
|
||||
MDLService.save_mdl(datasource_id, mdl)
|
||||
return mdl
|
||||
finally:
|
||||
db.close()
|
||||
Binary file not shown.
+2
-1
@@ -7,7 +7,7 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.api import upload, llm, skills, users, datasources, projects
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.nanobot import nanobot_service
|
||||
@@ -38,6 +38,7 @@ app.include_router(skills.router, prefix="/api/v1")
|
||||
app.include_router(users.router, prefix="/api/v1")
|
||||
app.include_router(projects.router, prefix="/api/v1")
|
||||
app.include_router(datasources.router, prefix="/api/v1")
|
||||
app.include_router(semantic.router, prefix="/api/v1")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
|
||||
@@ -10,6 +10,7 @@ import { Projects } from "./pages/Projects";
|
||||
import { Login } from "./pages/Login";
|
||||
import { ModelConfigs } from "./pages/ModelConfigs";
|
||||
import { DataSources } from "./pages/DataSources";
|
||||
import { Modeling } from "./pages/Modeling";
|
||||
import { useAuthStore } from "./store/authStore";
|
||||
|
||||
// Protected Route Component
|
||||
@@ -115,6 +116,14 @@ function App() {
|
||||
</MainLayout>
|
||||
</ProtectedRoute>
|
||||
} />
|
||||
|
||||
<Route path="/modeling/:id" element={
|
||||
<ProtectedRoute requireAdmin={true}>
|
||||
<MainLayout>
|
||||
<Modeling />
|
||||
</MainLayout>
|
||||
</ProtectedRoute>
|
||||
} />
|
||||
</Routes>
|
||||
</BrowserRouter>
|
||||
);
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState, useRef, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { User, Loader2, Sparkles, ArrowUp, ChevronDown, Paperclip, Check, X, File as FileIcon, Square, Plus, Database, Wand2, Search, Zap, LayoutGrid, CheckCircle2, Table, XCircle } from "lucide-react";
|
||||
import { User, Loader2, Sparkles, ArrowUp, ChevronDown, Paperclip, Check, X, Square, Plus, Database, Wand2, Search, Zap, LayoutGrid, CheckCircle2, Table, XCircle } from "lucide-react";
|
||||
import { api } from "@/lib/api";
|
||||
import { type ChartSpec } from "@/store/visualizationStore";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||
@@ -70,7 +70,7 @@ interface SessionData {
|
||||
export function ChatInterface() {
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [input, setInput] = useState("");
|
||||
const [selectedDataSource, setSelectedDataSource] = useState<string>("postgres-main");
|
||||
const [selectedDataSource, setSelectedDataSource] = useState<string>("");
|
||||
const [availableSkills, setAvailableSkills] = useState<Skill[]>([]);
|
||||
const [selectedSkillIds, setSelectedSkillIds] = useState<string[]>([]);
|
||||
const [isMenuOpen, setIsMenuOpen] = useState(false);
|
||||
@@ -104,6 +104,7 @@ export function ChatInterface() {
|
||||
|
||||
useEffect(() => {
|
||||
if (currentProject) {
|
||||
setSelectedDataSource("");
|
||||
fetchDataSources();
|
||||
}
|
||||
}, [currentProject]);
|
||||
@@ -114,14 +115,8 @@ export function ChatInterface() {
|
||||
const data = await api.get<Array<{id: number, name: string}>>(`/api/v1/datasources?project_id=${currentProject.id}`);
|
||||
const projectSources = data.map(d => ({ id: `ds:${d.id}`, name: d.name }));
|
||||
setAvailableDataSources(projectSources);
|
||||
|
||||
// Default select the first one if current selection is not in the list
|
||||
if (projectSources.length > 0) {
|
||||
if (!selectedDataSource.startsWith("ds:") || !projectSources.find(ds => ds.id === selectedDataSource)) {
|
||||
setSelectedDataSource(projectSources[0].id);
|
||||
}
|
||||
} else {
|
||||
setSelectedDataSource("upload"); // Default to upload if no data sources
|
||||
if (selectedDataSource && !projectSources.find(ds => ds.id === selectedDataSource)) {
|
||||
setSelectedDataSource("");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to fetch data sources", e);
|
||||
@@ -141,6 +136,8 @@ export function ChatInterface() {
|
||||
useEffect(() => {
|
||||
const fetchSessionData = async () => {
|
||||
setIsLoading(true);
|
||||
setSelectedDataSource("");
|
||||
setSelectedSkillIds([]);
|
||||
try {
|
||||
const data = await api.get<SessionData>(`/nanobot/sessions/${activeSessionKey}`);
|
||||
if (data.messages && data.messages.length > 0) {
|
||||
@@ -157,11 +154,6 @@ export function ChatInterface() {
|
||||
const restoredFile = data.metadata?.active_data_file || null;
|
||||
setActiveDataFile(restoredFile);
|
||||
setAttachedFile(null);
|
||||
if (restoredFile) {
|
||||
setSelectedDataSource("upload-main");
|
||||
} else if (selectedDataSource.startsWith("upload")) {
|
||||
setSelectedDataSource("postgres-main");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to fetch session messages", e);
|
||||
setMessages([]);
|
||||
@@ -245,7 +237,7 @@ export function ChatInterface() {
|
||||
};
|
||||
setAttachedFile(uploadedFile);
|
||||
setActiveDataFile(uploadedFile);
|
||||
setSelectedDataSource("upload-main");
|
||||
setSelectedDataSource("");
|
||||
await syncSessionFileContext(uploadedFile);
|
||||
} catch (error) {
|
||||
console.error("File upload error:", error);
|
||||
@@ -261,12 +253,37 @@ export function ChatInterface() {
|
||||
const handleRemoveFile = async () => {
|
||||
setAttachedFile(null);
|
||||
setActiveDataFile(null);
|
||||
if (selectedDataSource.startsWith("upload")) {
|
||||
setSelectedDataSource("postgres-main");
|
||||
}
|
||||
await syncSessionFileContext(null);
|
||||
};
|
||||
|
||||
const selectedDataSourceName = availableDataSources.find(ds => ds.id === selectedDataSource)?.name || "";
|
||||
const selectedSkills = availableSkills.filter(skill => selectedSkillIds.includes(skill.id));
|
||||
|
||||
const renderActiveSelections = () => {
|
||||
if (!selectedDataSource && selectedSkills.length === 0) return null;
|
||||
return (
|
||||
<div className="px-2 pt-2">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{selectedDataSource ? (
|
||||
<div className="px-3 py-1.5 rounded-full text-xs border flex items-center gap-1.5 bg-blue-50 text-blue-700 border-blue-200">
|
||||
<Database className="h-3.5 w-3.5" />
|
||||
{`数据源:${selectedDataSourceName}`}
|
||||
</div>
|
||||
) : null}
|
||||
{selectedSkills.map((skill) => (
|
||||
<div
|
||||
key={skill.id}
|
||||
className="px-3 py-1.5 rounded-full text-xs border flex items-center gap-1.5 bg-orange-50 text-orange-700 border-orange-200"
|
||||
>
|
||||
<Wand2 className="h-3.5 w-3.5" />
|
||||
{`Skill:${skill.name}`}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderFileCard = () => {
|
||||
const file = attachedFile || activeDataFile;
|
||||
if (!file) return null;
|
||||
@@ -328,7 +345,7 @@ export function ChatInterface() {
|
||||
};
|
||||
|
||||
const handleSend = async () => {
|
||||
if (!input.trim() || isLoading) return;
|
||||
if (!input.trim() || isLoading || !selectedDataSource) return;
|
||||
|
||||
const newMessage: Message = { id: Date.now().toString(), role: 'user', content: input };
|
||||
setMessages(prev => [...prev, newMessage]);
|
||||
@@ -357,16 +374,9 @@ export function ChatInterface() {
|
||||
const token = localStorage.getItem("token");
|
||||
const effectiveModelId = selectedModelId || currentModel?.id || "";
|
||||
|
||||
// Correctly parse source from selectedDataSource (could be 'ds:ID', 'upload', or legacy 'postgres-main')
|
||||
let source = selectedDataSource;
|
||||
if (selectedDataSource.includes("-")) {
|
||||
source = selectedDataSource.split("-")[0];
|
||||
}
|
||||
let source = selectedDataSource;
|
||||
|
||||
const useUploadSource = Boolean(
|
||||
currentAttachedFile?.url?.startsWith("local://") ||
|
||||
(source === "upload" && activeDataFile?.url?.startsWith("local://"))
|
||||
);
|
||||
const useUploadSource = Boolean(currentAttachedFile?.url?.startsWith("local://"));
|
||||
if (useUploadSource) {
|
||||
source = "upload";
|
||||
}
|
||||
@@ -580,6 +590,7 @@ export function ChatInterface() {
|
||||
<div className="relative group">
|
||||
<div className="flex flex-col bg-white rounded-[26px] border border-zinc-200 shadow-[0_2px_12px_rgba(0,0,0,0.04)] transition-all duration-200">
|
||||
{renderFileCard()}
|
||||
{renderActiveSelections()}
|
||||
<div className="flex items-center pl-2 pr-2 py-2">
|
||||
<div className="flex items-center">
|
||||
<Popover open={isMenuOpen} onOpenChange={setIsMenuOpen}>
|
||||
@@ -615,26 +626,16 @@ export function ChatInterface() {
|
||||
{selectedDataSource === ds.id && <CheckCircle2 className="h-4 w-4 text-blue-500" />}
|
||||
</button>
|
||||
))}
|
||||
|
||||
<button
|
||||
onClick={() => {
|
||||
setSelectedDataSource('upload');
|
||||
fileInputRef.current?.click();
|
||||
setIsMenuOpen(false);
|
||||
}}
|
||||
className={cn(
|
||||
"w-full flex items-center justify-between px-3 py-2.5 rounded-xl text-sm transition-all duration-200",
|
||||
selectedDataSource === 'upload' || selectedDataSource === 'upload-main'
|
||||
? "bg-white text-zinc-900 shadow-sm ring-1 ring-zinc-200"
|
||||
: "text-zinc-600 hover:bg-white hover:shadow-sm"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2.5">
|
||||
<FileIcon className={cn("h-4 w-4", (selectedDataSource === 'upload' || selectedDataSource === 'upload-main') ? "text-blue-500" : "text-zinc-400")} />
|
||||
<span className="font-medium">本地文件上传</span>
|
||||
{selectedDataSource && (
|
||||
<div className="mt-2 pt-2 border-t border-zinc-100">
|
||||
<button
|
||||
onClick={() => setSelectedDataSource("")}
|
||||
className="w-full py-1.5 text-[11px] text-zinc-400 hover:text-zinc-600 transition-colors flex items-center justify-center gap-1"
|
||||
>
|
||||
清除已选
|
||||
</button>
|
||||
</div>
|
||||
{(selectedDataSource === 'upload' || selectedDataSource === 'upload-main') && <CheckCircle2 className="h-4 w-4 text-blue-500" />}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -652,26 +653,21 @@ export function ChatInterface() {
|
||||
<button
|
||||
key={skill.id}
|
||||
onClick={() => {
|
||||
setSelectedSkillIds(prev =>
|
||||
isSelected
|
||||
? prev.filter(id => id !== skill.id)
|
||||
setSelectedSkillIds((prev) =>
|
||||
isSelected
|
||||
? prev.filter((id) => id !== skill.id)
|
||||
: [...prev, skill.id]
|
||||
);
|
||||
}}
|
||||
className={cn(
|
||||
"w-full flex items-center justify-between px-3 py-2.5 rounded-xl text-sm transition-all duration-200 group/item",
|
||||
"w-full flex items-center justify-between px-3 py-2.5 rounded-xl text-sm transition-all duration-200",
|
||||
isSelected
|
||||
? "bg-zinc-50 text-zinc-900 ring-1 ring-zinc-100"
|
||||
: "text-zinc-600 hover:bg-zinc-50"
|
||||
? "bg-white text-zinc-900 shadow-sm ring-1 ring-zinc-200"
|
||||
: "text-zinc-600 hover:bg-white hover:shadow-sm"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col items-start gap-0.5">
|
||||
<div className="flex items-center text-left">
|
||||
<span className="font-medium">{skill.name}</span>
|
||||
{skill.description && (
|
||||
<span className="text-[11px] text-zinc-400 line-clamp-1 group-hover/item:text-zinc-500">
|
||||
{skill.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{isSelected && <CheckCircle2 className="h-4 w-4 text-blue-500" />}
|
||||
</button>
|
||||
@@ -713,7 +709,7 @@ export function ChatInterface() {
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
onClick={handleSend}
|
||||
disabled={isLoading || (!input.trim() && !attachedFile && !activeDataFile)}
|
||||
disabled={isLoading || !selectedDataSource || !input.trim()}
|
||||
className={cn(
|
||||
"flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200",
|
||||
(input.trim() || attachedFile || activeDataFile) && !isLoading
|
||||
@@ -798,6 +794,7 @@ export function ChatInterface() {
|
||||
<div className="relative group max-w-4xl mx-auto">
|
||||
<div className="flex flex-col bg-white rounded-[26px] border border-zinc-200 shadow-[0_2px_12px_rgba(0,0,0,0.04)] transition-all duration-200">
|
||||
{renderFileCard()}
|
||||
{renderActiveSelections()}
|
||||
<div className="flex items-center pl-2 pr-2 py-2">
|
||||
<div className="flex items-center">
|
||||
<Popover open={isMenuOpen} onOpenChange={setIsMenuOpen}>
|
||||
@@ -833,26 +830,16 @@ export function ChatInterface() {
|
||||
{selectedDataSource === ds.id && <CheckCircle2 className="h-4 w-4 text-blue-500" />}
|
||||
</button>
|
||||
))}
|
||||
|
||||
<button
|
||||
onClick={() => {
|
||||
setSelectedDataSource('upload');
|
||||
fileInputRef.current?.click();
|
||||
setIsMenuOpen(false);
|
||||
}}
|
||||
className={cn(
|
||||
"w-full flex items-center justify-between px-3 py-2.5 rounded-xl text-sm transition-all duration-200",
|
||||
selectedDataSource === 'upload' || selectedDataSource === 'upload-main'
|
||||
? "bg-white text-zinc-900 shadow-sm ring-1 ring-zinc-200"
|
||||
: "text-zinc-600 hover:bg-white hover:shadow-sm"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2.5">
|
||||
<FileIcon className={cn("h-4 w-4", (selectedDataSource === 'upload' || selectedDataSource === 'upload-main') ? "text-blue-500" : "text-zinc-400")} />
|
||||
<span className="font-medium">本地文件上传</span>
|
||||
{selectedDataSource && (
|
||||
<div className="mt-2 pt-2 border-t border-zinc-100">
|
||||
<button
|
||||
onClick={() => setSelectedDataSource("")}
|
||||
className="w-full py-1.5 text-[11px] text-zinc-400 hover:text-zinc-600 transition-colors flex items-center justify-center gap-1"
|
||||
>
|
||||
清除已选
|
||||
</button>
|
||||
</div>
|
||||
{(selectedDataSource === 'upload' || selectedDataSource === 'upload-main') && <CheckCircle2 className="h-4 w-4 text-blue-500" />}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -870,26 +857,21 @@ export function ChatInterface() {
|
||||
<button
|
||||
key={skill.id}
|
||||
onClick={() => {
|
||||
setSelectedSkillIds(prev =>
|
||||
isSelected
|
||||
? prev.filter(id => id !== skill.id)
|
||||
setSelectedSkillIds((prev) =>
|
||||
isSelected
|
||||
? prev.filter((id) => id !== skill.id)
|
||||
: [...prev, skill.id]
|
||||
);
|
||||
}}
|
||||
className={cn(
|
||||
"w-full flex items-center justify-between px-3 py-2.5 rounded-xl text-sm transition-all duration-200 group/item",
|
||||
"w-full flex items-center justify-between px-3 py-2.5 rounded-xl text-sm transition-all duration-200",
|
||||
isSelected
|
||||
? "bg-zinc-50 text-zinc-900 ring-1 ring-zinc-100"
|
||||
: "text-zinc-600 hover:bg-zinc-50"
|
||||
? "bg-white text-zinc-900 shadow-sm ring-1 ring-zinc-200"
|
||||
: "text-zinc-600 hover:bg-white hover:shadow-sm"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col items-start gap-0.5">
|
||||
<div className="flex items-center text-left">
|
||||
<span className="font-medium">{skill.name}</span>
|
||||
{skill.description && (
|
||||
<span className="text-[11px] text-zinc-400 line-clamp-1 group-hover/item:text-zinc-500">
|
||||
{skill.description}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{isSelected && <CheckCircle2 className="h-4 w-4 text-blue-500" />}
|
||||
</button>
|
||||
@@ -931,7 +913,7 @@ export function ChatInterface() {
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
onClick={isLoading ? handleForceStop : handleSend}
|
||||
disabled={isLoading ? false : !input.trim()}
|
||||
disabled={isLoading ? false : !selectedDataSource || !input.trim()}
|
||||
className={cn(
|
||||
"flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200",
|
||||
(input.trim() || isLoading)
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState, useEffect } from "react";
|
||||
import { api } from "@/lib/api";
|
||||
import { DataSourceForm, type DataSourceConfig } from "@/components/DataSourceForm";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Plus, Database, Pencil, Trash2, Loader2, FolderOpen, Info, ChevronLeft, FileText, Search } from "lucide-react";
|
||||
import { Plus, Database, Pencil, Trash2, Loader2, FolderOpen, Info, ChevronLeft, FileText, Search, Network } from "lucide-react";
|
||||
import { Dialog, DialogContent, DialogHeader, DialogTitle } from "@/components/ui/dialog";
|
||||
import { useAuthStore } from "@/store/authStore";
|
||||
import { useProjectStore } from "@/store/projectStore";
|
||||
@@ -216,9 +216,12 @@ export function DataSources() {
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-zinc-400 hover:text-zinc-600" onClick={() => handleEdit(ds)}>
|
||||
<Pencil className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-zinc-400 hover:text-blue-600" onClick={() => navigate(`/modeling/${ds.id}`)} title="Data Modeling">
|
||||
<Network className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-zinc-400 hover:text-zinc-600" onClick={() => handleEdit(ds)}>
|
||||
<Pencil className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-zinc-400 hover:text-red-600 hover:bg-red-50" onClick={() => handleDelete(ds.id!)}>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
@@ -0,0 +1,514 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useParams, useNavigate } from "react-router-dom";
|
||||
import { api } from "../lib/api";
|
||||
import { Button } from "../components/ui/button";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "../components/ui/card";
|
||||
import { Label } from "../components/ui/label";
|
||||
import { ScrollArea } from "../components/ui/scroll-area";
|
||||
import { Dialog, DialogContent, DialogHeader, DialogTitle } from "../components/ui/dialog";
|
||||
import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "../components/ui/table";
|
||||
import { ArrowLeft, Table as TableIcon, Network } from "lucide-react";
|
||||
|
||||
interface RawSchema {
|
||||
[table: string]: { name: string; type: string }[];
|
||||
}
|
||||
|
||||
interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
isCalculated: boolean;
|
||||
relationship?: string;
|
||||
expression?: string;
|
||||
properties?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface Model {
|
||||
name: string;
|
||||
columns: Column[];
|
||||
primaryKey?: string;
|
||||
properties?: Record<string, any>;
|
||||
}
|
||||
|
||||
interface Relationship {
|
||||
name: string;
|
||||
models: string[];
|
||||
joinType: string;
|
||||
condition: string;
|
||||
}
|
||||
|
||||
interface MDLManifest {
|
||||
catalog: string;
|
||||
schema: string;
|
||||
dataSource: string;
|
||||
models: Model[];
|
||||
relationships: Relationship[];
|
||||
}
|
||||
|
||||
interface ModelDetailResponse {
|
||||
model: {
|
||||
name: string;
|
||||
tableReference?: {
|
||||
table: string;
|
||||
schema?: string;
|
||||
catalog?: string;
|
||||
} | null;
|
||||
primaryKey?: string;
|
||||
properties?: Record<string, unknown>;
|
||||
columns: Column[];
|
||||
};
|
||||
relationships: {
|
||||
name: string;
|
||||
models: string[];
|
||||
joinType: string;
|
||||
condition: string;
|
||||
properties?: Record<string, unknown>;
|
||||
}[];
|
||||
preview_rows: Record<string, unknown>[];
|
||||
}
|
||||
|
||||
export function Modeling() {
|
||||
const { id } = useParams<{ id: string }>();
|
||||
const navigate = useNavigate();
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [schema, setSchema] = useState<RawSchema | null>(null);
|
||||
const [mdl, setMdl] = useState<MDLManifest | null>(null);
|
||||
const [selectedTables, setSelectedTables] = useState<string[]>([]);
|
||||
const [selectedColumns, setSelectedColumns] = useState<Record<string, string[]>>({});
|
||||
const [expandedTables, setExpandedTables] = useState<Record<string, boolean>>({});
|
||||
const [step, setStep] = useState<"select" | "view">("select");
|
||||
const [detailOpen, setDetailOpen] = useState(false);
|
||||
const [detailLoading, setDetailLoading] = useState(false);
|
||||
const [modelDetail, setModelDetail] = useState<ModelDetailResponse | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
fetchInitialData();
|
||||
}, [id]);
|
||||
|
||||
const initSelectionFromSchema = (schemaRes: RawSchema) => {
|
||||
const tableNames = Object.keys(schemaRes);
|
||||
const columnsMap: Record<string, string[]> = {};
|
||||
const expanded: Record<string, boolean> = {};
|
||||
for (const tableName of tableNames) {
|
||||
columnsMap[tableName] = schemaRes[tableName].map((c) => c.name);
|
||||
expanded[tableName] = true;
|
||||
}
|
||||
setSchema(schemaRes);
|
||||
setSelectedTables(tableNames);
|
||||
setSelectedColumns(columnsMap);
|
||||
setExpandedTables(expanded);
|
||||
};
|
||||
|
||||
const fetchSchemaOnly = async () => {
|
||||
const schemaRes = await api.get(`/api/v1/semantic/${id}/schema`) as RawSchema;
|
||||
initSelectionFromSchema(schemaRes);
|
||||
setStep("select");
|
||||
};
|
||||
|
||||
const fetchInitialData = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const mdlRes = await api.get(`/api/v1/semantic/${id}`) as any;
|
||||
if (mdlRes && mdlRes.models && mdlRes.models.length > 0) {
|
||||
setMdl(mdlRes as MDLManifest);
|
||||
setStep("view");
|
||||
} else {
|
||||
await fetchSchemaOnly();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch modeling data:", error);
|
||||
try {
|
||||
await fetchSchemaOnly();
|
||||
} catch (e) {
|
||||
console.error("Failed to fetch schema:", e);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleGenerate = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const res = await api.post(`/api/v1/semantic/${id}/generate`, {
|
||||
selected_tables: selectedTables,
|
||||
selected_columns: Object.fromEntries(
|
||||
selectedTables.map((table) => [table, selectedColumns[table] ?? []])
|
||||
),
|
||||
}) as MDLManifest;
|
||||
setMdl(res);
|
||||
setStep("view");
|
||||
} catch (error) {
|
||||
console.error("Failed to generate MDL:", error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const toggleTable = (table: string) => {
|
||||
setSelectedTables((prev) =>
|
||||
prev.includes(table) ? prev.filter((t) => t !== table) : [...prev, table]
|
||||
);
|
||||
if (!schema) return;
|
||||
if (!selectedTables.includes(table) && (!selectedColumns[table] || selectedColumns[table].length === 0)) {
|
||||
setSelectedColumns((prev) => ({
|
||||
...prev,
|
||||
[table]: schema[table].map((c) => c.name),
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
const toggleColumn = (table: string, column: string) => {
|
||||
setSelectedColumns((prev) => {
|
||||
const current = prev[table] ?? [];
|
||||
const has = current.includes(column);
|
||||
const next = has ? current.filter((c) => c !== column) : [...current, column];
|
||||
return { ...prev, [table]: next };
|
||||
});
|
||||
setSelectedTables((prev) => {
|
||||
const exists = prev.includes(table);
|
||||
const current = selectedColumns[table] ?? [];
|
||||
const has = current.includes(column);
|
||||
const nextLen = has ? current.length - 1 : current.length + 1;
|
||||
if (nextLen <= 0) {
|
||||
return prev.filter((t) => t !== table);
|
||||
}
|
||||
if (!exists) {
|
||||
return [...prev, table];
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
};
|
||||
|
||||
const toggleExpandTable = (table: string) => {
|
||||
setExpandedTables((prev) => ({ ...prev, [table]: !prev[table] }));
|
||||
};
|
||||
|
||||
const handleSelectAll = () => {
|
||||
if (!schema) return;
|
||||
const tableNames = Object.keys(schema);
|
||||
setSelectedTables(tableNames);
|
||||
setSelectedColumns(
|
||||
Object.fromEntries(
|
||||
tableNames.map((table) => [table, schema[table].map((c) => c.name)])
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const handleClearAll = () => {
|
||||
setSelectedTables([]);
|
||||
setSelectedColumns({});
|
||||
};
|
||||
|
||||
const handleReselectTables = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
await fetchSchemaOnly();
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const openModelDetail = async (modelName: string) => {
|
||||
try {
|
||||
setDetailOpen(true);
|
||||
setDetailLoading(true);
|
||||
const detail = await api.get<ModelDetailResponse>(
|
||||
`/api/v1/semantic/${id}/models/${encodeURIComponent(modelName)}?limit=10`
|
||||
);
|
||||
setModelDetail(detail);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch model detail:", error);
|
||||
setModelDetail(null);
|
||||
} finally {
|
||||
setDetailLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return <div className="p-8 text-center">Loading modeling data...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full bg-gray-50">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-6 py-4 bg-white border-b">
|
||||
<div className="flex items-center gap-4">
|
||||
<Button variant="ghost" size="icon" onClick={() => navigate("/datasources")}>
|
||||
<ArrowLeft className="w-5 h-5" />
|
||||
</Button>
|
||||
<div>
|
||||
<h1 className="text-xl font-semibold">Data Modeling</h1>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
DataSource ID: {id} • {step === "select" ? "Select Tables" : "Entity Relationship Diagram"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{step === "view" && (
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" onClick={handleReselectTables}>
|
||||
Reselect Tables
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 overflow-hidden p-6">
|
||||
{step === "select" ? (
|
||||
<div className="max-w-4xl mx-auto h-full flex flex-col">
|
||||
<Card className="flex-1 flex flex-col overflow-hidden">
|
||||
<CardHeader>
|
||||
<CardTitle>Select tables to create data models</CardTitle>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Choose the tables you want to include in your semantic model.
|
||||
</p>
|
||||
</CardHeader>
|
||||
<CardContent className="flex-1 overflow-hidden flex flex-col">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{selectedTables.length} / {schema ? Object.keys(schema).length : 0} selected
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleSelectAll}
|
||||
>
|
||||
Select All
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleClearAll}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ScrollArea className="flex-1 border rounded-md p-4">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{schema && Object.keys(schema).map((table) => (
|
||||
<div
|
||||
key={table}
|
||||
className={`p-3 rounded-lg border transition-colors ${
|
||||
selectedTables.includes(table)
|
||||
? "bg-primary/5 border-primary"
|
||||
: "bg-white hover:bg-gray-50"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<div className="flex items-center space-x-3">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="h-4 w-4 rounded border-gray-300 text-primary focus:ring-primary"
|
||||
checked={selectedTables.includes(table)}
|
||||
onChange={() => toggleTable(table)}
|
||||
/>
|
||||
<Label className="cursor-pointer font-medium flex items-center gap-2">
|
||||
<TableIcon className="w-4 h-4 text-muted-foreground" />
|
||||
{table}
|
||||
</Label>
|
||||
</div>
|
||||
<Button variant="ghost" size="sm" onClick={() => toggleExpandTable(table)}>
|
||||
{expandedTables[table] ? "Hide Columns" : "Show Columns"}
|
||||
</Button>
|
||||
</div>
|
||||
{expandedTables[table] && (
|
||||
<div className="mt-3 max-h-48 overflow-auto border rounded-md bg-white">
|
||||
{schema[table].map((col) => (
|
||||
<label
|
||||
key={`${table}:${col.name}`}
|
||||
className="flex items-center justify-between px-3 py-2 border-b last:border-b-0 cursor-pointer hover:bg-gray-50"
|
||||
>
|
||||
<div className="flex items-center gap-2 min-w-0">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="h-4 w-4 rounded border-gray-300 text-primary focus:ring-primary"
|
||||
checked={(selectedColumns[table] ?? []).includes(col.name)}
|
||||
onChange={() => toggleColumn(table, col.name)}
|
||||
/>
|
||||
<span className="text-sm truncate">{col.name}</span>
|
||||
</div>
|
||||
<span className="text-[10px] font-mono text-muted-foreground ml-2">{col.type}</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
|
||||
<div className="pt-6 flex justify-end">
|
||||
<Button onClick={handleGenerate} disabled={selectedTables.length === 0}>
|
||||
Generate Model
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
) : (
|
||||
<div className="h-full flex gap-6">
|
||||
{/* Sidebar List */}
|
||||
<Card className="w-64 flex flex-col h-full">
|
||||
<CardHeader className="py-4 px-4 border-b">
|
||||
<CardTitle className="text-sm font-medium">Models ({mdl?.models.length})</CardTitle>
|
||||
</CardHeader>
|
||||
<ScrollArea className="flex-1">
|
||||
<div className="p-2 space-y-1">
|
||||
{mdl?.models.map((model) => (
|
||||
<div
|
||||
key={model.name}
|
||||
className="flex items-center gap-2 px-3 py-2 text-sm rounded-md hover:bg-gray-100 cursor-pointer"
|
||||
onClick={() => openModelDetail(model.name)}
|
||||
>
|
||||
<TableIcon className="w-4 h-4 text-muted-foreground" />
|
||||
<span className="truncate">{model.name}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</Card>
|
||||
|
||||
{/* Canvas Area (Simulated) */}
|
||||
<div className="flex-1 overflow-auto bg-slate-100 rounded-lg border p-8 relative">
|
||||
<div className="absolute inset-0 pointer-events-none"
|
||||
style={{
|
||||
backgroundImage: 'radial-gradient(#cbd5e1 1px, transparent 1px)',
|
||||
backgroundSize: '20px 20px'
|
||||
}}
|
||||
/>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 relative z-10">
|
||||
{mdl?.models.map((model) => (
|
||||
<Card key={model.name} className="shadow-md border-t-4 border-t-blue-500 min-w-[240px] cursor-pointer" onClick={() => openModelDetail(model.name)}>
|
||||
<CardHeader className="py-3 px-4 bg-gray-50 border-b flex flex-row items-center justify-between">
|
||||
<div className="font-semibold text-sm flex items-center gap-2">
|
||||
<TableIcon className="w-4 h-4 text-blue-500" />
|
||||
{model.name}
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="p-0">
|
||||
<div className="max-h-[300px] overflow-y-auto text-xs">
|
||||
{model.columns.map((col) => (
|
||||
<div key={col.name} className="flex justify-between py-2 px-4 border-b last:border-0 hover:bg-gray-50">
|
||||
<span className="font-medium">{col.name}</span>
|
||||
<span className="text-muted-foreground font-mono text-[10px]">{col.type}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
{/* Show Relationships if any */}
|
||||
{mdl.relationships.filter(r => r.models.includes(model.name)).length > 0 && (
|
||||
<div className="bg-orange-50 p-2 border-t text-xs">
|
||||
<div className="font-semibold text-orange-700 mb-1 flex items-center gap-1">
|
||||
<Network className="w-3 h-3" /> Relationships
|
||||
</div>
|
||||
{mdl.relationships
|
||||
.filter(r => r.models.includes(model.name))
|
||||
.map(r => {
|
||||
const other = r.models.find(m => m !== model.name);
|
||||
return (
|
||||
<div key={r.name} className="text-orange-600 truncate" title={`${r.joinType} with ${other}`}>
|
||||
⟷ {other}
|
||||
</div>
|
||||
);
|
||||
})
|
||||
}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<Dialog open={detailOpen} onOpenChange={setDetailOpen}>
|
||||
<DialogContent className="sm:max-w-[1100px] max-h-[85vh] overflow-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{modelDetail?.model?.name ?? "Model Detail"}</DialogTitle>
|
||||
</DialogHeader>
|
||||
{detailLoading ? (
|
||||
<div className="py-8 text-center text-muted-foreground">Loading model detail...</div>
|
||||
) : !modelDetail ? (
|
||||
<div className="py-8 text-center text-muted-foreground">No metadata available.</div>
|
||||
) : (
|
||||
<div className="space-y-6">
|
||||
<div className="space-y-2">
|
||||
<div className="text-base font-semibold">Columns Metadata</div>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Type</TableHead>
|
||||
<TableHead>Description</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{modelDetail.model.columns.map((col) => (
|
||||
<TableRow key={col.name}>
|
||||
<TableCell>{col.name}</TableCell>
|
||||
<TableCell>{col.type}</TableCell>
|
||||
<TableCell>{String(col.properties?.description ?? "-")}</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-base font-semibold">Relationships ({modelDetail.relationships.length})</div>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Models</TableHead>
|
||||
<TableHead>Type</TableHead>
|
||||
<TableHead>Condition</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{modelDetail.relationships.map((rel) => (
|
||||
<TableRow key={rel.name}>
|
||||
<TableCell>{rel.name}</TableCell>
|
||||
<TableCell>{rel.models.join(" ↔ ")}</TableCell>
|
||||
<TableCell>{rel.joinType}</TableCell>
|
||||
<TableCell>{rel.condition}</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<div className="text-base font-semibold">Data Preview (Top 10)</div>
|
||||
{modelDetail.preview_rows.length === 0 ? (
|
||||
<div className="text-sm text-muted-foreground">No preview data.</div>
|
||||
) : (
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
{Object.keys(modelDetail.preview_rows[0]).map((key) => (
|
||||
<TableHead key={key}>{key}</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{modelDetail.preview_rows.map((row, idx) => (
|
||||
<TableRow key={idx}>
|
||||
{Object.keys(modelDetail.preview_rows[0]).map((key) => (
|
||||
<TableCell key={key}>{String(row[key] ?? "")}</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user