feat: add langfuse
This commit is contained in:
+86
-42
@@ -30,6 +30,7 @@ from app.models.datasource import DataSource
|
||||
from app.core.files import resolve_upload_file_path
|
||||
from app.services.mdl import MDLService
|
||||
from app.services.llm_cache import get_active_llm_config
|
||||
from app.trace import trace_service
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -247,6 +248,12 @@ async def process_nl2sql(
|
||||
await on_progress(content)
|
||||
|
||||
total_started = time.perf_counter()
|
||||
trace_base_attributes = {
|
||||
"component": "nl2sql",
|
||||
"source": request.source,
|
||||
"session_id": request.session_id,
|
||||
"generate_chart": request.generate_chart,
|
||||
}
|
||||
# 1. Get the connector and schema
|
||||
connector = None
|
||||
schema = {}
|
||||
@@ -404,15 +411,25 @@ Language: Chinese (Simplified)
|
||||
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
with trace_service.start_span(
|
||||
"nl2sql.llm_generation",
|
||||
attributes={
|
||||
**trace_base_attributes,
|
||||
"exec_attempt": exec_attempt,
|
||||
"retry_attempt": attempt,
|
||||
"model": active_config.get("model"),
|
||||
},
|
||||
) as llm_span:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
llm_span.update(output={"finish_reason": getattr(response, "finish_reason", None)})
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
@@ -472,36 +489,42 @@ Language: Chinese (Simplified)
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
with trace_service.start_span(
|
||||
"nl2sql.sql_execution",
|
||||
attributes={
|
||||
**trace_base_attributes,
|
||||
"exec_attempt": exec_attempt,
|
||||
},
|
||||
input_payload={"sql": sql_query},
|
||||
) as sql_span:
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
formatted_results = []
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
formatted_results = []
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
else:
|
||||
formatted_results = []
|
||||
sql_span.set_attributes({"result_rows": len(formatted_results)})
|
||||
|
||||
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
|
||||
break # Execution succeeded, break the retry loop
|
||||
@@ -526,10 +549,21 @@ Language: Chinese (Simplified)
|
||||
chart_started = time.perf_counter()
|
||||
await emit_progress("正在生成可视化方案")
|
||||
timeout_stage = "chart_generation"
|
||||
chart_response = await asyncio.wait_for(
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
with trace_service.start_span(
|
||||
"nl2sql.chart_generation",
|
||||
attributes=trace_base_attributes,
|
||||
input_payload={"query": request.query, "rows": len(formatted_results)},
|
||||
) as chart_span:
|
||||
chart_response = await asyncio.wait_for(
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
chart_span.set_attributes(
|
||||
{
|
||||
"chart.can_visualize": bool(getattr(chart_response, "can_visualize", False)),
|
||||
"chart.type": getattr(chart_response, "chart_type", ""),
|
||||
}
|
||||
)
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
except asyncio.TimeoutError:
|
||||
fallback_chart = ChartGenerationResponse(
|
||||
@@ -542,5 +576,15 @@ Language: Chinese (Simplified)
|
||||
except Exception as e:
|
||||
pass # Ignore chart generation errors, return data only
|
||||
|
||||
with trace_service.start_span(
|
||||
"nl2sql.completed",
|
||||
attributes={
|
||||
**trace_base_attributes,
|
||||
"total_seconds": round(time.perf_counter() - total_started, 4),
|
||||
"result_rows": len(formatted_results),
|
||||
"has_chart": bool(chart_response),
|
||||
},
|
||||
):
|
||||
pass
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
|
||||
Reference in New Issue
Block a user