feat: add modelling layer

This commit is contained in:
qixinbo
2026-03-16 22:18:23 +08:00
parent a1a855a126
commit 720c30a893
16 changed files with 1115 additions and 106 deletions
+37 -4
View File
@@ -24,6 +24,7 @@ from app.agent.chart import generate_chart
from app.database import SessionLocal
from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -116,11 +117,11 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
return pd.read_parquet(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
schema = {"uploaded_file": [{"name": col[0], "type": col[1]} for col in columns]}
conn.close()
return schema
@@ -167,7 +168,7 @@ def _build_schema_cache_key(source: str, connector: Any) -> str:
)
return source
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[Dict[str, str]]]]:
key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
@@ -176,7 +177,7 @@ def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[s
return cached["schema"]
return None
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[str, str]]]) -> None:
key = _build_schema_cache_key(source, connector)
with _cache_lock:
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
@@ -247,6 +248,37 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
schema_str = json.dumps(schema, indent=2)
# Try to load MDL context
mdl_context = ""
if request.source.startswith("ds:"):
try:
ds_id = int(request.source.split(":")[1])
mdl = MDLService.get_mdl(ds_id)
if mdl:
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
mdl_lines.append("MODELS:")
for model in mdl.models:
table_ref = model.tableReference.table if model.tableReference else model.name
desc = f" - Description: {model.properties.get('description', '')}" if model.properties.get('description') else ""
mdl_lines.append(f"- Model: {model.name} (Table: {table_ref}){desc}")
if model.columns:
mdl_lines.append(" Columns:")
for col in model.columns:
col_desc = f" ({col.properties.get('description')})" if col.properties.get('description') else ""
expr = f" [Calculated: {col.expression}]" if col.isCalculated else ""
mdl_lines.append(f" - {col.name} ({col.type}){col_desc}{expr}")
if mdl.relationships:
mdl_lines.append("\nRELATIONSHIPS:")
for rel in mdl.relationships:
mdl_lines.append(f"- {rel.name}: {rel.joinType} between {rel.models} ON {rel.condition}")
mdl_context = "\n".join(mdl_lines)
except Exception as e:
print(f"Failed to load MDL: {e}")
# 2. Get the active LLM config
llm_configs = load_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
@@ -270,6 +302,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
user_prompt = f"""
### DATABASE SCHEMA ###
{schema_str}
{mdl_context}
### INPUTS ###
User's Question: {request.query}
+139
View File
@@ -0,0 +1,139 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, List, Optional
from pydantic import BaseModel
from app.database import get_db
from app.models.datasource import DataSource
from app.schemas.mdl import MDLManifest
from app.services.mdl import MDLService
from app.connectors.factory import get_connector
router = APIRouter(tags=["semantic"])
class GenerateMDLRequest(BaseModel):
selected_tables: Optional[List[str]] = None
selected_columns: Optional[Dict[str, List[str]]] = None
class ModelDetailResponse(BaseModel):
model: Dict[str, Any]
relationships: List[Dict[str, Any]]
preview_rows: List[Dict[str, Any]]
def _normalize_query_result(results: Any) -> List[Dict[str, Any]]:
if isinstance(results, list):
if results and isinstance(results[0], dict):
return results
if results and isinstance(results[0], (list, tuple)):
return [dict(enumerate(row)) for row in results]
return []
if isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
return [dict(zip(col_names, row)) for row in rows]
return []
@router.get("/semantic/{datasource_id}/schema", response_model=Dict[str, List[Dict[str, str]]])
def get_semantic_schema(datasource_id: int, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
return MDLService.get_raw_schema(ds)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/semantic/{datasource_id}", response_model=MDLManifest)
def get_semantic_model(datasource_id: int, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
# Get or generate MDL
try:
mdl = MDLService.get_or_create_mdl(datasource_id)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/semantic/{datasource_id}", response_model=MDLManifest)
def update_semantic_model(datasource_id: int, mdl: MDLManifest, db: Session = Depends(get_db)):
# Check if datasource exists
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
MDLService.save_mdl(datasource_id, mdl)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/semantic/{datasource_id}/generate", response_model=MDLManifest)
def regenerate_semantic_model(datasource_id: int, request: Optional[GenerateMDLRequest] = None, db: Session = Depends(get_db)):
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
try:
selected_tables = request.selected_tables if request else None
selected_columns = request.selected_columns if request else None
mdl = MDLService.generate_default_mdl(
ds,
selected_tables=selected_tables,
selected_columns=selected_columns,
)
MDLService.save_mdl(datasource_id, mdl)
return mdl
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/semantic/{datasource_id}/models/{model_name}", response_model=ModelDetailResponse)
def get_model_detail(datasource_id: int, model_name: str, limit: int = 10, db: Session = Depends(get_db)):
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise HTTPException(status_code=404, detail="DataSource not found")
mdl = MDLService.get_or_create_mdl(datasource_id)
model = next((m for m in mdl.models if m.name == model_name), None)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
relationships = [
{
"name": rel.name,
"models": rel.models,
"joinType": rel.joinType,
"condition": rel.condition,
"properties": rel.properties,
}
for rel in mdl.relationships
if model_name in rel.models
]
preview_rows: List[Dict[str, Any]] = []
try:
connector = get_connector(ds)
table_name = model.tableReference.table if model.tableReference else model.name
query = f'SELECT * FROM "{table_name}" LIMIT {max(1, min(limit, 100))}'
raw = connector.execute_query(query)
preview_rows = _normalize_query_result(raw)
except Exception:
preview_rows = []
model_payload = {
"name": model.name,
"tableReference": model.tableReference.model_dump(by_alias=True) if model.tableReference else None,
"primaryKey": model.primaryKey,
"properties": model.properties,
"columns": [c.model_dump(by_alias=True) for c in model.columns],
}
return ModelDetailResponse(
model=model_payload,
relationships=relationships,
preview_rows=preview_rows,
)
+1 -1
View File
@@ -33,7 +33,7 @@ class ClickHouseConnector:
table = row[0]
if table not in schema:
schema[table] = []
schema[table].append(f"{row[1]} ({row[2]})")
schema[table].append({"name": row[1], "type": row[2]})
return schema
except Exception as e:
print(f"Error getting schema: {e}")
+68
View File
@@ -0,0 +1,68 @@
import duckdb
import pandas as pd
from typing import List, Dict, Any
import os
from app.core.files import resolve_upload_file_path
class CSVConnector:
def __init__(self, file_path: str):
self.file_path = file_path
if not os.path.exists(self.file_path):
raise FileNotFoundError(f"CSV file not found: {self.file_path}")
def _get_table_name(self) -> str:
# Normalize table name to be SQL safe-ish
base = os.path.splitext(os.path.basename(self.file_path))[0]
# Replace non-alphanumeric chars with underscore
safe_name = "".join([c if c.isalnum() else "_" for c in base])
# Ensure it doesn't start with a number
if safe_name and safe_name[0].isdigit():
safe_name = f"t_{safe_name}"
return safe_name
def execute_query(self, query: str) -> List[Dict[str, Any]]:
conn = duckdb.connect(":memory:")
table_name = self._get_table_name()
# Register the csv file as a view
# read_csv_auto is powerful
try:
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')")
# Execute the user query
# The query should rely on the table name provided in schema
df = conn.execute(query).df()
return df.to_dict(orient="records")
except Exception as e:
print(f"CSV Query Error: {e}")
raise e
finally:
conn.close()
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
table_name = self._get_table_name()
try:
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_csv_auto('{self.file_path}')")
# Get columns
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
# col[0] is name, col[1] is type
schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]}
return schema
except Exception as e:
print(f"Error getting schema: {e}")
return {}
finally:
conn.close()
def test_connection(self) -> bool:
try:
conn = duckdb.connect(":memory:")
conn.execute(f"SELECT * FROM read_csv_auto('{self.file_path}') LIMIT 1")
conn.close()
return True
except Exception as e:
print(f"CSV Connection Error: {e}")
return False
+5
View File
@@ -4,6 +4,7 @@ import functools
from app.connectors.postgres import PostgresConnector
from app.connectors.clickhouse import ClickHouseConnector
from app.connectors.parquet import ParquetConnector
from app.connectors.csv import CSVConnector
from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
@@ -37,6 +38,10 @@ def _get_cached_connector(ds_type: str, config_json: str):
elif ds_type == "parquet":
file_path = str(resolve_upload_file_path(config.get("file_path")))
return ParquetConnector(file_path=file_path)
elif ds_type == "csv":
file_path = str(resolve_upload_file_path(config.get("file_path")))
return CSVConnector(file_path=file_path)
else:
raise ValueError(f"Unsupported data source type: {ds_type}")
+2 -2
View File
@@ -31,7 +31,7 @@ class ParquetConnector:
finally:
conn.close()
def get_schema(self) -> Dict[str, List[str]]:
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
table_name = os.path.splitext(os.path.basename(self.file_path))[0]
conn.execute(f"CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM read_parquet('{self.file_path}')")
@@ -39,7 +39,7 @@ class ParquetConnector:
try:
# Get columns
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
schema = {table_name: [f"{col[0]} ({col[1]})" for col in columns]}
schema = {table_name: [{"name": col[0], "type": col[1]} for col in columns]}
return schema
except Exception as e:
print(f"Error getting schema: {e}")
+19 -1
View File
@@ -22,6 +22,9 @@ class PostgresConnector:
return [dict(row._mapping) for row in result]
def get_schema(self):
if self.engine.dialect.name == "sqlite":
return self._get_sqlite_schema()
query = """
SELECT table_name, column_name, data_type
FROM information_schema.columns
@@ -35,12 +38,27 @@ class PostgresConnector:
table = row['table_name']
if table not in schema:
schema[table] = []
schema[table].append(f"{row['column_name']} ({row['data_type']})")
schema[table].append({"name": row['column_name'], "type": row['data_type']})
return schema
except Exception as e:
print(f"Error getting schema: {e}")
return {}
def _get_sqlite_schema(self):
try:
from sqlalchemy import inspect
inspector = inspect(self.engine)
schema = {}
for table_name in inspector.get_table_names():
columns = []
for col in inspector.get_columns(table_name):
columns.append({"name": col['name'], "type": str(col['type'])})
schema[table_name] = columns
return schema
except Exception as e:
print(f"Error getting SQLite schema: {e}")
return {}
def test_connection(self) -> bool:
try:
with self.engine.connect() as connection:
+115
View File
@@ -0,0 +1,115 @@
from typing import List, Optional, Dict, Any, Union, Literal
from pydantic import BaseModel, Field
# Common Types
AccessControlOperator = Literal[
"EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN",
"GREATER_THAN_OR_EQUALS", "LESS_THAN_OR_EQUALS"
]
JoinType = Literal["ONE_TO_ONE", "ONE_TO_MANY", "MANY_TO_ONE", "MANY_TO_MANY"]
# Column Definitions
class SessionProperty(BaseModel):
name: str
required: bool
defaultExpr: Optional[str] = None
class AccessControlThreshold(BaseModel):
value: str
dataType: Literal["NUMERIC", "STRING"]
class ColumnAccessControl(BaseModel):
name: str
operator: AccessControlOperator
requiredProperties: List[SessionProperty]
threshold: Optional[AccessControlThreshold] = None
class Column(BaseModel):
name: str
type: str
relationship: Optional[str] = None
isCalculated: bool = False
notNull: bool = False
expression: Optional[str] = None
isHidden: bool = False
columnLevelAccessControl: Optional[ColumnAccessControl] = None
properties: Dict[str, str] = Field(default_factory=dict)
# Model Definitions
class TableReference(BaseModel):
catalog: Optional[str] = None
schema_: Optional[str] = Field(None, alias="schema")
table: str
class RowLevelAccessControl(BaseModel):
name: str
requiredProperties: List[SessionProperty]
condition: str
class Model(BaseModel):
name: str
tableReference: Optional[TableReference] = None
refSql: Optional[str] = None
baseObject: Optional[str] = None
columns: List[Column] = Field(default_factory=list)
primaryKey: Optional[str] = None
cached: bool = False
refreshTime: Optional[str] = None
rowLevelAccessControls: List[RowLevelAccessControl] = Field(default_factory=list)
properties: Dict[str, Any] = Field(default_factory=dict)
# Relationship Definitions
class Relationship(BaseModel):
name: str
models: List[str] # minItems: 2, maxItems: 2
joinType: JoinType
condition: str
properties: Dict[str, Any] = Field(default_factory=dict)
# Metric Definitions
class MetricTimeGrain(BaseModel):
name: str
refColumn: str
dateParts: List[str]
class Metric(BaseModel):
name: str
baseObject: str
dimension: List[Column] = Field(default_factory=list)
measure: List[Column] = Field(default_factory=list)
timeGrain: List[MetricTimeGrain] = Field(default_factory=list)
cached: bool = False
refreshTime: Optional[str] = None
properties: Dict[str, Any] = Field(default_factory=dict)
# View Definitions
class View(BaseModel):
name: str
statement: str
properties: Dict[str, Any] = Field(default_factory=dict)
# Enum Definitions
class EnumValue(BaseModel):
name: str
value: Optional[str] = None
properties: Dict[str, Any] = Field(default_factory=dict)
class EnumDefinition(BaseModel):
name: str
values: List[EnumValue]
properties: Dict[str, Any] = Field(default_factory=dict)
# Main Manifest
class MDLManifest(BaseModel):
catalog: str
schema_: str = Field(..., alias="schema") # 'schema' is a reserved word in Pydantic v1/Python, aliasing
dataSource: Optional[str] = None
models: List[Model] = Field(default_factory=list)
relationships: List[Relationship] = Field(default_factory=list)
metrics: List[Metric] = Field(default_factory=list)
views: List[View] = Field(default_factory=list)
enumDefinitions: List[EnumDefinition] = Field(default_factory=list)
class Config:
populate_by_name = True
+122
View File
@@ -0,0 +1,122 @@
import json
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
from app.models.datasource import DataSource
from app.schemas.mdl import MDLManifest, Model, Column, TableReference
from app.connectors.factory import get_connector
from app.database import SessionLocal
# Assuming running from backend/ directory
MDL_STORAGE_PATH = Path("data/mdl")
class MDLService:
@staticmethod
def _get_mdl_path(datasource_id: int) -> Path:
MDL_STORAGE_PATH.mkdir(parents=True, exist_ok=True)
return MDL_STORAGE_PATH / f"{datasource_id}.json"
@staticmethod
def get_raw_schema(datasource: DataSource) -> Dict[str, List[Dict[str, str]]]:
connector = get_connector(datasource)
try:
return connector.get_schema()
except Exception as e:
print(f"Error fetching schema for DS {datasource.id}: {e}")
return {}
@staticmethod
def generate_default_mdl(
datasource: DataSource,
selected_tables: Optional[List[str]] = None,
selected_columns: Optional[Dict[str, List[str]]] = None,
) -> MDLManifest:
raw_schema = MDLService.get_raw_schema(datasource)
models = []
for table_name, columns in raw_schema.items():
if selected_tables is not None and table_name not in selected_tables:
continue
model_cols = []
for col_info in columns:
if isinstance(col_info, dict):
name = col_info.get("name", "UNKNOWN")
type_ = col_info.get("type", "UNKNOWN")
elif isinstance(col_info, str):
# Fallback for old string format "name (type)"
if "(" in col_info and col_info.endswith(")"):
parts = col_info.rsplit(" (", 1)
if len(parts) == 2:
name = parts[0]
type_ = parts[1][:-1]
else:
name = col_info
type_ = "UNKNOWN"
else:
name = col_info
type_ = "UNKNOWN"
else:
name = str(col_info)
type_ = "UNKNOWN"
if selected_columns is not None:
allowed = selected_columns.get(table_name, [])
if allowed and name not in allowed:
continue
model_cols.append(Column(name=name, type=type_))
if not model_cols:
continue
models.append(Model(
name=table_name,
tableReference=TableReference(table=table_name),
columns=model_cols
))
return MDLManifest(
catalog="default",
schema="public", # Default schema, might need adjustment based on datasource config
dataSource=datasource.type.upper(),
models=models
)
@staticmethod
def get_mdl(datasource_id: int) -> Optional[MDLManifest]:
path = MDLService._get_mdl_path(datasource_id)
if path.exists():
try:
with open(path, "r") as f:
data = json.load(f)
# Pydantic v2 compatible
return MDLManifest.model_validate(data)
except Exception as e:
print(f"Error loading MDL for {datasource_id}: {e}")
return None
return None
@staticmethod
def save_mdl(datasource_id: int, mdl: MDLManifest):
path = MDLService._get_mdl_path(datasource_id)
with open(path, "w") as f:
f.write(mdl.model_dump_json(indent=2, by_alias=True))
@staticmethod
def get_or_create_mdl(datasource_id: int) -> MDLManifest:
mdl = MDLService.get_mdl(datasource_id)
if mdl:
return mdl
# Generate new
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == datasource_id).first()
if not ds:
raise ValueError(f"DataSource {datasource_id} not found")
mdl = MDLService.generate_default_mdl(ds)
MDLService.save_mdl(datasource_id, mdl)
return mdl
finally:
db.close()
View File
Binary file not shown.
+2 -1
View File
@@ -7,7 +7,7 @@ import asyncio
import json
from datetime import datetime
from app.api import upload, llm, skills, users, datasources, projects
from app.api import upload, llm, skills, users, datasources, projects, semantic
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.core.nanobot import nanobot_service
@@ -38,6 +38,7 @@ app.include_router(skills.router, prefix="/api/v1")
app.include_router(users.router, prefix="/api/v1")
app.include_router(projects.router, prefix="/api/v1")
app.include_router(datasources.router, prefix="/api/v1")
app.include_router(semantic.router, prefix="/api/v1")
@app.on_event("startup")
async def startup_event():