Update 2026-05-13 16:43:53

This commit is contained in:
yi
2026-05-13 16:43:53 +08:00
parent 6af5c584f4
commit afd7c5fe85
490 changed files with 850 additions and 922 deletions
+131
View File
@@ -0,0 +1,131 @@
import json
import logging
from typing import Any
import asyncio
from nanobot.agent.tools.base import Tool
from app.context import current_data_source, current_file_url, current_progress_callback
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.database import SessionLocal
from app.models.datasource import DataSource
# Import schema logic from nl2sql
from app.agent.nl2sql import (
_get_cached_schema,
_set_cached_schema,
_check_connection_with_cache,
_get_upload_payload
)
logger = logging.getLogger(__name__)
class GetDatabaseSchemaTool(Tool):
"""
Tool for fetching the database schema directly without SQL generation.
"""
@property
def name(self) -> str:
return "get_database_schema"
@property
def description(self) -> str:
return (
"Get the structural schema of the currently connected database or data source. "
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
"'show me the database structure', 'what are the columns in table X', etc. "
"It directly returns the schema without generating SQL."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
async def execute(self, **kwargs: Any) -> str:
source = current_data_source.get()
file_url = current_file_url.get()
on_progress = current_progress_callback.get()
async def emit_progress(msg: str):
if on_progress:
await on_progress(msg)
await emit_progress("正在获取数据源结构...")
connector = None
schema = {}
if not source:
return "Error: No data source connected."
if source == "postgres":
connector = postgres_connector
elif source == "clickhouse":
connector = clickhouse_connector
elif source == "upload":
try:
payload = await asyncio.to_thread(_get_upload_payload, file_url)
schema = payload["schema"]
await emit_progress("文件 Schema 获取完成")
except Exception as e:
return f"Failed to get upload schema: {e}"
elif source.startswith("ds:"):
try:
ds_id = int(source.split(":")[1])
def _get_ds_connector():
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
if not ds: return None
return get_connector(ds)
finally:
db.close()
connector = await asyncio.to_thread(_get_ds_connector)
if not connector:
return f"Data source not found: {source}"
except Exception as e:
return f"Failed to load data source: {e}"
else:
return f"Unsupported data source: {source}"
if connector:
cached_schema = _get_cached_schema(source, connector)
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
else:
if not await _check_connection_with_cache(source, connector):
return f"Failed to connect to {source}"
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=120.0
)
_set_cached_schema(source, connector, schema)
await emit_progress(f"成功获取 {len(schema)} 张表结构")
except asyncio.TimeoutError:
return "Failed to fetch schema: Timeout after 120 seconds."
except Exception as e:
return f"Failed to fetch schema: {e}"
# Format the output for the LLM to make it readable and token-efficient
lines = []
for table_name, table_info in schema.items():
if isinstance(table_info, list):
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
lines.append(f"Table: {table_name}\n Columns: {cols}")
elif isinstance(table_info, dict):
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
pks = ", ".join(table_info.get("primary_keys", []))
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
return "\n\n".join(lines) if lines else "No tables found in schema."
+59
View File
@@ -0,0 +1,59 @@
import json
from typing import Any
from nanobot.agent.tools.base import Tool
from app.context import current_knowledge_base_id
from app.services.knowledge_index import knowledge_index_service
class KnowledgeBaseRetrieveTool(Tool):
@property
def name(self) -> str:
return "knowledge_retrieve"
@property
def description(self) -> str:
return "Retrieve relevant context from the selected knowledge base to answer user questions."
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "User question or retrieval query.",
},
"knowledge_base_id": {
"type": "string",
"description": "Optional knowledge base id, defaults to current session setting.",
},
"top_k": {
"type": "integer",
"description": "Maximum number of returned chunks.",
"minimum": 1,
"maximum": 20,
},
},
"required": ["query"],
}
async def execute(self, **kwargs: Any) -> str:
query = (kwargs.get("query") or "").strip()
if not query:
return "Query is required."
kb_id = (kwargs.get("knowledge_base_id") or current_knowledge_base_id.get() or "").strip()
if not kb_id:
return "No knowledge base is selected in this session."
top_k = kwargs.get("top_k")
try:
result = knowledge_index_service.search(kb_id=kb_id, query=query, top_k=top_k)
except ValueError as exc:
return str(exc)
payload = {
"knowledge_base_id": kb_id,
"answer": result.get("answer", ""),
"hits": result.get("hits", []),
}
return json.dumps(payload, ensure_ascii=False)
+127
View File
@@ -0,0 +1,127 @@
import json
import logging
import re
from typing import Any
from nanobot.agent.tools.base import Tool
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
from app.context import current_progress_callback, current_viz_data, current_data_source, current_file_url, current_data
from fastapi.encoders import jsonable_encoder
logger = logging.getLogger(__name__)
def _normalize_query(value: str) -> str:
return re.sub(r"\s+", "", (value or "")).lower()
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse, query: str) -> dict:
chart = nl2sql_result.chart
payload = {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
"chart_query": query,
"chart_query_normalized": _normalize_query(query),
"chart_generated_by": "nl2sql",
"error": nl2sql_result.error,
}
return jsonable_encoder(payload)
class NL2SQLTool(Tool):
"""
Tool for translating natural language queries into SQL, executing them,
and optionally generating visualizations.
"""
@property
def name(self) -> str:
return "nl2sql"
@property
def description(self) -> str:
return (
"Query the connected database or data source using natural language. "
"Use this tool when the user asks to query, analyze, aggregate, or fetch data from the database. "
"Set generate_chart=True if the user also wants to visualize or plot the data. "
"If generate_chart=True, do not call visualization again for the same request."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The natural language query describing what data to fetch or analyze.",
},
"generate_chart": {
"type": "boolean",
"description": "Whether to automatically generate a visualization chart for the result. Default is False.",
}
},
"required": ["query"],
}
async def execute(self, **kwargs: Any) -> str:
query = kwargs.get("query", "")
generate_chart = kwargs.get("generate_chart", False)
# Get context
source = current_data_source.get()
file_url = current_file_url.get()
on_progress = current_progress_callback.get()
request = NL2SQLRequest(
query=query,
source=source,
file_url=file_url,
generate_chart=generate_chart,
)
try:
# Call the core logic
result = await process_nl2sql(request, on_progress=on_progress)
# Always save visualization payload to context so the chat stream can pick it up
# Even if there's an error, we want the frontend to see the generated SQL
viz_payload = _build_sql_chart_viz(result, query)
existing_viz = current_viz_data.get()
if isinstance(existing_viz, dict):
existing_viz.clear()
existing_viz.update(viz_payload)
current_viz_data.set(existing_viz)
else:
current_viz_data.set(viz_payload)
if result.error:
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Build a summary string for the Agent to read
row_count = len(result.result) if result.result else 0
summary_parts = [f"Successfully executed SQL query."]
summary_parts.append(f"SQL: {result.sql}")
summary_parts.append(f"Rows returned: {row_count}")
if generate_chart:
if result.chart and result.chart.can_visualize:
summary_parts.append("Chart was successfully generated.")
if result.chart.reasoning:
summary_parts.append(f"Chart Reasoning: {result.chart.reasoning}")
else:
summary_parts.append("Requested a chart, but the data was not suitable for visualization.")
summary_parts.append("\nSample data (first 5 rows):")
sample = result.result[:5] if result.result else []
summary_parts.append(json.dumps(jsonable_encoder(sample), ensure_ascii=False))
return "\n".join(summary_parts)
except Exception as e:
logger.error(f"NL2SQL Tool error: {e}", exc_info=True)
return f"An unexpected error occurred during NL2SQL execution: {str(e)}"
+166
View File
@@ -0,0 +1,166 @@
from typing import Any, Optional
import json
from nanobot.agent.tools.base import Tool
from app.database import SessionLocal
from app.models.subagent import Subagent
from app.core.nanobot import nanobot_service
from app.core.session_alias_store import session_alias_store
from app.services.llm_cache import get_llm_configs, get_active_llm_config
def _resolve_project_id(preferred_project_id: Optional[int]) -> Optional[int]:
if preferred_project_id is not None:
return preferred_project_id
from app.context import current_session_id
session_id = (current_session_id.get() or "").strip()
if not session_id:
return None
alias_meta = session_alias_store.get_alias_meta(session_id)
if not alias_meta:
return None
project_id = alias_meta.get("project_id")
return project_id if isinstance(project_id, int) else None
class ListSubagentsTool(Tool):
"""
Tool to list available subagents for the current project.
"""
def __init__(self, project_id: Optional[int] = None):
super().__init__()
self.project_id = project_id
@property
def name(self) -> str:
return "list_subagents"
@property
def description(self) -> str:
return "List all available subagents in the current project, including their names and descriptions."
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
async def execute(self, **kwargs: Any) -> str:
resolved_project_id = _resolve_project_id(self.project_id)
if resolved_project_id is None:
return "Error: No project context available to list subagents."
with SessionLocal() as db:
subagents = db.query(Subagent).filter(Subagent.project_id == resolved_project_id).all()
if not subagents:
return "No subagents found in the current project."
result = []
for sa in subagents:
result.append({
"id": sa.id,
"name": sa.name,
"description": sa.description,
})
return json.dumps(result, ensure_ascii=False, indent=2)
class InvokeSubagentTool(Tool):
"""
Tool to invoke a specific subagent to perform a task.
"""
def __init__(self, project_id: Optional[int] = None):
super().__init__()
self.project_id = project_id
@property
def name(self) -> str:
return "invoke_subagent"
@property
def description(self) -> str:
return (
"Invoke a subagent by name to perform a specific task. "
"You should first use list_subagents to find the correct subagent name."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"subagent_name": {
"type": "string",
"description": "The name of the subagent to invoke.",
},
"task": {
"type": "string",
"description": "The specific task or query to send to the subagent.",
}
},
"required": ["subagent_name", "task"],
}
async def execute(self, **kwargs: Any) -> str:
subagent_name = kwargs.get("subagent_name")
task = kwargs.get("task")
resolved_project_id = _resolve_project_id(self.project_id)
if resolved_project_id is None:
return "Error: No project context available to invoke subagent."
if not subagent_name or not task:
return "Error: subagent_name and task are required."
with SessionLocal() as db:
subagent = db.query(Subagent).filter(
Subagent.project_id == resolved_project_id,
Subagent.name == subagent_name
).first()
if not subagent:
return f"Error: Subagent '{subagent_name}' not found."
# Construct the message with subagent instructions
instructions = subagent.instructions or "You are a helpful assistant."
message = f"[System: You are acting as subagent '{subagent.name}'. Instructions: {instructions}]\n\nTask: {task}"
resolved_model_id = None
llm_configs = get_llm_configs()
target = None
raw_model = (getattr(subagent, "model", None) or "").strip()
if raw_model:
target = next((item for item in llm_configs if item.get("id") == raw_model), None)
if target is None:
normalized = raw_model.lower()
target = next(
(
item for item in llm_configs
if (
str(item.get("model") or "").strip().lower() == normalized
or str(item.get("name") or "").strip().lower() == normalized
)
),
None,
)
if target is None:
target = get_active_llm_config()
if target and target.get("id"):
resolved_model_id = target.get("id")
try:
from app.context import current_session_id
parent_session_id = current_session_id.get() or "default"
subagent_session_id = f"{parent_session_id}:subagent:{subagent.id}"
response = await nanobot_service.process_message(
message=message,
session_id=subagent_session_id,
project_id=resolved_project_id,
model_id=resolved_model_id,
)
return f"Subagent '{subagent.name}' completed the task.\nResult:\n{response}"
except Exception as e:
return f"Error invoking subagent '{subagent.name}': {str(e)}"
+103
View File
@@ -0,0 +1,103 @@
import logging
import re
from typing import Any
from nanobot.agent.tools.base import Tool
from app.agent.chart import generate_chart
from app.context import current_data, current_viz_data, current_progress_callback
from fastapi.encoders import jsonable_encoder
logger = logging.getLogger(__name__)
def _normalize_query(value: str) -> str:
return re.sub(r"\s+", "", (value or "")).lower()
class VisualizationTool(Tool):
"""
Tool for generating a visualization (chart) from existing data.
"""
@property
def name(self) -> str:
return "visualization"
@property
def description(self) -> str:
return (
"Generate a chart or visualization based on the most recently queried data. "
"Use this tool when the user asks to plot, visualize, or create a chart from data that has already been retrieved. "
"Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first. "
"Do not call this tool right after nl2sql(generate_chart=True) for the same request."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The user's request describing how they want the data visualized (e.g., 'plot sales by month as a bar chart').",
}
},
"required": ["query"],
}
async def execute(self, **kwargs: Any) -> str:
query = kwargs.get("query", "")
data = current_data.get()
on_progress = current_progress_callback.get()
if not data:
return "Error: No data available to visualize. Please query the data first using the nl2sql tool."
try:
existing_viz = current_viz_data.get() or {}
existing_chart = existing_viz.get("chart") if isinstance(existing_viz, dict) else None
existing_result = existing_viz.get("result") if isinstance(existing_viz, dict) else None
existing_query_normalized = (
existing_viz.get("chart_query_normalized") if isinstance(existing_viz, dict) else None
)
if (
existing_chart
and existing_result == data
and existing_query_normalized
and existing_query_normalized == _normalize_query(query)
):
return "Chart already exists for this query and dataset. Reusing existing Vega visualization."
if on_progress:
await on_progress("正在分析数据特征并生成可视化方案...")
chart_response = await generate_chart(data, query)
if chart_response.can_visualize:
# Build the viz payload (similar to NL2SQL, but without the SQL part)
# We reuse the previous viz_data if it exists (to keep SQL), or create a new one
existing_viz = current_viz_data.get() or {}
viz_payload = {
"sql": existing_viz.get("sql", ""),
"result": data,
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
"chart_query": query,
"chart_query_normalized": _normalize_query(query),
"chart_generated_by": "visualization",
"error": None,
}
encoded_viz = jsonable_encoder(viz_payload)
if isinstance(existing_viz, dict):
existing_viz.clear()
existing_viz.update(encoded_viz)
current_viz_data.set(existing_viz)
else:
current_viz_data.set(encoded_viz)
return f"Successfully generated a {chart_response.chart_type} chart.\nReasoning: {chart_response.reasoning}"
else:
return f"Could not generate a chart: {chart_response.reasoning}"
except Exception as e:
logger.error(f"Visualization Tool error: {e}", exc_info=True)
return f"An error occurred while generating the visualization: {str(e)}"