Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.context import current_data_source, current_file_url, current_progress_callback
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
# Import schema logic from nl2sql
|
||||
from app.agent.nl2sql import (
|
||||
_get_cached_schema,
|
||||
_set_cached_schema,
|
||||
_check_connection_with_cache,
|
||||
_get_upload_payload
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GetDatabaseSchemaTool(Tool):
|
||||
"""
|
||||
Tool for fetching the database schema directly without SQL generation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_database_schema"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the structural schema of the currently connected database or data source. "
|
||||
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
|
||||
"'show me the database structure', 'what are the columns in table X', etc. "
|
||||
"It directly returns the schema without generating SQL."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
source = current_data_source.get()
|
||||
file_url = current_file_url.get()
|
||||
on_progress = current_progress_callback.get()
|
||||
|
||||
async def emit_progress(msg: str):
|
||||
if on_progress:
|
||||
await on_progress(msg)
|
||||
|
||||
await emit_progress("正在获取数据源结构...")
|
||||
|
||||
connector = None
|
||||
schema = {}
|
||||
|
||||
if not source:
|
||||
return "Error: No data source connected."
|
||||
|
||||
if source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif source == "clickhouse":
|
||||
connector = clickhouse_connector
|
||||
elif source == "upload":
|
||||
try:
|
||||
payload = await asyncio.to_thread(_get_upload_payload, file_url)
|
||||
schema = payload["schema"]
|
||||
await emit_progress("文件 Schema 获取完成")
|
||||
except Exception as e:
|
||||
return f"Failed to get upload schema: {e}"
|
||||
elif source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(source.split(":")[1])
|
||||
def _get_ds_connector():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds: return None
|
||||
return get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
connector = await asyncio.to_thread(_get_ds_connector)
|
||||
if not connector:
|
||||
return f"Data source not found: {source}"
|
||||
except Exception as e:
|
||||
return f"Failed to load data source: {e}"
|
||||
else:
|
||||
return f"Unsupported data source: {source}"
|
||||
|
||||
if connector:
|
||||
cached_schema = _get_cached_schema(source, connector)
|
||||
if cached_schema is not None:
|
||||
schema = cached_schema
|
||||
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
|
||||
else:
|
||||
if not await _check_connection_with_cache(source, connector):
|
||||
return f"Failed to connect to {source}"
|
||||
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=120.0
|
||||
)
|
||||
_set_cached_schema(source, connector, schema)
|
||||
await emit_progress(f"成功获取 {len(schema)} 张表结构")
|
||||
except asyncio.TimeoutError:
|
||||
return "Failed to fetch schema: Timeout after 120 seconds."
|
||||
except Exception as e:
|
||||
return f"Failed to fetch schema: {e}"
|
||||
|
||||
# Format the output for the LLM to make it readable and token-efficient
|
||||
lines = []
|
||||
for table_name, table_info in schema.items():
|
||||
if isinstance(table_info, list):
|
||||
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
|
||||
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
|
||||
lines.append(f"Table: {table_name}\n Columns: {cols}")
|
||||
elif isinstance(table_info, dict):
|
||||
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
|
||||
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
|
||||
pks = ", ".join(table_info.get("primary_keys", []))
|
||||
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
|
||||
|
||||
return "\n\n".join(lines) if lines else "No tables found in schema."
|
||||
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
from app.context import current_knowledge_base_id
|
||||
from app.services.knowledge_index import knowledge_index_service
|
||||
|
||||
|
||||
class KnowledgeBaseRetrieveTool(Tool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "knowledge_retrieve"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Retrieve relevant context from the selected knowledge base to answer user questions."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "User question or retrieval query.",
|
||||
},
|
||||
"knowledge_base_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge base id, defaults to current session setting.",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of returned chunks.",
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
query = (kwargs.get("query") or "").strip()
|
||||
if not query:
|
||||
return "Query is required."
|
||||
kb_id = (kwargs.get("knowledge_base_id") or current_knowledge_base_id.get() or "").strip()
|
||||
if not kb_id:
|
||||
return "No knowledge base is selected in this session."
|
||||
top_k = kwargs.get("top_k")
|
||||
try:
|
||||
result = knowledge_index_service.search(kb_id=kb_id, query=query, top_k=top_k)
|
||||
except ValueError as exc:
|
||||
return str(exc)
|
||||
payload = {
|
||||
"knowledge_base_id": kb_id,
|
||||
"answer": result.get("answer", ""),
|
||||
"hits": result.get("hits", []),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
@@ -0,0 +1,127 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
from app.context import current_progress_callback, current_viz_data, current_data_source, current_file_url, current_data
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _normalize_query(value: str) -> str:
|
||||
return re.sub(r"\s+", "", (value or "")).lower()
|
||||
|
||||
|
||||
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse, query: str) -> dict:
|
||||
chart = nl2sql_result.chart
|
||||
payload = {
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
|
||||
"chart_query": query,
|
||||
"chart_query_normalized": _normalize_query(query),
|
||||
"chart_generated_by": "nl2sql",
|
||||
"error": nl2sql_result.error,
|
||||
}
|
||||
return jsonable_encoder(payload)
|
||||
|
||||
class NL2SQLTool(Tool):
|
||||
"""
|
||||
Tool for translating natural language queries into SQL, executing them,
|
||||
and optionally generating visualizations.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "nl2sql"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Query the connected database or data source using natural language. "
|
||||
"Use this tool when the user asks to query, analyze, aggregate, or fetch data from the database. "
|
||||
"Set generate_chart=True if the user also wants to visualize or plot the data. "
|
||||
"If generate_chart=True, do not call visualization again for the same request."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The natural language query describing what data to fetch or analyze.",
|
||||
},
|
||||
"generate_chart": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to automatically generate a visualization chart for the result. Default is False.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
generate_chart = kwargs.get("generate_chart", False)
|
||||
|
||||
# Get context
|
||||
source = current_data_source.get()
|
||||
file_url = current_file_url.get()
|
||||
on_progress = current_progress_callback.get()
|
||||
|
||||
request = NL2SQLRequest(
|
||||
query=query,
|
||||
source=source,
|
||||
file_url=file_url,
|
||||
generate_chart=generate_chart,
|
||||
)
|
||||
|
||||
try:
|
||||
# Call the core logic
|
||||
result = await process_nl2sql(request, on_progress=on_progress)
|
||||
|
||||
# Always save visualization payload to context so the chat stream can pick it up
|
||||
# Even if there's an error, we want the frontend to see the generated SQL
|
||||
viz_payload = _build_sql_chart_viz(result, query)
|
||||
existing_viz = current_viz_data.get()
|
||||
if isinstance(existing_viz, dict):
|
||||
existing_viz.clear()
|
||||
existing_viz.update(viz_payload)
|
||||
current_viz_data.set(existing_viz)
|
||||
else:
|
||||
current_viz_data.set(viz_payload)
|
||||
|
||||
if result.error:
|
||||
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
|
||||
|
||||
# Save the result data to context for potential later use by VisualizationTool
|
||||
if result.result:
|
||||
current_data.set(result.result)
|
||||
|
||||
# Build a summary string for the Agent to read
|
||||
row_count = len(result.result) if result.result else 0
|
||||
|
||||
summary_parts = [f"Successfully executed SQL query."]
|
||||
summary_parts.append(f"SQL: {result.sql}")
|
||||
summary_parts.append(f"Rows returned: {row_count}")
|
||||
|
||||
if generate_chart:
|
||||
if result.chart and result.chart.can_visualize:
|
||||
summary_parts.append("Chart was successfully generated.")
|
||||
if result.chart.reasoning:
|
||||
summary_parts.append(f"Chart Reasoning: {result.chart.reasoning}")
|
||||
else:
|
||||
summary_parts.append("Requested a chart, but the data was not suitable for visualization.")
|
||||
|
||||
summary_parts.append("\nSample data (first 5 rows):")
|
||||
sample = result.result[:5] if result.result else []
|
||||
summary_parts.append(json.dumps(jsonable_encoder(sample), ensure_ascii=False))
|
||||
|
||||
return "\n".join(summary_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NL2SQL Tool error: {e}", exc_info=True)
|
||||
return f"An unexpected error occurred during NL2SQL execution: {str(e)}"
|
||||
@@ -0,0 +1,166 @@
|
||||
from typing import Any, Optional
|
||||
import json
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.database import SessionLocal
|
||||
from app.models.subagent import Subagent
|
||||
from app.core.nanobot import nanobot_service
|
||||
from app.core.session_alias_store import session_alias_store
|
||||
from app.services.llm_cache import get_llm_configs, get_active_llm_config
|
||||
|
||||
|
||||
def _resolve_project_id(preferred_project_id: Optional[int]) -> Optional[int]:
|
||||
if preferred_project_id is not None:
|
||||
return preferred_project_id
|
||||
from app.context import current_session_id
|
||||
session_id = (current_session_id.get() or "").strip()
|
||||
if not session_id:
|
||||
return None
|
||||
alias_meta = session_alias_store.get_alias_meta(session_id)
|
||||
if not alias_meta:
|
||||
return None
|
||||
project_id = alias_meta.get("project_id")
|
||||
return project_id if isinstance(project_id, int) else None
|
||||
|
||||
class ListSubagentsTool(Tool):
|
||||
"""
|
||||
Tool to list available subagents for the current project.
|
||||
"""
|
||||
def __init__(self, project_id: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.project_id = project_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_subagents"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List all available subagents in the current project, including their names and descriptions."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
resolved_project_id = _resolve_project_id(self.project_id)
|
||||
if resolved_project_id is None:
|
||||
return "Error: No project context available to list subagents."
|
||||
|
||||
with SessionLocal() as db:
|
||||
subagents = db.query(Subagent).filter(Subagent.project_id == resolved_project_id).all()
|
||||
|
||||
if not subagents:
|
||||
return "No subagents found in the current project."
|
||||
|
||||
result = []
|
||||
for sa in subagents:
|
||||
result.append({
|
||||
"id": sa.id,
|
||||
"name": sa.name,
|
||||
"description": sa.description,
|
||||
})
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
class InvokeSubagentTool(Tool):
|
||||
"""
|
||||
Tool to invoke a specific subagent to perform a task.
|
||||
"""
|
||||
def __init__(self, project_id: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.project_id = project_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "invoke_subagent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Invoke a subagent by name to perform a specific task. "
|
||||
"You should first use list_subagents to find the correct subagent name."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"subagent_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the subagent to invoke.",
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The specific task or query to send to the subagent.",
|
||||
}
|
||||
},
|
||||
"required": ["subagent_name", "task"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
subagent_name = kwargs.get("subagent_name")
|
||||
task = kwargs.get("task")
|
||||
resolved_project_id = _resolve_project_id(self.project_id)
|
||||
|
||||
if resolved_project_id is None:
|
||||
return "Error: No project context available to invoke subagent."
|
||||
|
||||
if not subagent_name or not task:
|
||||
return "Error: subagent_name and task are required."
|
||||
|
||||
with SessionLocal() as db:
|
||||
subagent = db.query(Subagent).filter(
|
||||
Subagent.project_id == resolved_project_id,
|
||||
Subagent.name == subagent_name
|
||||
).first()
|
||||
|
||||
if not subagent:
|
||||
return f"Error: Subagent '{subagent_name}' not found."
|
||||
|
||||
# Construct the message with subagent instructions
|
||||
instructions = subagent.instructions or "You are a helpful assistant."
|
||||
message = f"[System: You are acting as subagent '{subagent.name}'. Instructions: {instructions}]\n\nTask: {task}"
|
||||
resolved_model_id = None
|
||||
llm_configs = get_llm_configs()
|
||||
target = None
|
||||
raw_model = (getattr(subagent, "model", None) or "").strip()
|
||||
if raw_model:
|
||||
target = next((item for item in llm_configs if item.get("id") == raw_model), None)
|
||||
if target is None:
|
||||
normalized = raw_model.lower()
|
||||
target = next(
|
||||
(
|
||||
item for item in llm_configs
|
||||
if (
|
||||
str(item.get("model") or "").strip().lower() == normalized
|
||||
or str(item.get("name") or "").strip().lower() == normalized
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if target is None:
|
||||
target = get_active_llm_config()
|
||||
if target and target.get("id"):
|
||||
resolved_model_id = target.get("id")
|
||||
|
||||
try:
|
||||
from app.context import current_session_id
|
||||
parent_session_id = current_session_id.get() or "default"
|
||||
subagent_session_id = f"{parent_session_id}:subagent:{subagent.id}"
|
||||
|
||||
response = await nanobot_service.process_message(
|
||||
message=message,
|
||||
session_id=subagent_session_id,
|
||||
project_id=resolved_project_id,
|
||||
model_id=resolved_model_id,
|
||||
)
|
||||
return f"Subagent '{subagent.name}' completed the task.\nResult:\n{response}"
|
||||
except Exception as e:
|
||||
return f"Error invoking subagent '{subagent.name}': {str(e)}"
|
||||
@@ -0,0 +1,103 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.agent.chart import generate_chart
|
||||
from app.context import current_data, current_viz_data, current_progress_callback
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_query(value: str) -> str:
|
||||
return re.sub(r"\s+", "", (value or "")).lower()
|
||||
|
||||
class VisualizationTool(Tool):
|
||||
"""
|
||||
Tool for generating a visualization (chart) from existing data.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "visualization"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Generate a chart or visualization based on the most recently queried data. "
|
||||
"Use this tool when the user asks to plot, visualize, or create a chart from data that has already been retrieved. "
|
||||
"Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first. "
|
||||
"Do not call this tool right after nl2sql(generate_chart=True) for the same request."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The user's request describing how they want the data visualized (e.g., 'plot sales by month as a bar chart').",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
data = current_data.get()
|
||||
on_progress = current_progress_callback.get()
|
||||
|
||||
if not data:
|
||||
return "Error: No data available to visualize. Please query the data first using the nl2sql tool."
|
||||
|
||||
try:
|
||||
existing_viz = current_viz_data.get() or {}
|
||||
existing_chart = existing_viz.get("chart") if isinstance(existing_viz, dict) else None
|
||||
existing_result = existing_viz.get("result") if isinstance(existing_viz, dict) else None
|
||||
existing_query_normalized = (
|
||||
existing_viz.get("chart_query_normalized") if isinstance(existing_viz, dict) else None
|
||||
)
|
||||
if (
|
||||
existing_chart
|
||||
and existing_result == data
|
||||
and existing_query_normalized
|
||||
and existing_query_normalized == _normalize_query(query)
|
||||
):
|
||||
return "Chart already exists for this query and dataset. Reusing existing Vega visualization."
|
||||
|
||||
if on_progress:
|
||||
await on_progress("正在分析数据特征并生成可视化方案...")
|
||||
|
||||
chart_response = await generate_chart(data, query)
|
||||
|
||||
if chart_response.can_visualize:
|
||||
# Build the viz payload (similar to NL2SQL, but without the SQL part)
|
||||
# We reuse the previous viz_data if it exists (to keep SQL), or create a new one
|
||||
existing_viz = current_viz_data.get() or {}
|
||||
|
||||
viz_payload = {
|
||||
"sql": existing_viz.get("sql", ""),
|
||||
"result": data,
|
||||
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
|
||||
"chart_query": query,
|
||||
"chart_query_normalized": _normalize_query(query),
|
||||
"chart_generated_by": "visualization",
|
||||
"error": None,
|
||||
}
|
||||
encoded_viz = jsonable_encoder(viz_payload)
|
||||
if isinstance(existing_viz, dict):
|
||||
existing_viz.clear()
|
||||
existing_viz.update(encoded_viz)
|
||||
current_viz_data.set(existing_viz)
|
||||
else:
|
||||
current_viz_data.set(encoded_viz)
|
||||
|
||||
return f"Successfully generated a {chart_response.chart_type} chart.\nReasoning: {chart_response.reasoning}"
|
||||
else:
|
||||
return f"Could not generate a chart: {chart_response.reasoning}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Visualization Tool error: {e}", exc_info=True)
|
||||
return f"An error occurred while generating the visualization: {str(e)}"
|
||||
Reference in New Issue
Block a user