refactor: convert to nl2sql skills
This commit is contained in:
+44
-11
@@ -191,14 +191,24 @@ def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[
|
|||||||
with _cache_lock:
|
with _cache_lock:
|
||||||
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
||||||
|
|
||||||
def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
async def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||||
cache_key = _build_schema_cache_key(source, connector)
|
cache_key = _build_schema_cache_key(source, connector)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
with _cache_lock:
|
with _cache_lock:
|
||||||
cached = _connection_cache.get(cache_key)
|
cached = _connection_cache.get(cache_key)
|
||||||
if cached and now < cached["expires_at"]:
|
if cached and now < cached["expires_at"]:
|
||||||
return bool(cached["ok"])
|
return bool(cached["ok"])
|
||||||
ok = connector.test_connection()
|
|
||||||
|
# Run synchronous test_connection in a separate thread to avoid blocking event loop
|
||||||
|
try:
|
||||||
|
ok = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(connector.test_connection),
|
||||||
|
timeout=10.0
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Connection test failed or timed out: {e}")
|
||||||
|
ok = False
|
||||||
|
|
||||||
with _cache_lock:
|
with _cache_lock:
|
||||||
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
||||||
return ok
|
return ok
|
||||||
@@ -224,7 +234,7 @@ async def process_nl2sql(
|
|||||||
elif request.source == "upload":
|
elif request.source == "upload":
|
||||||
try:
|
try:
|
||||||
upload_started = time.perf_counter()
|
upload_started = time.perf_counter()
|
||||||
upload_payload = _get_upload_payload(request.file_url)
|
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||||
upload_df = upload_payload["df"]
|
upload_df = upload_payload["df"]
|
||||||
schema = upload_payload["schema"]
|
schema = upload_payload["schema"]
|
||||||
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
|
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
|
||||||
@@ -234,14 +244,21 @@ async def process_nl2sql(
|
|||||||
try:
|
try:
|
||||||
ds_started = time.perf_counter()
|
ds_started = time.perf_counter()
|
||||||
ds_id = int(request.source.split(":")[1])
|
ds_id = int(request.source.split(":")[1])
|
||||||
|
|
||||||
|
def _get_ds_connector():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||||
if not ds:
|
if not ds:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
return None
|
||||||
connector = get_connector(ds)
|
return get_connector(ds)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
connector = await asyncio.to_thread(_get_ds_connector)
|
||||||
|
if not connector:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||||
|
|
||||||
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
|
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||||
@@ -258,20 +275,35 @@ async def process_nl2sql(
|
|||||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||||
else:
|
else:
|
||||||
conn_started = time.perf_counter()
|
conn_started = time.perf_counter()
|
||||||
if not _check_connection_with_cache(request.source, connector):
|
if not await _check_connection_with_cache(request.source, connector):
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||||
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
|
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
|
||||||
schema_started = time.perf_counter()
|
schema_started = time.perf_counter()
|
||||||
schema = connector.get_schema()
|
try:
|
||||||
|
schema = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(connector.get_schema),
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
|
||||||
|
|
||||||
_set_cached_schema(request.source, connector, schema)
|
_set_cached_schema(request.source, connector, schema)
|
||||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||||
|
|
||||||
if connector and not schema:
|
if connector and not schema:
|
||||||
retry_started = time.perf_counter()
|
retry_started = time.perf_counter()
|
||||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||||
if not _check_connection_with_cache(request.source, connector):
|
if not await _check_connection_with_cache(request.source, connector):
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||||
schema = connector.get_schema()
|
|
||||||
|
try:
|
||||||
|
schema = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(connector.get_schema),
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
|
||||||
|
|
||||||
_set_cached_schema(request.source, connector, schema)
|
_set_cached_schema(request.source, connector, schema)
|
||||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
||||||
|
|
||||||
@@ -282,7 +314,7 @@ async def process_nl2sql(
|
|||||||
if request.source.startswith("ds:"):
|
if request.source.startswith("ds:"):
|
||||||
try:
|
try:
|
||||||
ds_id = int(request.source.split(":")[1])
|
ds_id = int(request.source.split(":")[1])
|
||||||
mdl = MDLService.get_mdl(ds_id)
|
mdl = await asyncio.to_thread(MDLService.get_mdl, ds_id)
|
||||||
if mdl:
|
if mdl:
|
||||||
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
|
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
|
||||||
|
|
||||||
@@ -392,7 +424,8 @@ Language: Chinese (Simplified)
|
|||||||
await emit_progress("正在执行 SQL 查询")
|
await emit_progress("正在执行 SQL 查询")
|
||||||
if request.source == "upload":
|
if request.source == "upload":
|
||||||
if upload_df is None:
|
if upload_df is None:
|
||||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||||
|
upload_df = upload_payload["df"]
|
||||||
timeout_stage = "sql_execution"
|
timeout_stage = "sql_execution"
|
||||||
formatted_results = await asyncio.wait_for(
|
formatted_results = await asyncio.wait_for(
|
||||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Callable, Awaitable, Dict, Optional
|
||||||
|
|
||||||
|
# The current session ID processing the request
|
||||||
|
current_session_id: ContextVar[str] = ContextVar("current_session_id", default="")
|
||||||
|
|
||||||
|
# A callback to send progress updates to the frontend during tool execution
|
||||||
|
current_progress_callback: ContextVar[Optional[Callable[[str], Awaitable[None]]]] = ContextVar("current_progress_callback", default=None)
|
||||||
|
|
||||||
|
# A payload dictionary to store visualization results generated by tools
|
||||||
|
# This will be picked up by the stream handler and sent to the frontend
|
||||||
|
current_viz_data: ContextVar[Optional[Dict[str, Any]]] = ContextVar("current_viz_data", default=None)
|
||||||
|
|
||||||
|
# Store the last queried data so the Visualization Tool can access it
|
||||||
|
current_data: ContextVar[Optional[list]] = ContextVar("current_data", default=None)
|
||||||
|
|
||||||
|
# The data source requested by the user or bound to the session
|
||||||
|
current_data_source: ContextVar[str] = ContextVar("current_data_source", default="postgres")
|
||||||
|
|
||||||
|
# Any file URL attached to the request
|
||||||
|
current_file_url: ContextVar[Optional[str]] = ContextVar("current_file_url", default=None)
|
||||||
@@ -85,6 +85,14 @@ class NanobotIntegration:
|
|||||||
channels_config=self.config.channels,
|
channels_config=self.config.channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._register_custom_tools(self.agent)
|
||||||
|
|
||||||
|
def _register_custom_tools(self, agent: AgentLoop):
|
||||||
|
from app.tools.nl2sql import NL2SQLTool
|
||||||
|
from app.tools.visualization import VisualizationTool
|
||||||
|
agent.tools.register(NL2SQLTool())
|
||||||
|
agent.tools.register(VisualizationTool())
|
||||||
|
|
||||||
def _make_provider(self, config: Config):
|
def _make_provider(self, config: Config):
|
||||||
# Logic adapted from nanobot/cli/commands.py
|
# Logic adapted from nanobot/cli/commands.py
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
@@ -195,6 +203,7 @@ class NanobotIntegration:
|
|||||||
provider_name=target_config.get("provider"),
|
provider_name=target_config.get("provider"),
|
||||||
)
|
)
|
||||||
agent = self._build_agent_for_provider(provider)
|
agent = self._build_agent_for_provider(provider)
|
||||||
|
self._register_custom_tools(agent)
|
||||||
self._model_agent_cache[model_id] = agent
|
self._model_agent_cache[model_id] = agent
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,117 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||||
|
from app.context import current_progress_callback, current_viz_data, current_data_source, current_file_url, current_data
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
||||||
|
chart = nl2sql_result.chart
|
||||||
|
payload = {
|
||||||
|
"sql": nl2sql_result.sql,
|
||||||
|
"result": nl2sql_result.result,
|
||||||
|
"chart": chart.model_dump() if chart else None,
|
||||||
|
"error": nl2sql_result.error,
|
||||||
|
}
|
||||||
|
return jsonable_encoder(payload)
|
||||||
|
|
||||||
|
class NL2SQLTool(Tool):
|
||||||
|
"""
|
||||||
|
Tool for translating natural language queries into SQL, executing them,
|
||||||
|
and optionally generating visualizations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "nl2sql"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Query the connected database or data source using natural language. "
|
||||||
|
"Use this tool when the user asks to query, analyze, aggregate, or fetch data from the database. "
|
||||||
|
"Set generate_chart=True if the user also wants to visualize or plot the data."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The natural language query describing what data to fetch or analyze.",
|
||||||
|
},
|
||||||
|
"generate_chart": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to automatically generate a visualization chart for the result. Default is False.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
query = kwargs.get("query", "")
|
||||||
|
generate_chart = kwargs.get("generate_chart", False)
|
||||||
|
|
||||||
|
# Get context
|
||||||
|
source = current_data_source.get()
|
||||||
|
file_url = current_file_url.get()
|
||||||
|
on_progress = current_progress_callback.get()
|
||||||
|
|
||||||
|
request = NL2SQLRequest(
|
||||||
|
query=query,
|
||||||
|
source=source,
|
||||||
|
file_url=file_url,
|
||||||
|
generate_chart=generate_chart,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call the core logic
|
||||||
|
result = await process_nl2sql(request, on_progress=on_progress)
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
return f"Error executing query: {result.error}"
|
||||||
|
|
||||||
|
# Save the result data to context for potential later use by VisualizationTool
|
||||||
|
if result.result:
|
||||||
|
current_data.set(result.result)
|
||||||
|
|
||||||
|
# Save visualization payload to context so the chat stream can pick it up
|
||||||
|
viz_payload = _build_sql_chart_viz(result)
|
||||||
|
existing_viz = current_viz_data.get()
|
||||||
|
if isinstance(existing_viz, dict):
|
||||||
|
existing_viz.clear()
|
||||||
|
existing_viz.update(viz_payload)
|
||||||
|
current_viz_data.set(existing_viz)
|
||||||
|
else:
|
||||||
|
current_viz_data.set(viz_payload)
|
||||||
|
|
||||||
|
# Build a summary string for the Agent to read
|
||||||
|
row_count = len(result.result) if result.result else 0
|
||||||
|
|
||||||
|
summary_parts = [f"Successfully executed SQL query."]
|
||||||
|
summary_parts.append(f"SQL: {result.sql}")
|
||||||
|
summary_parts.append(f"Rows returned: {row_count}")
|
||||||
|
|
||||||
|
if generate_chart:
|
||||||
|
if result.chart and result.chart.can_visualize:
|
||||||
|
summary_parts.append("Chart was successfully generated.")
|
||||||
|
if result.chart.reasoning:
|
||||||
|
summary_parts.append(f"Chart Reasoning: {result.chart.reasoning}")
|
||||||
|
else:
|
||||||
|
summary_parts.append("Requested a chart, but the data was not suitable for visualization.")
|
||||||
|
|
||||||
|
summary_parts.append("\nSample data (first 5 rows):")
|
||||||
|
sample = result.result[:5] if result.result else []
|
||||||
|
summary_parts.append(json.dumps(jsonable_encoder(sample), ensure_ascii=False))
|
||||||
|
|
||||||
|
return "\n".join(summary_parts)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"NL2SQL Tool error: {e}", exc_info=True)
|
||||||
|
return f"An unexpected error occurred during NL2SQL execution: {str(e)}"
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from app.agent.chart import generate_chart
|
||||||
|
from app.context import current_data, current_viz_data, current_progress_callback
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class VisualizationTool(Tool):
|
||||||
|
"""
|
||||||
|
Tool for generating a visualization (chart) from existing data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "visualization"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Generate a chart or visualization based on the most recently queried data. "
|
||||||
|
"Use this tool when the user asks to plot, visualize, or create a chart from data that has already been retrieved. "
|
||||||
|
"Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The user's request describing how they want the data visualized (e.g., 'plot sales by month as a bar chart').",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
query = kwargs.get("query", "")
|
||||||
|
data = current_data.get()
|
||||||
|
on_progress = current_progress_callback.get()
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return "Error: No data available to visualize. Please query the data first using the nl2sql tool."
|
||||||
|
|
||||||
|
try:
|
||||||
|
if on_progress:
|
||||||
|
await on_progress("正在分析数据特征并生成可视化方案...")
|
||||||
|
|
||||||
|
chart_response = await generate_chart(data, query)
|
||||||
|
|
||||||
|
if chart_response.can_visualize:
|
||||||
|
# Build the viz payload (similar to NL2SQL, but without the SQL part)
|
||||||
|
# We reuse the previous viz_data if it exists (to keep SQL), or create a new one
|
||||||
|
existing_viz = current_viz_data.get() or {}
|
||||||
|
|
||||||
|
viz_payload = {
|
||||||
|
"sql": existing_viz.get("sql", ""),
|
||||||
|
"result": data,
|
||||||
|
"chart": chart_response.model_dump(),
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
encoded_viz = jsonable_encoder(viz_payload)
|
||||||
|
if isinstance(existing_viz, dict):
|
||||||
|
existing_viz.clear()
|
||||||
|
existing_viz.update(encoded_viz)
|
||||||
|
current_viz_data.set(existing_viz)
|
||||||
|
else:
|
||||||
|
current_viz_data.set(encoded_viz)
|
||||||
|
|
||||||
|
return f"Successfully generated a {chart_response.chart_type} chart.\nReasoning: {chart_response.reasoning}"
|
||||||
|
else:
|
||||||
|
return f"Could not generate a chart: {chart_response.reasoning}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Visualization Tool error: {e}", exc_info=True)
|
||||||
|
return f"An error occurred while generating the visualization: {str(e)}"
|
||||||
+72
-162
@@ -14,7 +14,7 @@ from app.connectors.postgres import postgres_connector
|
|||||||
from app.connectors.clickhouse import clickhouse_connector
|
from app.connectors.clickhouse import clickhouse_connector
|
||||||
from app.core.nanobot import nanobot_service
|
from app.core.nanobot import nanobot_service
|
||||||
from app.core.session_alias_store import session_alias_store
|
from app.core.session_alias_store import session_alias_store
|
||||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
from app.context import current_session_id, current_progress_callback, current_viz_data, current_data_source, current_file_url
|
||||||
from app.database import engine, Base
|
from app.database import engine, Base
|
||||||
# Import all models to ensure they are registered
|
# Import all models to ensure they are registered
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
@@ -44,21 +44,6 @@ app.include_router(semantic.router, prefix="/api/v1")
|
|||||||
|
|
||||||
STREAM_DELTA_CHUNK_SIZE = 48
|
STREAM_DELTA_CHUNK_SIZE = 48
|
||||||
|
|
||||||
SQL_INTENT_DENY_PATTERNS = [
|
|
||||||
re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE),
|
|
||||||
re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE),
|
|
||||||
]
|
|
||||||
|
|
||||||
SQL_INTENT_POSITIVE_PATTERNS = [
|
|
||||||
re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE),
|
|
||||||
re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE),
|
|
||||||
re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE),
|
|
||||||
re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE),
|
|
||||||
]
|
|
||||||
|
|
||||||
VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE)
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
# Initialize nanobot in background
|
# Initialize nanobot in background
|
||||||
@@ -110,56 +95,15 @@ def _session_context_for_routing(session_id: str) -> Dict[str, Any]:
|
|||||||
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
||||||
return session.metadata or {}
|
return session.metadata or {}
|
||||||
|
|
||||||
|
def _resolve_effective_source(request: ChatRequest) -> str:
|
||||||
def _looks_like_sql_intent(message: str) -> bool:
|
|
||||||
text = (message or "").strip().lower()
|
|
||||||
if not text:
|
|
||||||
return False
|
|
||||||
for pattern in SQL_INTENT_DENY_PATTERNS:
|
|
||||||
if pattern.search(text):
|
|
||||||
return False
|
|
||||||
for pattern in SQL_INTENT_POSITIVE_PATTERNS:
|
|
||||||
if pattern.search(text):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _looks_like_visual_intent(message: str) -> bool:
|
|
||||||
return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower()))
|
|
||||||
|
|
||||||
|
|
||||||
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
|
|
||||||
# Determine the effective data source from session context or request
|
|
||||||
session_ctx = _session_context_for_routing(request.session_id)
|
session_ctx = _session_context_for_routing(request.session_id)
|
||||||
session_source = (session_ctx.get("selected_data_source") or "").strip().lower()
|
session_source = (session_ctx.get("selected_data_source") or "").strip().lower()
|
||||||
request_source = (request.source or "").strip().lower()
|
request_source = (request.source or "").strip().lower()
|
||||||
|
|
||||||
# Priority: Session bound source > Request source > "postgres"
|
|
||||||
effective_source = request_source
|
effective_source = request_source
|
||||||
if session_source.startswith("ds:") or session_source == "upload":
|
if session_source.startswith("ds:") or session_source == "upload":
|
||||||
effective_source = session_source
|
effective_source = session_source
|
||||||
|
return effective_source
|
||||||
if request.route_mode == "sql":
|
|
||||||
return True, "route_mode=sql", effective_source
|
|
||||||
if request.route_mode == "chat":
|
|
||||||
return False, "route_mode=chat", effective_source
|
|
||||||
if request.prefer_sql_chart:
|
|
||||||
return True, "prefer_sql_chart=true", effective_source
|
|
||||||
|
|
||||||
has_sql_intent = _looks_like_sql_intent(request.message)
|
|
||||||
if not has_sql_intent:
|
|
||||||
return False, "message_non_sql_intent", effective_source
|
|
||||||
|
|
||||||
# If we have intent, check if we have a valid source context
|
|
||||||
if effective_source.startswith("ds:") or effective_source == "upload":
|
|
||||||
return True, "message_sql_intent_with_datasource", effective_source
|
|
||||||
|
|
||||||
# Even if just "postgres" (default), if intent is strong, we might allow it?
|
|
||||||
# But usually we want a bound source.
|
|
||||||
# Let's keep existing logic: if intent is strong, return True.
|
|
||||||
# But effectively, if source is "postgres", it might fail later if no tables are there.
|
|
||||||
return True, "message_sql_intent", effective_source
|
|
||||||
|
|
||||||
|
|
||||||
class SessionAliasUpdateRequest(BaseModel):
|
class SessionAliasUpdateRequest(BaseModel):
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
@@ -175,71 +119,42 @@ class SessionFileContextUpdateRequest(BaseModel):
|
|||||||
active_data_file: Optional[Dict[str, Any]] = None
|
active_data_file: Optional[Dict[str, Any]] = None
|
||||||
selected_data_source: Optional[str] = None
|
selected_data_source: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
|
|
||||||
chart = nl2sql_result.chart
|
|
||||||
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
|
||||||
text = (
|
|
||||||
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
|
||||||
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
|
||||||
)
|
|
||||||
if chart and chart.reasoning:
|
|
||||||
return f"{text}\n\n可视化说明:{chart.reasoning}"
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
|
||||||
chart = nl2sql_result.chart
|
|
||||||
payload = {
|
|
||||||
"sql": nl2sql_result.sql,
|
|
||||||
"result": nl2sql_result.result,
|
|
||||||
"chart": chart.model_dump() if chart else None,
|
|
||||||
"error": nl2sql_result.error,
|
|
||||||
}
|
|
||||||
return jsonable_encoder(payload)
|
|
||||||
|
|
||||||
|
|
||||||
def _persist_session_turn(
|
|
||||||
session_id: str,
|
|
||||||
user_message: str,
|
|
||||||
assistant_message: str,
|
|
||||||
assistant_extra: Optional[dict] = None,
|
|
||||||
) -> None:
|
|
||||||
if not nanobot_service.agent:
|
|
||||||
return
|
|
||||||
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
|
||||||
session.add_message("user", user_message)
|
|
||||||
session.add_message("assistant", assistant_message, **(assistant_extra or {}))
|
|
||||||
nanobot_service.agent.sessions.save(session)
|
|
||||||
|
|
||||||
@app.post("/nanobot/chat")
|
@app.post("/nanobot/chat")
|
||||||
async def nanobot_chat(request: ChatRequest):
|
async def nanobot_chat(request: ChatRequest):
|
||||||
try:
|
try:
|
||||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
resolved_source = _resolve_effective_source(request)
|
||||||
if use_nl2sql:
|
current_data_source.set(resolved_source)
|
||||||
nl2sql_result = await process_nl2sql(
|
current_file_url.set(request.file_url)
|
||||||
NL2SQLRequest(
|
current_session_id.set(request.session_id)
|
||||||
query=request.message,
|
current_viz_data.set({})
|
||||||
source=resolved_source,
|
|
||||||
file_url=request.file_url,
|
# Inject instructions if explicitly routed
|
||||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
message = request.message
|
||||||
)
|
if request.route_mode == "sql" or request.prefer_sql_chart:
|
||||||
)
|
message = f"[System: User explicitly requested data analysis. Please use the nl2sql tool to answer the following query.]\n{message}"
|
||||||
text = _build_sql_chart_text(nl2sql_result)
|
elif request.route_mode == "chat":
|
||||||
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
message = f"[System: User explicitly requested normal chat. Do NOT use the nl2sql tool.]\n{message}"
|
||||||
_persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload})
|
|
||||||
return {
|
|
||||||
"response": text,
|
|
||||||
"viz": viz_payload,
|
|
||||||
"routing": {"selected": "sql", "reason": route_reason},
|
|
||||||
}
|
|
||||||
response = await nanobot_service.process_message(
|
response = await nanobot_service.process_message(
|
||||||
request.message,
|
message,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
skill_ids=request.skill_ids,
|
skill_ids=request.skill_ids,
|
||||||
model_id=request.model_id,
|
model_id=request.model_id,
|
||||||
)
|
)
|
||||||
return {"response": response, "routing": {"selected": "chat", "reason": route_reason}}
|
|
||||||
|
viz_payload = current_viz_data.get()
|
||||||
|
if viz_payload and nanobot_service.agent:
|
||||||
|
# Update the last assistant message with viz data
|
||||||
|
session = nanobot_service.agent.sessions.get_or_create(request.session_id)
|
||||||
|
if session.messages and session.messages[-1].get("role") == "assistant":
|
||||||
|
session.messages[-1]["viz"] = viz_payload
|
||||||
|
nanobot_service.agent.sessions.save(session)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": response,
|
||||||
|
"viz": viz_payload,
|
||||||
|
"routing": {"selected": "agent", "reason": "auto_routed_by_agent"},
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -248,69 +163,49 @@ async def nanobot_chat_stream(request: ChatRequest):
|
|||||||
async def event_generator():
|
async def event_generator():
|
||||||
current_task = None
|
current_task = None
|
||||||
try:
|
try:
|
||||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
resolved_source = _resolve_effective_source(request)
|
||||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
current_data_source.set(resolved_source)
|
||||||
if use_nl2sql:
|
current_file_url.set(request.file_url)
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n"
|
current_session_id.set(request.session_id)
|
||||||
sql_progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
current_viz_data.set({})
|
||||||
|
|
||||||
async def _on_sql_progress(content: str) -> None:
|
yield f"data: {json.dumps({'type': 'routing', 'selected': 'agent', 'reason': 'auto_routed_by_agent'}, ensure_ascii=False)}\n\n"
|
||||||
if content:
|
|
||||||
await sql_progress_queue.put(content)
|
|
||||||
|
|
||||||
current_task = asyncio.create_task(
|
|
||||||
process_nl2sql(
|
|
||||||
NL2SQLRequest(
|
|
||||||
query=request.message,
|
|
||||||
source=resolved_source,
|
|
||||||
file_url=request.file_url,
|
|
||||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
|
||||||
),
|
|
||||||
on_progress=_on_sql_progress,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
if current_task.done() and sql_progress_queue.empty():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2)
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
nl2sql_result = await current_task
|
|
||||||
if nl2sql_result.error:
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': f'出错:{nl2sql_result.error},正在整理结果'}, ensure_ascii=False)}\n\n"
|
|
||||||
else:
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
|
|
||||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
|
||||||
viz_payload = {
|
|
||||||
"type": "viz",
|
|
||||||
**persisted_viz_payload,
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
|
|
||||||
text = _build_sql_chart_text(nl2sql_result)
|
|
||||||
_persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload})
|
|
||||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
|
||||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
|
||||||
return
|
|
||||||
progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||||
|
|
||||||
async def _on_progress(content: str, **_: Any) -> None:
|
async def _on_progress(content: str, **kwargs: Any) -> None:
|
||||||
if content:
|
if content:
|
||||||
await progress_queue.put(content)
|
await progress_queue.put(content)
|
||||||
|
|
||||||
|
current_progress_callback.set(_on_progress)
|
||||||
|
|
||||||
|
# Inject instructions if explicitly routed
|
||||||
|
message = request.message
|
||||||
|
if request.route_mode == "sql" or request.prefer_sql_chart:
|
||||||
|
message = f"[System: User explicitly requested data analysis. Please use the nl2sql tool to answer the following query.]\n{message}"
|
||||||
|
elif request.route_mode == "chat":
|
||||||
|
message = f"[System: User explicitly requested normal chat. Do NOT use the nl2sql tool.]\n{message}"
|
||||||
|
|
||||||
current_task = asyncio.create_task(
|
current_task = asyncio.create_task(
|
||||||
nanobot_service.process_message(
|
nanobot_service.process_message(
|
||||||
request.message,
|
message,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
skill_ids=request.skill_ids,
|
skill_ids=request.skill_ids,
|
||||||
model_id=request.model_id,
|
model_id=request.model_id,
|
||||||
on_progress=_on_progress,
|
on_progress=_on_progress,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n"
|
|
||||||
text = ""
|
text = ""
|
||||||
|
viz_sent = False
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
# Check for viz payload during processing
|
||||||
|
viz_payload = current_viz_data.get()
|
||||||
|
if viz_payload and not viz_sent:
|
||||||
|
yield f"data: {json.dumps({'type': 'viz', **viz_payload}, ensure_ascii=False)}\n\n"
|
||||||
|
viz_sent = True
|
||||||
|
|
||||||
if current_task.done() and progress_queue.empty():
|
if current_task.done() and progress_queue.empty():
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
@@ -318,11 +213,26 @@ async def nanobot_chat_stream(request: ChatRequest):
|
|||||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response = await current_task
|
response = await current_task
|
||||||
text = response or ""
|
text = response or ""
|
||||||
|
|
||||||
|
# Check again for viz payload after task completes if not sent yet
|
||||||
|
viz_payload = current_viz_data.get()
|
||||||
|
if viz_payload and not viz_sent:
|
||||||
|
yield f"data: {json.dumps({'type': 'viz', **viz_payload}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# Persist viz payload to session
|
||||||
|
if viz_payload and nanobot_service.agent:
|
||||||
|
session = nanobot_service.agent.sessions.get_or_create(request.session_id)
|
||||||
|
if session.messages and session.messages[-1].get("role") == "assistant":
|
||||||
|
session.messages[-1]["viz"] = viz_payload
|
||||||
|
nanobot_service.agent.sessions.save(session)
|
||||||
|
|
||||||
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
|
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
|
||||||
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
|
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
|
||||||
yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
Reference in New Issue
Block a user