feat: add langfuse

This commit is contained in:
qixinbo
2026-03-31 00:18:32 +08:00
parent ed0075c910
commit 01524aaff5
11 changed files with 1034 additions and 330 deletions
+40 -30
View File
@@ -14,6 +14,7 @@ if str(NANOBOT_ROOT) not in sys.path:
from app.core.llm_provider import build_llm_provider
from app.schemas.chart import ChartGenerationResponse
from app.services.llm_cache import get_active_llm_config
from app.trace import build_error_attributes, trace_service
CHART_MAX_TOKENS = 700
CHART_TEMPERATURE = 0.2
@@ -140,6 +141,10 @@ CHART_EXAMPLES = """
"""
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
trace_attributes = {
"component": "chart_generation",
"rows": len(data),
}
active_config = get_active_llm_config()
if not active_config:
@@ -218,37 +223,42 @@ Language: Chinese (Simplified)
# 4. Call LLM
try:
response = await provider.chat(
messages=messages,
max_tokens=CHART_MAX_TOKENS,
temperature=CHART_TEMPERATURE,
reasoning_effort=CHART_REASONING_EFFORT,
)
content = response.content
# Clean up code blocks
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
content = content.strip()
result = json.loads(content)
# Post-process to fix common LLM hallucinations (translating field names)
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
encoding = result["chart_spec"].get("encoding", {})
for channel, enc_def in encoding.items():
if isinstance(enc_def, dict) and "field" in enc_def:
field = enc_def["field"]
# If field is not in columns, try to find a match or let it be (Vega will render empty)
# But if we can detect it was translated, we might not be able to fix it perfectly.
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
# but it's safer to just rely on the stricter prompt.
return ChartGenerationResponse(**result)
with trace_service.start_span(
"chart.generate",
attributes={
**trace_attributes,
"model": active_config.get("model"),
},
input_payload={"query": query, "columns": columns},
) as span:
response = await provider.chat(
messages=messages,
max_tokens=CHART_MAX_TOKENS,
temperature=CHART_TEMPERATURE,
reasoning_effort=CHART_REASONING_EFFORT,
)
content = response.content
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
content = content.strip()
result = json.loads(content)
chart_result = ChartGenerationResponse(**result)
span.set_attributes(
{
"chart.can_visualize": bool(chart_result.can_visualize),
"chart.type": chart_result.chart_type,
}
)
span.update(output={"chart_type": chart_result.chart_type})
return chart_result
except Exception as e:
with trace_service.start_span(
"chart.generate.error",
attributes={**trace_attributes, **build_error_attributes(e, stage="chart_generation")},
):
pass
return ChartGenerationResponse(
reasoning=f"Failed to generate chart configuration: {str(e)}",
can_visualize=False,
+86 -42
View File
@@ -30,6 +30,7 @@ from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService
from app.services.llm_cache import get_active_llm_config
from app.trace import trace_service
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -247,6 +248,12 @@ async def process_nl2sql(
await on_progress(content)
total_started = time.perf_counter()
trace_base_attributes = {
"component": "nl2sql",
"source": request.source,
"session_id": request.session_id,
"generate_chart": request.generate_chart,
}
# 1. Get the connector and schema
connector = None
schema = {}
@@ -404,15 +411,25 @@ Language: Chinese (Simplified)
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
with trace_service.start_span(
"nl2sql.llm_generation",
attributes={
**trace_base_attributes,
"exec_attempt": exec_attempt,
"retry_attempt": attempt,
"model": active_config.get("model"),
},
) as llm_span:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
llm_span.update(output={"finish_reason": getattr(response, "finish_reason", None)})
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
@@ -472,36 +489,42 @@ Language: Chinese (Simplified)
timeout_stage = "sql_execution"
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
upload_df = upload_payload["df"]
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
# Format results
formatted_results = []
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
with trace_service.start_span(
"nl2sql.sql_execution",
attributes={
**trace_base_attributes,
"exec_attempt": exec_attempt,
},
input_payload={"sql": sql_query},
) as sql_span:
if request.source == "upload":
if upload_df is None:
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
upload_df = upload_payload["df"]
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
formatted_results = []
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
formatted_results = []
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
else:
formatted_results = []
sql_span.set_attributes({"result_rows": len(formatted_results)})
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
break # Execution succeeded, break the retry loop
@@ -526,10 +549,21 @@ Language: Chinese (Simplified)
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
timeout_stage = "chart_generation"
chart_response = await asyncio.wait_for(
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
with trace_service.start_span(
"nl2sql.chart_generation",
attributes=trace_base_attributes,
input_payload={"query": request.query, "rows": len(formatted_results)},
) as chart_span:
chart_response = await asyncio.wait_for(
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
chart_span.set_attributes(
{
"chart.can_visualize": bool(getattr(chart_response, "can_visualize", False)),
"chart.type": getattr(chart_response, "chart_type", ""),
}
)
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
except asyncio.TimeoutError:
fallback_chart = ChartGenerationResponse(
@@ -542,5 +576,15 @@ Language: Chinese (Simplified)
except Exception as e:
pass # Ignore chart generation errors, return data only
with trace_service.start_span(
"nl2sql.completed",
attributes={
**trace_base_attributes,
"total_seconds": round(time.perf_counter() - total_started, 4),
"result_rows": len(formatted_results),
"has_chart": bool(chart_response),
},
):
pass
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
+87 -54
View File
@@ -36,6 +36,7 @@ from app.services.llm_cache import get_llm_configs, get_active_llm_config
from app.services.web_search_config_store import get_web_search_config
from app.core.data_root import get_workspace_root
from app.trace import build_error_attributes, build_usage_attributes, trace_service
class NanobotIntegration:
def __init__(self):
@@ -385,68 +386,100 @@ class NanobotIntegration:
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
):
if not self.agent:
self.initialize()
if not self._started:
await self.start()
if project_id is None:
from app.core.session_alias_store import session_alias_store
alias_meta = session_alias_store.get_alias_meta(session_id)
if alias_meta and alias_meta.get("project_id") is not None:
project_id = alias_meta.get("project_id")
agent_to_use = self.agent
need_custom_agent = False
target_config = None
span_attributes = {
"session_id": session_id,
"project_id": project_id,
"model_id": model_id,
"component": "nanobot.process_message",
}
with trace_service.start_span(
"nanobot.process_message",
attributes=span_attributes,
input_payload={"message": message},
) as root_span:
try:
if not self.agent:
self.initialize()
if not self._started:
await self.start()
selected_model_id = self._normalize_model_id(model_id)
if selected_model_id:
llm_configs = get_llm_configs()
target_config = next(
(item for item in llm_configs if self._normalize_model_id(item.get("id")) == selected_model_id),
None,
)
if project_id is None:
from app.core.session_alias_store import session_alias_store
if target_config is None:
active_config = get_active_llm_config()
if active_config and active_config.get("id"):
selected_model_id = self._normalize_model_id(active_config.get("id"))
target_config = active_config
alias_meta = session_alias_store.get_alias_meta(session_id)
if alias_meta and alias_meta.get("project_id") is not None:
project_id = alias_meta.get("project_id")
root_span.set_attributes({"project_id": project_id})
if target_config and self._need_custom_agent_for_target(target_config):
need_custom_agent = True
agent_to_use = self.agent
need_custom_agent = False
target_config = None
if project_id is not None:
need_custom_agent = True
selected_model_id = self._normalize_model_id(model_id)
if selected_model_id:
llm_configs = get_llm_configs()
target_config = next(
(item for item in llm_configs if self._normalize_model_id(item.get("id")) == selected_model_id),
None,
)
if need_custom_agent:
agent_to_use = await self._get_or_create_model_agent(selected_model_id, target_config, project_id)
if target_config is None:
active_config = get_active_llm_config()
if active_config and active_config.get("id"):
selected_model_id = self._normalize_model_id(active_config.get("id"))
target_config = active_config
full_message = message
# We no longer inject the full skill content into the user's message here,
# because the skill is already available to the agent via its workspace/tools.
# The routing instructions (System Prompt) injected in main.py are sufficient
# to guide the agent to use the selected skills.
if target_config and self._need_custom_agent_for_target(target_config):
need_custom_agent = True
if project_id is not None:
need_custom_agent = True
session = agent_to_use.sessions.get_or_create(session_id)
normalized_messages = self._normalize_session_messages(session.messages)
if len(normalized_messages) != len(session.messages):
session.messages = normalized_messages
agent_to_use.sessions.save(session)
with trace_service.start_span(
"nanobot.resolve_agent",
attributes={
"session_id": session_id,
"project_id": project_id,
"selected_model_id": selected_model_id,
"custom_agent": need_custom_agent,
},
):
if need_custom_agent:
agent_to_use = await self._get_or_create_model_agent(selected_model_id, target_config, project_id)
response = await agent_to_use.process_direct(
full_message,
session_key=session_id,
channel="api",
chat_id=session_id,
on_progress=on_progress,
on_stream=on_stream,
)
usage = self._normalize_usage(getattr(agent_to_use, "_last_usage", None))
if usage:
self._last_usage_by_session[session_id] = usage
return self._extract_response_text(response)
session = agent_to_use.sessions.get_or_create(session_id)
normalized_messages = self._normalize_session_messages(session.messages)
if len(normalized_messages) != len(session.messages):
session.messages = normalized_messages
agent_to_use.sessions.save(session)
with trace_service.start_span(
"nanobot.process_direct",
attributes={
"session_id": session_id,
"model": getattr(agent_to_use, "model", None),
},
) as direct_span:
response = await agent_to_use.process_direct(
message,
session_key=session_id,
channel="api",
chat_id=session_id,
on_progress=on_progress,
on_stream=on_stream,
)
usage = self._normalize_usage(getattr(agent_to_use, "_last_usage", None))
if usage:
self._last_usage_by_session[session_id] = usage
direct_span.set_attributes(build_usage_attributes(usage))
root_span.set_attributes(build_usage_attributes(usage))
text = self._extract_response_text(response)
direct_span.update(output={"content": text})
root_span.update(output={"content": text})
return text
except Exception as exc:
root_span.set_attributes(build_error_attributes(exc, stage="nanobot_process_message"))
root_span.record_error(exc, stage="nanobot_process_message")
raise
def _normalize_session_messages(self, messages: List[Any]) -> List[dict[str, Any]]:
normalized: List[dict[str, Any]] = []
+15
View File
@@ -0,0 +1,15 @@
from app.trace.attributes import (
build_chat_trace_attributes,
build_error_attributes,
build_usage_attributes,
sanitize_attributes,
)
from app.trace.service import trace_service
__all__ = [
"trace_service",
"sanitize_attributes",
"build_chat_trace_attributes",
"build_usage_attributes",
"build_error_attributes",
]
+65
View File
@@ -0,0 +1,65 @@
from __future__ import annotations
from typing import Any, Dict, Mapping, Optional
def sanitize_attributes(attributes: Optional[Mapping[str, Any]]) -> Dict[str, Any]:
if not attributes:
return {}
normalized: Dict[str, Any] = {}
for key, value in attributes.items():
if value is None:
continue
name = str(key).strip()
if not name:
continue
if isinstance(value, (str, int, float, bool)):
normalized[name] = value
continue
normalized[name] = str(value)
return normalized
def build_chat_trace_attributes(
*,
session_id: str,
project_id: Optional[int],
model_id: Optional[str],
route_mode: str,
source: str,
knowledge_base_id: Optional[str],
) -> Dict[str, Any]:
return sanitize_attributes(
{
"session_id": session_id,
"project_id": project_id,
"model_id": model_id,
"route_mode": route_mode,
"source": source,
"knowledge_base_id": knowledge_base_id,
"component": "chat_stream",
}
)
def build_usage_attributes(usage: Optional[Mapping[str, Any]]) -> Dict[str, Any]:
if not usage:
return {}
return sanitize_attributes(
{
"usage.prompt_tokens": usage.get("prompt_tokens"),
"usage.completion_tokens": usage.get("completion_tokens"),
"usage.total_tokens": usage.get("total_tokens"),
}
)
def build_error_attributes(exc: Exception, *, stage: str) -> Dict[str, Any]:
return sanitize_attributes(
{
"error": True,
"error.stage": stage,
"error.type": exc.__class__.__name__,
"error.message": str(exc),
}
)
+187
View File
@@ -0,0 +1,187 @@
from __future__ import annotations
import logging
import os
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Mapping, Optional
from app.trace.attributes import sanitize_attributes
logger = logging.getLogger(__name__)
class _NoopSpan:
def set_attributes(self, _attributes: Optional[Mapping[str, Any]] = None) -> None:
return None
def update(self, **_kwargs: Any) -> None:
return None
def update_trace(self, **_kwargs: Any) -> None:
return None
def record_error(self, _exc: Exception, *, stage: str = "unknown") -> None:
return None
class _SpanAdapter:
def __init__(self, raw_span: Any) -> None:
self._raw_span = raw_span
def set_attributes(self, attributes: Optional[Mapping[str, Any]] = None) -> None:
payload = sanitize_attributes(attributes)
if not payload:
return
set_attribute = getattr(self._raw_span, "set_attribute", None)
if callable(set_attribute):
for key, value in payload.items():
set_attribute(key, value)
return
update = getattr(self._raw_span, "update", None)
if callable(update):
update(metadata=payload)
def update(self, **kwargs: Any) -> None:
update = getattr(self._raw_span, "update", None)
if callable(update):
update(**kwargs)
def update_trace(self, **kwargs: Any) -> None:
update_trace = getattr(self._raw_span, "update_trace", None)
if callable(update_trace):
update_trace(**kwargs)
def record_error(self, exc: Exception, *, stage: str = "unknown") -> None:
self.set_attributes(
{
"error": True,
"error.stage": stage,
"error.type": exc.__class__.__name__,
"error.message": str(exc),
}
)
self.update(level="ERROR", status_message=str(exc))
class TraceService:
def __init__(self) -> None:
self._client: Any = None
self._enabled = False
self._initialized = False
self._httpx_instrumented = False
@property
def enabled(self) -> bool:
return self._enabled
@property
def initialized(self) -> bool:
return self._initialized
def _read_config(self) -> Dict[str, Optional[str]]:
return {
"public_key": os.getenv("LANGFUSE_PUBLIC_KEY"),
"secret_key": os.getenv("LANGFUSE_SECRET_KEY"),
"base_url": os.getenv("LANGFUSE_BASE_URL", "http://localhost:3000"),
}
def initialize(self) -> bool:
if self._initialized:
return self._enabled
self._initialized = True
cfg = self._read_config()
if not cfg["public_key"] or not cfg["secret_key"]:
logger.info("Langfuse tracing disabled: missing LANGFUSE_PUBLIC_KEY or LANGFUSE_SECRET_KEY")
return False
try:
from langfuse import Langfuse
except Exception as exc:
logger.warning("Langfuse tracing disabled: SDK import failed: %s", exc)
return False
try:
self._client = Langfuse(
public_key=cfg["public_key"],
secret_key=cfg["secret_key"],
host=cfg["base_url"],
)
self._enabled = True
logger.info("Langfuse tracing enabled, host=%s", cfg["base_url"])
except Exception as exc:
logger.warning("Langfuse tracing initialization failed, fallback to no-op: %s", exc)
self._client = None
self._enabled = False
return False
try:
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
HTTPXClientInstrumentor().instrument()
self._httpx_instrumented = True
except Exception as exc:
logger.warning("HTTPX OTEL instrumentation unavailable: %s", exc)
return True
def shutdown(self) -> None:
if self._enabled and self._client:
flush = getattr(self._client, "flush", None)
if callable(flush):
try:
flush()
except Exception:
pass
close = getattr(self._client, "shutdown", None)
if callable(close):
try:
close()
except Exception:
pass
if self._httpx_instrumented:
try:
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
HTTPXClientInstrumentor().uninstrument()
except Exception:
pass
self._client = None
self._enabled = False
self._initialized = False
self._httpx_instrumented = False
@contextmanager
def start_span(
self,
name: str,
*,
attributes: Optional[Mapping[str, Any]] = None,
input_payload: Optional[Any] = None,
) -> Iterator[_SpanAdapter | _NoopSpan]:
if not self._enabled or not self._client:
yield _NoopSpan()
return
try:
start_observation = getattr(self._client, "start_as_current_observation", None)
if callable(start_observation):
ctx = start_observation(name=name, as_type="span")
else:
start_span = getattr(self._client, "start_as_current_span", None)
if not callable(start_span):
yield _NoopSpan()
return
ctx = start_span(name=name)
except Exception:
yield _NoopSpan()
return
try:
with ctx as raw_span:
span = _SpanAdapter(raw_span)
if attributes:
span.set_attributes(attributes)
if input_payload is not None:
span.update(input=input_payload)
yield span
except Exception as exc:
logger.warning("Langfuse span failure (%s): %s", name, exc)
yield _NoopSpan()
trace_service = TraceService()