reorg skill folder

This commit is contained in:
qixinbo
2026-03-19 12:27:31 +08:00
parent baec21c774
commit cca492cfdb
13 changed files with 232 additions and 81 deletions
+13 -1
View File
@@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """
- For daily question, the time unit should be "yearmonthdate".
- Default time unit is "yearmonth".
- For each axis, generate the corresponding human-readable title based on the language provided by the user.
- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data.
- **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property!
### GUIDELINES TO PLOT CHART ###
@@ -233,6 +233,18 @@ Language: Chinese (Simplified)
content = content.strip()
result = json.loads(content)
# Post-process to fix common LLM hallucinations (translating field names)
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
encoding = result["chart_spec"].get("encoding", {})
for channel, enc_def in encoding.items():
if isinstance(enc_def, dict) and "field" in enc_def:
field = enc_def["field"]
# If field is not in columns, try to find a match or let it be (Vega will render empty)
# But if we can detect it was translated, we might not be able to fix it perfectly.
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
# but it's safer to just rely on the stricter prompt.
return ChartGenerationResponse(**result)
except Exception as e:
+77 -42
View File
@@ -4,7 +4,6 @@ import os
import json
import time
import threading
import re
from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
NL2SQL_CHART_TIMEOUT_SECONDS = 45
NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
conn.close()
return result_df.to_dict(orient="records")
def _to_number(value: Any) -> Optional[float]:
if isinstance(value, bool):
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
text = value.strip().replace(",", "")
if not text:
return None
try:
return float(text)
except ValueError:
return None
return None
# _build_fallback_chart removed as per user request to not hardcode fallbacks
def _build_schema_cache_key(source: str, connector: Any) -> str:
# If source is ds:ID, that's already a good key
if source.startswith("ds:"):
@@ -270,7 +288,7 @@ async def process_nl2sql(
if connector:
await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else:
@@ -289,23 +307,6 @@ async def process_nl2sql(
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not await _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=30.0
)
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
@@ -378,42 +379,66 @@ Language: Chinese (Simplified)
try:
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
if response.finish_reason == "error":
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
response = None
last_error = ""
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
num_retries=0,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
except Exception as e:
last_error = f"LLM generation failed: {e}"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
if response.finish_reason == "error":
last_error = response.content or "LLM Error"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
break
if response is None:
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
content = (response.content or "").strip()
if not content:
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
# Clean up code blocks
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
content = content.strip()
try:
result_json = json.loads(content)
sql_query = result_json.get("sql", "").strip()
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
if not chart_response or not chart_response.chart_spec:
# Do not fallback automatically if the LLM explicitly decided not to or failed.
# Just pass whatever it returned (or lack thereof)
pass
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except asyncio.TimeoutError:
if timeout_stage == "chart_generation":
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
fallback_chart = ChartGenerationResponse(
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
chart_type="",
can_visualize=False,
chart_spec=None,
)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
+38 -3
View File
@@ -3,6 +3,7 @@ import os
import shutil
import zipfile
import tarfile
import re
import yaml
from typing import List, Optional, Dict, Any
from datetime import datetime
@@ -111,6 +112,26 @@ def _save_data(data: List[Dict[str, Any]]):
with open(DATA_FILE, "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def _safe_skill_dir_name(value: str) -> str:
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
return safe or "skill"
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
os.makedirs(skill_dir, exist_ok=True)
skill_md_path = os.path.join(skill_dir, "SKILL.md")
final_description = description or "No description provided"
body = content or ""
markdown = (
f"---\n"
f"name: {skill_name}\n"
f"description: {final_description}\n"
f"---\n\n"
f"{body}\n"
)
with open(skill_md_path, "w", encoding="utf-8") as f:
f.write(markdown)
return skill_md_path
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
data = _load_data()
if project_id is not None:
@@ -205,9 +226,7 @@ async def upload_skill(
skill_name = os.path.splitext(filename)[0]
# Create a safe directory name for the skill
import re
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', skill_name).lower()
if not safe_name: safe_name = "skill"
safe_name = _safe_skill_dir_name(skill_name)
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
@@ -264,6 +283,15 @@ def create_skill(skill: SkillCreate):
new_skill_dict = skill.dict()
if not new_skill_dict.get("installation_time"):
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d")
if not new_skill_dict.get("file_path"):
skill_dir = os.path.join(SKILL_HUB_DIR, _safe_skill_dir_name(new_skill_dict["id"]))
_write_skill_markdown(
skill_dir=skill_dir,
skill_name=new_skill_dict["name"],
description=new_skill_dict.get("description"),
content=new_skill_dict.get("content", ""),
)
new_skill_dict["file_path"] = skill_dir
data.append(new_skill_dict)
_save_data(data)
@@ -279,6 +307,13 @@ def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] =
updated_item = item.copy()
update_data = skill.dict(exclude_unset=True)
updated_item.update(update_data)
if updated_item.get("file_path"):
_write_skill_markdown(
skill_dir=updated_item["file_path"],
skill_name=updated_item.get("name") or item.get("name") or "skill",
description=updated_item.get("description"),
content=updated_item.get("content", ""),
)
data[i] = updated_item
_save_data(data)
return Skill(**updated_item)
+16
View File
@@ -1,6 +1,7 @@
import asyncio
import sys
import os
import shutil
from pathlib import Path
from typing import List, Callable, Awaitable, Any, Dict
@@ -47,6 +48,7 @@ class NanobotIntegration:
# Set workspace path to backend/data/workspace
workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace"))
workspace_path.mkdir(parents=True, exist_ok=True)
self._sync_builtin_skills_to_workspace(workspace_path)
# Override config workspace path via environment variable (since config is loaded from env)
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
@@ -87,6 +89,20 @@ class NanobotIntegration:
self._register_custom_tools(self.agent)
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
workspace_skills_root = workspace_path / "skills"
workspace_skills_root.mkdir(parents=True, exist_ok=True)
for skill_name in ("nl2sql", "visualization"):
source_dir = builtin_root / skill_name
source_skill_file = source_dir / "SKILL.md"
if not source_skill_file.exists():
continue
target_dir = workspace_skills_root / skill_name
target_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
def _register_custom_tools(self, agent: AgentLoop):
from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool
+2 -9
View File
@@ -123,13 +123,6 @@ class ChartGenerationResponse(BaseModel):
chart_type: Literal[
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
] = Field(..., description="The type of chart generated, or empty string if none")
chart_spec: Optional[Union[
LineChartSchema,
MultiLineChartSchema,
BarChartSchema,
PieChartSchema,
GroupedBarChartSchema,
StackedBarChartSchema,
AreaChartSchema
]] = Field(None, description="The generated Vega-Lite chart specification")
# Using Dict[str, Any] allows LLM to output valid Vega-Lite spec directly, avoiding Pydantic strict model serialization issues with dynamic fields
chart_spec: Optional[Dict[str, Any]] = Field(None, description="The generated Vega-Lite chart specification")
can_visualize: bool = Field(..., description="Whether the data can be visualized")
@@ -0,0 +1,24 @@
---
description: Data Analysis and SQL Generation
metadata:
nanobot:
always: true
---
# NL2SQL Data Analysis Skill
You are an expert data analyst. You have access to a powerful `nl2sql` tool that can query the connected database using natural language.
## When to use this skill
- When the user asks to query, analyze, aggregate, or fetch data from the database.
- Examples: "Show me the top 10 sales", "What is the average revenue by month?", "How many users registered yesterday?".
## How to use this skill
- Call the `nl2sql` tool with the user's natural language query.
- If the user explicitly asks to "visualize" or "plot" the data in the SAME message as the query (e.g., "Show me sales by region and plot it as a pie chart"), you can set `generate_chart=True` in the `nl2sql` tool.
- If the user ONLY asks to query data, set `generate_chart=False` (default).
## After using the tool
- The tool will return a summary of the executed query and a sample of the results.
- Use this information to provide a clear, concise, and helpful response to the user.
- If a chart was successfully generated by the tool, inform the user that the chart is available in the visualization panel.
@@ -0,0 +1,23 @@
---
description: Data Visualization and Chart Generation
metadata:
nanobot:
always: true
---
# Data Visualization Skill
You are an expert data visualization specialist. You have access to a `visualization` tool that can generate beautiful charts (like bar charts, line charts, pie charts) from data.
## When to use this skill
- When the user asks to visualize, plot, or draw a chart based on data that has ALREADY been queried or is currently in context.
- Examples: "Visualize it as a bar chart", "Plot the trend over time", "Draw a pie chart of the regions".
- DO NOT use this tool if the data hasn't been queried yet. If the user asks a new question and wants it visualized (e.g., "Show me sales and plot it"), use the `nl2sql` tool with `generate_chart=True` instead, or call `nl2sql` first and then this tool.
## How to use this skill
- Call the `visualization` tool with the user's specific visualization request (e.g., "plot as a pie chart").
- The tool relies on the data from the most recent SQL query. It will automatically read this data from the context.
## After using the tool
- The tool will return a success message and the reasoning for the chosen chart type.
- Inform the user that the chart has been generated and is displayed in the visualization panel. Explain briefly what the chart shows if helpful.
+2 -2
View File
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Dict
from typing import Any
from nanobot.agent.tools.base import Tool
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
@@ -14,7 +14,7 @@ def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
payload = {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
"error": nl2sql_result.error,
}
return jsonable_encoder(payload)
+1 -1
View File
@@ -60,7 +60,7 @@ class VisualizationTool(Tool):
viz_payload = {
"sql": existing_viz.get("sql", ""),
"result": data,
"chart": chart_response.model_dump(),
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
"error": None,
}
encoded_viz = jsonable_encoder(viz_payload)