feat: add langfuse

This commit is contained in:
qixinbo
2026-03-31 00:18:32 +08:00
parent ed0075c910
commit 01524aaff5
11 changed files with 1034 additions and 330 deletions
+86 -42
View File
@@ -30,6 +30,7 @@ from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService
from app.services.llm_cache import get_active_llm_config
from app.trace import trace_service
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -247,6 +248,12 @@ async def process_nl2sql(
await on_progress(content)
total_started = time.perf_counter()
trace_base_attributes = {
"component": "nl2sql",
"source": request.source,
"session_id": request.session_id,
"generate_chart": request.generate_chart,
}
# 1. Get the connector and schema
connector = None
schema = {}
@@ -404,15 +411,25 @@ Language: Chinese (Simplified)
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
with trace_service.start_span(
"nl2sql.llm_generation",
attributes={
**trace_base_attributes,
"exec_attempt": exec_attempt,
"retry_attempt": attempt,
"model": active_config.get("model"),
},
) as llm_span:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
llm_span.update(output={"finish_reason": getattr(response, "finish_reason", None)})
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
@@ -472,36 +489,42 @@ Language: Chinese (Simplified)
timeout_stage = "sql_execution"
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
upload_df = upload_payload["df"]
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
# Format results
formatted_results = []
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
with trace_service.start_span(
"nl2sql.sql_execution",
attributes={
**trace_base_attributes,
"exec_attempt": exec_attempt,
},
input_payload={"sql": sql_query},
) as sql_span:
if request.source == "upload":
if upload_df is None:
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
upload_df = upload_payload["df"]
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
formatted_results = []
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
formatted_results = []
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
else:
formatted_results = []
sql_span.set_attributes({"result_rows": len(formatted_results)})
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
break # Execution succeeded, break the retry loop
@@ -526,10 +549,21 @@ Language: Chinese (Simplified)
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
timeout_stage = "chart_generation"
chart_response = await asyncio.wait_for(
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
with trace_service.start_span(
"nl2sql.chart_generation",
attributes=trace_base_attributes,
input_payload={"query": request.query, "rows": len(formatted_results)},
) as chart_span:
chart_response = await asyncio.wait_for(
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
chart_span.set_attributes(
{
"chart.can_visualize": bool(getattr(chart_response, "can_visualize", False)),
"chart.type": getattr(chart_response, "chart_type", ""),
}
)
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
except asyncio.TimeoutError:
fallback_chart = ChartGenerationResponse(
@@ -542,5 +576,15 @@ Language: Chinese (Simplified)
except Exception as e:
pass # Ignore chart generation errors, return data only
with trace_service.start_span(
"nl2sql.completed",
attributes={
**trace_base_attributes,
"total_seconds": round(time.perf_counter() - total_started, 4),
"result_rows": len(formatted_results),
"has_chart": bool(chart_response),
},
):
pass
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)