feat: add modelling layer

This commit is contained in:
qixinbo
2026-03-16 22:18:23 +08:00
parent a1a855a126
commit 720c30a893
16 changed files with 1115 additions and 106 deletions
+37 -4
View File
@@ -24,6 +24,7 @@ from app.agent.chart import generate_chart
from app.database import SessionLocal
from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -116,11 +117,11 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
return pd.read_parquet(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
schema = {"uploaded_file": [{"name": col[0], "type": col[1]} for col in columns]}
conn.close()
return schema
@@ -167,7 +168,7 @@ def _build_schema_cache_key(source: str, connector: Any) -> str:
)
return source
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[Dict[str, str]]]]:
key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
@@ -176,7 +177,7 @@ def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[s
return cached["schema"]
return None
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[str, str]]]) -> None:
key = _build_schema_cache_key(source, connector)
with _cache_lock:
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
@@ -247,6 +248,37 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
schema_str = json.dumps(schema, indent=2)
# Try to load MDL context
mdl_context = ""
if request.source.startswith("ds:"):
try:
ds_id = int(request.source.split(":")[1])
mdl = MDLService.get_mdl(ds_id)
if mdl:
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
mdl_lines.append("MODELS:")
for model in mdl.models:
table_ref = model.tableReference.table if model.tableReference else model.name
desc = f" - Description: {model.properties.get('description', '')}" if model.properties.get('description') else ""
mdl_lines.append(f"- Model: {model.name} (Table: {table_ref}){desc}")
if model.columns:
mdl_lines.append(" Columns:")
for col in model.columns:
col_desc = f" ({col.properties.get('description')})" if col.properties.get('description') else ""
expr = f" [Calculated: {col.expression}]" if col.isCalculated else ""
mdl_lines.append(f" - {col.name} ({col.type}){col_desc}{expr}")
if mdl.relationships:
mdl_lines.append("\nRELATIONSHIPS:")
for rel in mdl.relationships:
mdl_lines.append(f"- {rel.name}: {rel.joinType} between {rel.models} ON {rel.condition}")
mdl_context = "\n".join(mdl_lines)
except Exception as e:
print(f"Failed to load MDL: {e}")
# 2. Get the active LLM config
llm_configs = load_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
@@ -270,6 +302,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
user_prompt = f"""
### DATABASE SCHEMA ###
{schema_str}
{mdl_context}
### INPUTS ###
User's Question: {request.query}
+139
View File
@@ -0,0 +1,139 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, List, Optional
from pydantic import BaseModel
from app.database import get_db
from app.models.datasource import DataSource
from app.schemas.mdl import MDLManifest
from app.services.mdl import MDLService
from app.connectors.factory import get_connector
router = APIRouter(tags=["semantic"])
class GenerateMDLRequest(BaseModel):
selected_tables: Optional[List[str]] = None
selected_columns: Optional[Dict[str, List[str]]] = None
class ModelDetailResponse(BaseModel):
model: Dict[str, Any]
relationships: List[Dict[str, Any]]
preview_rows: List[Dict[str, Any]]
def _normalize_query_result(results: Any) -> List[Dict[str, Any]]:
if isinstance(results, list):
if results and isinstance(results[0], dict):
return results
if results and isinstance(results[0], (list, tuple)):
return [dict(enumerate(row)) for row in results]
return []
if isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
return [dict(zip(col_names, row)) for row in rows]
return []
@router.get("/semantic/{datasource_id}/schema", response_model=Dict[str, List[Dict[str, str]]])
def get_semantic_schema(datasource_id: int, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
return MDLService.get_raw_schema(ds)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/semantic/{datasource_id}", response_model=MDLManifest)
def get_semantic_model(datasource_id: int, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
# Get or generate MDL
try:
mdl = MDLService.get_or_create_mdl(datasource_id)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/semantic/{datasource_id}", response_model=MDLManifest)
def update_semantic_model(datasource_id: int, mdl: MDLManifest, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
MDLService.save_mdl(datasource_id, mdl)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/semantic/{datasource_id}/generate", response_model=MDLManifest)
def regenerate_semantic_model(datasource_id: int, request: Optional[GenerateMDLRequest] = None, db: Session = Depends(get_db)):
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
selected_tables = request.selected_tables if request else None
selected_columns = request.selected_columns if request else None
mdl = MDLService.generate_default_mdl(
ds,
selected_tables=selected_tables,
selected_columns=selected_columns,
)
MDLService.save_mdl(datasource_id, mdl)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/semantic/{datasource_id}/models/{model_name}", response_model=ModelDetailResponse)
def get_model_detail(datasource_id: int, model_name: str, limit: int = 10, db: Session = Depends(get_db)):
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
mdl = MDLService.get_or_create_mdl(datasource_id)
model = next((m for m in mdl.models if m.name == model_name), None)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
relationships = [
{
"name": rel.name,
"models": rel.models,
"joinType": rel.joinType,
"condition": rel.condition,
"properties": rel.properties,
}
for rel in mdl.relationships
if model_name in rel.models
]
preview_rows: List[Dict[str, Any]] = []
try:
connector = get_connector(ds)
table_name = model.tableReference.table if model.tableReference else model.name
query = f'SELECT * FROM "{table_name}" LIMIT {max(1, min(limit, 100))}'
raw = connector.execute_query(query)
preview_rows = _normalize_query_result(raw)
except Exception:
preview_rows = []
model_payload = {
"name": model.name,
"tableReference": model.tableReference.model_dump(by_alias=True) if model.tableReference else None,
"primaryKey": model.primaryKey,
"properties": model.properties,
"columns": [c.model_dump(by_alias=True) for c in model.columns],
}
return ModelDetailResponse(
model=model_payload,
relationships=relationships,
preview_rows=preview_rows,
)
+1 -1
View File
@@ -33,7 +33,7 @@ class ClickHouseConnector:
table = row[0]
if table not in schema:
schema[table] = []
schema[table].append(f"{row[1]} ({row[2]})")
schema[table].append({"name": row[1], "type": row[2]})
return schema
except Exception as e:
print(f"Error getting schema: {e}")
+68
View File
@@ -0,0 +1,68 @@
import duckdb
import pandas as pd
from typing import List, Dict, Any
import os
from app.core.files import resolve_upload_file_path
class CSVConnector:
def __init__(self, file_path: str):
self.file_path = file_path
if not os.path.exists(self.file_path):
raise FileNotFoundError(f"CSV file not found: {self.file_path}")
def _get_table_name(self) -> str:
# Normalize table name to be SQL safe-ish
base = os.path.splitext(os.path.basename(self.file_path))[0]
# Replace non-alphanumeric chars with underscore
safe_name = "".join([c if c.isalnum() else "_" for c in base])
# Ensure it doesn't start with a number
if safe_name and safe_name[0].isdigit():
safe_name = f"t_{safe_name}"
return safe_name
def execute_query(self, query: str) -> List[Dict[str, Any]]:
conn = duckdb.connect(":memory:")
table_name = self._get_table_name()
# Register the csv file as a view
# read_csv_auto is powerful
try:
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')")
# Execute the user query
# The query should rely on the table name provided in schema
df = conn.execute(query).df()
return df.to_dict(orient="records")
except Exception as e:
print(f"CSV Query Error: {e}")
raise e
finally:
conn.close()
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
table_name = self._get_table_name()
try:
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')")
# Get columns
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
# col[0] is name, col[1] is type
schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]}
return schema
except Exception as e:
print(f"Error getting schema: {e}")
return {}
finally:
conn.close()
def test_connection(self) -> bool:
try:
conn = duckdb.connect(":memory:")
conn.execute(f"SELECT * FROM read_csv_auto('{self.file_path}') LIMIT 1")
conn.close()
return True
except Exception as e:
print(f"CSV Connection Error: {e}")
return False
+5
View File
@@ -4,6 +4,7 @@ import functools
from app.connectors.postgres import PostgresConnector
from app.connectors.clickhouse import ClickHouseConnector
from app.connectors.parquet import ParquetConnector
from app.connectors.csv import CSVConnector
from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
@@ -37,6 +38,10 @@ def _get_cached_connector(ds_type: str, config_json: str):
elif ds_type == "parquet":
file_path = str(resolve_upload_file_path(config.get("file_path")))
return ParquetConnector(file_path=file_path)
elif ds_type == "csv":
file_path = str(resolve_upload_file_path(config.get("file_path")))
return CSVConnector(file_path=file_path)
else:
raise ValueError(f"Unsupported data source type: {ds_type}")
+2 -2
View File
@@ -31,7 +31,7 @@ class ParquetConnector:
finally:
conn.close()
def get_schema(self) -> Dict[str, List[str]]:
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
@@ -39,7 +39,7 @@ class ParquetConnector:
try:
# Get columns
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]}
schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]}
return schema
except Exception as e:
print(f"Error getting schema: {e}")
+19 -1
View File
@@ -22,6 +22,9 @@ class PostgresConnector:
return [dict(row._mapping) for row in result]
def get_schema(self):
if self.engine.dialect.name == "sqlite":
return self._get_sqlite_schema()
query = """
SELECT table_name, column_name, data_type
FROM information_schema.columns
@@ -35,12 +38,27 @@ class PostgresConnector:
table = row['table_name']
if table not in schema:
schema[table] = []
schema[table].append(f"{row['column_name']} ({row['data_type']})")
schema[table].append({"name": row['column_name'], "type": row['data_type']})
return schema
except Exception as e:
print(f"Error getting schema: {e}")
return {}
def _get_sqlite_schema(self):
try:
from sqlalchemy import inspect
inspector = inspect(self.engine)
schema = {}
for table_name in inspector.get_table_names():
columns = []
for col in inspector.get_columns(table_name):
columns.append({"name": col['name'], "type": str(col['type'])})
schema[table_name] = columns
return schema
except Exception as e:
print(f"Error getting SQLite schema: {e}")
return {}
def test_connection(self) -> bool:
try:
with self.engine.connect() as connection:
+115
View File
@@ -0,0 +1,115 @@
from typing import List, Optional, Dict, Any, Union, Literal
from pydantic import BaseModel, Field
# Common Types
AccessControlOperator = Literal[
"EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN",
"GREATER_THAN_OR_EQUALS", "LESS_THAN_OR_EQUALS"
]
JoinType = Literal["ONE_TO_ONE", "ONE_TO_MANY", "MANY_TO_ONE", "MANY_TO_MANY"]
# Column Definitions
class SessionProperty(BaseModel):
name: str
required: bool
defaultExpr: Optional[str] = None
class AccessControlThreshold(BaseModel):
value: str
dataType: Literal["NUMERIC", "STRING"]
class ColumnAccessControl(BaseModel):
name: str
operator: AccessControlOperator
requiredProperties: List[SessionProperty]
threshold: Optional[AccessControlThreshold] = None
class Column(BaseModel):
name: str
type: str
relationship: Optional[str] = None
isCalculated: bool = False
notNull: bool = False
expression: Optional[str] = None
isHidden: bool = False
columnLevelAccessControl: Optional[ColumnAccessControl] = None
properties: Dict[str, str] = Field(default_factory=dict)
# Model Definitions
class TableReference(BaseModel):
catalog: Optional[str] = None
schema_: Optional[str] = Field(None, alias="schema")
table: str
class RowLevelAccessControl(BaseModel):
name: str
requiredProperties: List[SessionProperty]
condition: str
class Model(BaseModel):
name: str
tableReference: Optional[TableReference] = None
refSql: Optional[str] = None
baseObject: Optional[str] = None
columns: List[Column] = Field(default_factory=list)
primaryKey: Optional[str] = None
cached: bool = False
refreshTime: Optional[str] = None
rowLevelAccessControls: List[RowLevelAccessControl] = Field(default_factory=list)
properties: Dict[str, Any] = Field(default_factory=dict)
# Relationship Definitions
class Relationship(BaseModel):
name: str
models: List[str] # minItems: 2, maxItems: 2
joinType: JoinType
condition: str
properties: Dict[str, Any] = Field(default_factory=dict)
# Metric Definitions
class MetricTimeGrain(BaseModel):
name: str
refColumn: str
dateParts: List[str]
class Metric(BaseModel):
name: str
baseObject: str
dimension: List[Column] = Field(default_factory=list)
measure: List[Column] = Field(default_factory=list)
timeGrain: List[MetricTimeGrain] = Field(default_factory=list)
cached: bool = False
refreshTime: Optional[str] = None
properties: Dict[str, Any] = Field(default_factory=dict)
# View Definitions
class View(BaseModel):
name: str
statement: str
properties: Dict[str, Any] = Field(default_factory=dict)
# Enum Definitions
class EnumValue(BaseModel):
name: str
value: Optional[str] = None
properties: Dict[str, Any] = Field(default_factory=dict)
class EnumDefinition(BaseModel):
name: str
values: List[EnumValue]
properties: Dict[str, Any] = Field(default_factory=dict)
# Main Manifest
class MDLManifest(BaseModel):
catalog: str
schema_: str = Field(..., alias="schema") # 'schema' is a reserved word in Pydantic v1/Python, aliasing
dataSource: Optional[str] = None
models: List[Model] = Field(default_factory=list)
relationships: List[Relationship] = Field(default_factory=list)
metrics: List[Metric] = Field(default_factory=list)
views: List[View] = Field(default_factory=list)
enumDefinitions: List[EnumDefinition] = Field(default_factory=list)
class Config:
populate_by_name = True
+122
View File
@@ -0,0 +1,122 @@
import json
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
from app.models.datasource import DataSource
from app.schemas.mdl import MDLManifest, Model, Column, TableReference
from app.connectors.factory import get_connector
from app.database import SessionLocal
# Assuming running from backend/ directory
MDL_STORAGE_PATH = Path("data/mdl")
class MDLService:
@staticmethod
def _get_mdl_path(datasource_id: int) -> Path:
MDL_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
return MDL_STORAGE_PATH / f"{datasource_id}.json"
@staticmethod
def get_raw_schema(datasource: DataSource) -> Dict[str, List[Dict[str, str]]]:
connector = get_connector(datasource)
try:
return connector.get_schema()
except Exception as e:
print(f"Error fetching schema for DS {datasource.id}: {e}")
return {}
@staticmethod
def generate_default_mdl(
datasource: DataSource,
selected_tables: Optional[List[str]] = None,
selected_columns: Optional[Dict[str, List[str]]] = None,
) -> MDLManifest:
raw_schema = MDLService.get_raw_schema(datasource)
models = []
for table_name, columns in raw_schema.items():
if selected_tables is not None and table_name not in selected_tables:
continue
model_cols = []
for col_info in columns:
if isinstance(col_info, dict):
name = col_info.get("name", "UNKNOWN")
type_ = col_info.get("type", "UNKNOWN")
elif isinstance(col_info, str):
# Fallback for old string format "name (type)"
if "(" in col_info and col_info.endswith(")"):
parts = col_info.rsplit(" (", 1)
if len(parts) == 2:
name = parts[0]
type_ = parts[1][:-1]
else:
name = col_info
type_ = "UNKNOWN"
else:
name = col_info
type_ = "UNKNOWN"
else:
name = str(col_info)
type_ = "UNKNOWN"
if selected_columns is not None:
allowed = selected_columns.get(table_name, [])
if allowed and name not in allowed:
continue
model_cols.append(Column(name=name, type=type_))
if not model_cols:
continue
models.append(Model(
name=table_name,
tableReference=TableReference(table=table_name),
columns=model_cols
))
return MDLManifest(
catalog="default",
schema="public", # Default schema, might need adjustment based on datasource config
dataSource=datasource.type.upper(),
models=models
)
@staticmethod
def get_mdl(datasource_id: int) -> Optional[MDLManifest]:
path = MDLService._get_mdl_path(datasource_id)
if path.exists():
try:
with open(path, "r") as f:
data = json.load(f)
# Pydantic v2 compatible
return MDLManifest.model_validate(data)
except Exception as e:
print(f"Error loading MDL for {datasource_id}: {e}")
return None
return None
@staticmethod
def save_mdl(datasource_id: int, mdl: MDLManifest):
path = MDLService._get_mdl_path(datasource_id)
with open(path, "w") as f:
f.write(mdl.model_dump_json(indent=2, by_alias=True))
@staticmethod
def get_or_create_mdl(datasource_id: int) -> MDLManifest:
mdl = MDLService.get_mdl(datasource_id)
if mdl:
return mdl
# Generate new
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise ValueError(f"DataSource {datasource_id} not found")
mdl = MDLService.generate_default_mdl(ds)
MDLService.save_mdl(datasource_id, mdl)
return mdl
finally:
db.close()