refactor: convert to nl2sql skills
This commit is contained in:
+50
-17
@@ -191,14 +191,24 @@ def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[
|
||||
with _cache_lock:
|
||||
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
||||
|
||||
def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
async def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
cache_key = _build_schema_cache_key(source, connector)
|
||||
now = time.time()
|
||||
with _cache_lock:
|
||||
cached = _connection_cache.get(cache_key)
|
||||
if cached and now < cached["expires_at"]:
|
||||
return bool(cached["ok"])
|
||||
ok = connector.test_connection()
|
||||
|
||||
# Run synchronous test_connection in a separate thread to avoid blocking event loop
|
||||
try:
|
||||
ok = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.test_connection),
|
||||
timeout=10.0
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Connection test failed or timed out: {e}")
|
||||
ok = False
|
||||
|
||||
with _cache_lock:
|
||||
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
||||
return ok
|
||||
@@ -224,7 +234,7 @@ async def process_nl2sql(
|
||||
elif request.source == "upload":
|
||||
try:
|
||||
upload_started = time.perf_counter()
|
||||
upload_payload = _get_upload_payload(request.file_url)
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
schema = upload_payload["schema"]
|
||||
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
|
||||
@@ -234,14 +244,21 @@ async def process_nl2sql(
|
||||
try:
|
||||
ds_started = time.perf_counter()
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
connector = get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _get_ds_connector():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return None
|
||||
return get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
connector = await asyncio.to_thread(_get_ds_connector)
|
||||
if not connector:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
|
||||
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
|
||||
except ValueError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||
@@ -258,20 +275,35 @@ async def process_nl2sql(
|
||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||
else:
|
||||
conn_started = time.perf_counter()
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
if not await _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
|
||||
schema_started = time.perf_counter()
|
||||
schema = connector.get_schema()
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||
|
||||
if connector and not schema:
|
||||
retry_started = time.perf_counter()
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
if not await _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
||||
|
||||
@@ -282,7 +314,7 @@ async def process_nl2sql(
|
||||
if request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
mdl = MDLService.get_mdl(ds_id)
|
||||
mdl = await asyncio.to_thread(MDLService.get_mdl, ds_id)
|
||||
if mdl:
|
||||
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
|
||||
|
||||
@@ -392,7 +424,8 @@ Language: Chinese (Simplified)
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
timeout_stage = "sql_execution"
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Awaitable, Dict, Optional
|
||||
|
||||
# The current session ID processing the request
|
||||
current_session_id: ContextVar[str] = ContextVar("current_session_id", default="")
|
||||
|
||||
# A callback to send progress updates to the frontend during tool execution
|
||||
current_progress_callback: ContextVar[Optional[Callable[[str], Awaitable[None]]]] = ContextVar("current_progress_callback", default=None)
|
||||
|
||||
# A payload dictionary to store visualization results generated by tools
|
||||
# This will be picked up by the stream handler and sent to the frontend
|
||||
current_viz_data: ContextVar[Optional[Dict[str, Any]]] = ContextVar("current_viz_data", default=None)
|
||||
|
||||
# Store the last queried data so the Visualization Tool can access it
|
||||
current_data: ContextVar[Optional[list]] = ContextVar("current_data", default=None)
|
||||
|
||||
# The data source requested by the user or bound to the session
|
||||
current_data_source: ContextVar[str] = ContextVar("current_data_source", default="postgres")
|
||||
|
||||
# Any file URL attached to the request
|
||||
current_file_url: ContextVar[Optional[str]] = ContextVar("current_file_url", default=None)
|
||||
@@ -85,6 +85,14 @@ class NanobotIntegration:
|
||||
channels_config=self.config.channels,
|
||||
)
|
||||
|
||||
self._register_custom_tools(self.agent)
|
||||
|
||||
def _register_custom_tools(self, agent: AgentLoop):
|
||||
from app.tools.nl2sql import NL2SQLTool
|
||||
from app.tools.visualization import VisualizationTool
|
||||
agent.tools.register(NL2SQLTool())
|
||||
agent.tools.register(VisualizationTool())
|
||||
|
||||
def _make_provider(self, config: Config):
|
||||
# Logic adapted from nanobot/cli/commands.py
|
||||
model = config.agents.defaults.model
|
||||
@@ -195,6 +203,7 @@ class NanobotIntegration:
|
||||
provider_name=target_config.get("provider"),
|
||||
)
|
||||
agent = self._build_agent_for_provider(provider)
|
||||
self._register_custom_tools(agent)
|
||||
self._model_agent_cache[model_id] = agent
|
||||
return agent
|
||||
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
from app.context import current_progress_callback, current_viz_data, current_data_source, current_file_url, current_data
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
||||
chart = nl2sql_result.chart
|
||||
payload = {
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump() if chart else None,
|
||||
"error": nl2sql_result.error,
|
||||
}
|
||||
return jsonable_encoder(payload)
|
||||
|
||||
class NL2SQLTool(Tool):
|
||||
"""
|
||||
Tool for translating natural language queries into SQL, executing them,
|
||||
and optionally generating visualizations.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "nl2sql"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Query the connected database or data source using natural language. "
|
||||
"Use this tool when the user asks to query, analyze, aggregate, or fetch data from the database. "
|
||||
"Set generate_chart=True if the user also wants to visualize or plot the data."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The natural language query describing what data to fetch or analyze.",
|
||||
},
|
||||
"generate_chart": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to automatically generate a visualization chart for the result. Default is False.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
generate_chart = kwargs.get("generate_chart", False)
|
||||
|
||||
# Get context
|
||||
source = current_data_source.get()
|
||||
file_url = current_file_url.get()
|
||||
on_progress = current_progress_callback.get()
|
||||
|
||||
request = NL2SQLRequest(
|
||||
query=query,
|
||||
source=source,
|
||||
file_url=file_url,
|
||||
generate_chart=generate_chart,
|
||||
)
|
||||
|
||||
try:
|
||||
# Call the core logic
|
||||
result = await process_nl2sql(request, on_progress=on_progress)
|
||||
|
||||
if result.error:
|
||||
return f"Error executing query: {result.error}"
|
||||
|
||||
# Save the result data to context for potential later use by VisualizationTool
|
||||
if result.result:
|
||||
current_data.set(result.result)
|
||||
|
||||
# Save visualization payload to context so the chat stream can pick it up
|
||||
viz_payload = _build_sql_chart_viz(result)
|
||||
existing_viz = current_viz_data.get()
|
||||
if isinstance(existing_viz, dict):
|
||||
existing_viz.clear()
|
||||
existing_viz.update(viz_payload)
|
||||
current_viz_data.set(existing_viz)
|
||||
else:
|
||||
current_viz_data.set(viz_payload)
|
||||
|
||||
# Build a summary string for the Agent to read
|
||||
row_count = len(result.result) if result.result else 0
|
||||
|
||||
summary_parts = [f"Successfully executed SQL query."]
|
||||
summary_parts.append(f"SQL: {result.sql}")
|
||||
summary_parts.append(f"Rows returned: {row_count}")
|
||||
|
||||
if generate_chart:
|
||||
if result.chart and result.chart.can_visualize:
|
||||
summary_parts.append("Chart was successfully generated.")
|
||||
if result.chart.reasoning:
|
||||
summary_parts.append(f"Chart Reasoning: {result.chart.reasoning}")
|
||||
else:
|
||||
summary_parts.append("Requested a chart, but the data was not suitable for visualization.")
|
||||
|
||||
summary_parts.append("\nSample data (first 5 rows):")
|
||||
sample = result.result[:5] if result.result else []
|
||||
summary_parts.append(json.dumps(jsonable_encoder(sample), ensure_ascii=False))
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NL2SQL Tool error: {e}", exc_info=True)
|
||||
return f"An unexpected error occurred during NL2SQL execution: {str(e)}"
|
||||
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.agent.chart import generate_chart
|
||||
from app.context import current_data, current_viz_data, current_progress_callback
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VisualizationTool(Tool):
|
||||
"""
|
||||
Tool for generating a visualization (chart) from existing data.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "visualization"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Generate a chart or visualization based on the most recently queried data. "
|
||||
"Use this tool when the user asks to plot, visualize, or create a chart from data that has already been retrieved. "
|
||||
"Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The user's request describing how they want the data visualized (e.g., 'plot sales by month as a bar chart').",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
data = current_data.get()
|
||||
on_progress = current_progress_callback.get()
|
||||
|
||||
if not data:
|
||||
return "Error: No data available to visualize. Please query the data first using the nl2sql tool."
|
||||
|
||||
try:
|
||||
if on_progress:
|
||||
await on_progress("正在分析数据特征并生成可视化方案...")
|
||||
|
||||
chart_response = await generate_chart(data, query)
|
||||
|
||||
if chart_response.can_visualize:
|
||||
# Build the viz payload (similar to NL2SQL, but without the SQL part)
|
||||
# We reuse the previous viz_data if it exists (to keep SQL), or create a new one
|
||||
existing_viz = current_viz_data.get() or {}
|
||||
|
||||
viz_payload = {
|
||||
"sql": existing_viz.get("sql", ""),
|
||||
"result": data,
|
||||
"chart": chart_response.model_dump(),
|
||||
"error": None,
|
||||
}
|
||||
encoded_viz = jsonable_encoder(viz_payload)
|
||||
if isinstance(existing_viz, dict):
|
||||
existing_viz.clear()
|
||||
existing_viz.update(encoded_viz)
|
||||
current_viz_data.set(existing_viz)
|
||||
else:
|
||||
current_viz_data.set(encoded_viz)
|
||||
|
||||
return f"Successfully generated a {chart_response.chart_type} chart.\nReasoning: {chart_response.reasoning}"
|
||||
else:
|
||||
return f"Could not generate a chart: {chart_response.reasoning}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Visualization Tool error: {e}", exc_info=True)
|
||||
return f"An error occurred while generating the visualization: {str(e)}"
|
||||
Reference in New Issue
Block a user