feat: add modelling layer

This commit is contained in:
qixinbo
2026-03-16 22:18:23 +08:00
parent a1a855a126
commit 720c30a893
16 changed files with 1115 additions and 106 deletions
+37 -4
View File
@@ -24,6 +24,7 @@ from app.agent.chart import generate_chart
from app.database import SessionLocal
from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -116,11 +117,11 @@ def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
return pd.read_parquet(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[Dict[str, str]]]:
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
schema = {"uploaded_file": [{"name": col[0], "type": col[1]} for col in columns]}
conn.close()
return schema
@@ -167,7 +168,7 @@ def _build_schema_cache_key(source: str, connector: Any) -> str:
)
return source
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[Dict[str, str]]]]:
key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
@@ -176,7 +177,7 @@ def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[s
return cached["schema"]
return None
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[str, str]]]) -> None:
key = _build_schema_cache_key(source, connector)
with _cache_lock:
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
@@ -247,6 +248,37 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
schema_str = json.dumps(schema, indent=2)
# Try to load MDL context
mdl_context = ""
if request.source.startswith("ds:"):
try:
ds_id = int(request.source.split(":")[1])
mdl = MDLService.get_mdl(ds_id)
if mdl:
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
mdl_lines.append("MODELS:")
for model in mdl.models:
table_ref = model.tableReference.table if model.tableReference else model.name
desc = f" - Description: {model.properties.get('description', '')}" if model.properties.get('description') else ""
mdl_lines.append(f"- Model: {model.name} (Table: {table_ref}){desc}")
if model.columns:
mdl_lines.append(" Columns:")
for col in model.columns:
col_desc = f" ({col.properties.get('description')})" if col.properties.get('description') else ""
expr = f" [Calculated: {col.expression}]" if col.isCalculated else ""
mdl_lines.append(f" - {col.name} ({col.type}){col_desc}{expr}")
if mdl.relationships:
mdl_lines.append("\nRELATIONSHIPS:")
for rel in mdl.relationships:
mdl_lines.append(f"- {rel.name}: {rel.joinType} between {rel.models} ON {rel.condition}")
mdl_context = "\n".join(mdl_lines)
except Exception as e:
print(f"Failed to load MDL: {e}")
# 2. Get the active LLM config
llm_configs = load_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
@@ -270,6 +302,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
user_prompt = f"""
### DATABASE SCHEMA ###
{schema_str}
{mdl_context}
### INPUTS ###
User's Question: {request.query}