reorg skill folder

This commit is contained in:
qixinbo
2026-03-19 12:27:31 +08:00
parent baec21c774
commit cca492cfdb
13 changed files with 232 additions and 81 deletions
+13 -1
View File
@@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """
- For daily question, the time unit should be "yearmonthdate".
- Default time unit is "yearmonth".
- For each axis, generate the corresponding human-readable title based on the language provided by the user.
- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data.
- **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property!
### GUIDELINES TO PLOT CHART ###
@@ -233,6 +233,18 @@ Language: Chinese (Simplified)
content = content.strip()
result = json.loads(content)
# Post-process to fix common LLM hallucinations (translating field names)
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
encoding = result["chart_spec"].get("encoding", {})
for channel, enc_def in encoding.items():
if isinstance(enc_def, dict) and "field" in enc_def:
field = enc_def["field"]
# If field is not in columns, try to find a match or let it be (Vega will render empty)
# But if we can detect it was translated, we might not be able to fix it perfectly.
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
# but it's safer to just rely on the stricter prompt.
return ChartGenerationResponse(**result)
except Exception as e:
+61 -26
View File
@@ -4,7 +4,6 @@ import os
import json
import time
import threading
import re
from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
NL2SQL_CHART_TIMEOUT_SECONDS = 45
NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
conn.close()
return result_df.to_dict(orient="records")
def _to_number(value: Any) -> Optional[float]:
if isinstance(value, bool):
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
text = value.strip().replace(",", "")
if not text:
return None
try:
return float(text)
except ValueError:
return None
return None
# _build_fallback_chart removed as per user request to not hardcode fallbacks
def _build_schema_cache_key(source: str, connector: Any) -> str:
# If source is ds:ID, that's already a good key
if source.startswith("ds:"):
@@ -270,7 +288,7 @@ async def process_nl2sql(
if connector:
await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else:
@@ -290,23 +308,6 @@ async def process_nl2sql(
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not await _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=30.0
)
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
# Try to load MDL context
@@ -378,18 +379,45 @@ Language: Chinese (Simplified)
try:
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = None
last_error = ""
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
num_retries=0,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
except Exception as e:
last_error = f"LLM generation failed: {e}"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
if response.finish_reason == "error":
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
last_error = response.content or "LLM Error"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
break
if response is None:
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
content = (response.content or "").strip()
if not content:
@@ -406,14 +434,11 @@ Language: Chinese (Simplified)
try:
result_json = json.loads(content)
sql_query = result_json.get("sql", "").strip()
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
if not chart_response or not chart_response.chart_spec:
# Do not fallback automatically if the LLM explicitly decided not to or failed.
# Just pass whatever it returned (or lack thereof)
pass
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except asyncio.TimeoutError:
if timeout_stage == "chart_generation":
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
fallback_chart = ChartGenerationResponse(
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
chart_type="",
can_visualize=False,
chart_spec=None,
)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
+38 -3
View File
@@ -3,6 +3,7 @@ import os
import shutil
import zipfile
import tarfile
import re
import yaml
from typing import List, Optional, Dict, Any
from datetime import datetime
@@ -111,6 +112,26 @@ def _save_data(data: List[Dict[str, Any]]):
with open(DATA_FILE, "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def _safe_skill_dir_name(value: str) -> str:
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
return safe or "skill"
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
os.makedirs(skill_dir, exist_ok=True)
skill_md_path = os.path.join(skill_dir, "SKILL.md")
final_description = description or "No description provided"
body = content or ""
markdown = (
f"---\n"
f"name: {skill_name}\n"
f"description: {final_description}\n"
f"---\n\n"
f"{body}\n"
)
with open(skill_md_path, "w", encoding="utf-8") as f:
f.write(markdown)
return skill_md_path
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
data = _load_data()
if project_id is not None:
@@ -205,9 +226,7 @@ async def upload_skill(
skill_name = os.path.splitext(filename)[0]
# Create a safe directory name for the skill
import re
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', skill_name).lower()
if not safe_name: safe_name = "skill"
safe_name = _safe_skill_dir_name(skill_name)
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
@@ -264,6 +283,15 @@ def create_skill(skill: SkillCreate):
new_skill_dict = skill.dict()
if not new_skill_dict.get("installation_time"):
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d")
if not new_skill_dict.get("file_path"):
skill_dir = os.path.join(SKILL_HUB_DIR, _safe_skill_dir_name(new_skill_dict["id"]))
_write_skill_markdown(
skill_dir=skill_dir,
skill_name=new_skill_dict["name"],
description=new_skill_dict.get("description"),
content=new_skill_dict.get("content", ""),
)
new_skill_dict["file_path"] = skill_dir
data.append(new_skill_dict)
_save_data(data)
@@ -279,6 +307,13 @@ def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] =
updated_item = item.copy()
update_data = skill.dict(exclude_unset=True)
updated_item.update(update_data)
if updated_item.get("file_path"):
_write_skill_markdown(
skill_dir=updated_item["file_path"],
skill_name=updated_item.get("name") or item.get("name") or "skill",
description=updated_item.get("description"),
content=updated_item.get("content", ""),
)
data[i] = updated_item
_save_data(data)
return Skill(**updated_item)
+16
View File
@@ -1,6 +1,7 @@
import asyncio
import sys
import os
import shutil
from pathlib import Path
from typing import List, Callable, Awaitable, Any, Dict
@@ -47,6 +48,7 @@ class NanobotIntegration:
# Set workspace path to backend/data/workspace
workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace"))
workspace_path.mkdir(parents=True, exist_ok=True)
self._sync_builtin_skills_to_workspace(workspace_path)
# Override config workspace path via environment variable (since config is loaded from env)
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
@@ -87,6 +89,20 @@ class NanobotIntegration:
self._register_custom_tools(self.agent)
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
workspace_skills_root = workspace_path / "skills"
workspace_skills_root.mkdir(parents=True, exist_ok=True)
for skill_name in ("nl2sql", "visualization"):
source_dir = builtin_root / skill_name
source_skill_file = source_dir / "SKILL.md"
if not source_skill_file.exists():
continue
target_dir = workspace_skills_root / skill_name
target_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
def _register_custom_tools(self, agent: AgentLoop):
from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool
+2 -9
View File
@@ -123,13 +123,6 @@ class ChartGenerationResponse(BaseModel):
chart_type: Literal[
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
] = Field(..., description="The type of chart generated, or empty string if none")
chart_spec: Optional[Union[
LineChartSchema,
MultiLineChartSchema,
BarChartSchema,
PieChartSchema,
GroupedBarChartSchema,
StackedBarChartSchema,
AreaChartSchema
]] = Field(None, description="The generated Vega-Lite chart specification")
# Using Dict[str, Any] allows LLM to output valid Vega-Lite spec directly, avoiding Pydantic strict model serialization issues with dynamic fields
chart_spec: Optional[Dict[str, Any]] = Field(None, description="The generated Vega-Lite chart specification")
can_visualize: bool = Field(..., description="Whether the data can be visualized")
@@ -0,0 +1,24 @@
---
description: Data Analysis and SQL Generation
metadata:
nanobot:
always: true
---
# NL2SQL Data Analysis Skill
You are an expert data analyst. You have access to a powerful `nl2sql` tool that can query the connected database using natural language.
## When to use this skill
- When the user asks to query, analyze, aggregate, or fetch data from the database.
- Examples: "Show me the top 10 sales", "What is the average revenue by month?", "How many users registered yesterday?".
## How to use this skill
- Call the `nl2sql` tool with the user's natural language query.
- If the user explicitly asks to "visualize" or "plot" the data in the SAME message as the query (e.g., "Show me sales by region and plot it as a pie chart"), you can set `generate_chart=True` in the `nl2sql` tool.
- If the user ONLY asks to query data, set `generate_chart=False` (default).
## After using the tool
- The tool will return a summary of the executed query and a sample of the results.
- Use this information to provide a clear, concise, and helpful response to the user.
- If a chart was successfully generated by the tool, inform the user that the chart is available in the visualization panel.
@@ -0,0 +1,23 @@
---
description: Data Visualization and Chart Generation
metadata:
nanobot:
always: true
---
# Data Visualization Skill
You are an expert data visualization specialist. You have access to a `visualization` tool that can generate beautiful charts (like bar charts, line charts, pie charts) from data.
## When to use this skill
- When the user asks to visualize, plot, or draw a chart based on data that has ALREADY been queried or is currently in context.
- Examples: "Visualize it as a bar chart", "Plot the trend over time", "Draw a pie chart of the regions".
- DO NOT use this tool if the data hasn't been queried yet. If the user asks a new question and wants it visualized (e.g., "Show me sales and plot it"), use the `nl2sql` tool with `generate_chart=True` instead, or call `nl2sql` first and then this tool.
## How to use this skill
- Call the `visualization` tool with the user's specific visualization request (e.g., "plot as a pie chart").
- The tool relies on the data from the most recent SQL query. It will automatically read this data from the context.
## After using the tool
- The tool will return a success message and the reasoning for the chosen chart type.
- Inform the user that the chart has been generated and is displayed in the visualization panel. Explain briefly what the chart shows if helpful.
+2 -2
View File
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Dict
from typing import Any
from nanobot.agent.tools.base import Tool
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
@@ -14,7 +14,7 @@ def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
payload = {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
"error": nl2sql_result.error,
}
return jsonable_encoder(payload)
+1 -1
View File
@@ -60,7 +60,7 @@ class VisualizationTool(Tool):
viz_payload = {
"sql": existing_viz.get("sql", ""),
"result": data,
"chart": chart_response.model_dump(),
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
"error": None,
}
encoded_viz = jsonable_encoder(viz_payload)
+6 -10
View File
@@ -313,8 +313,8 @@ export function ChatInterface() {
}): MessageViz => {
const rows = Array.isArray(payload.result) ? payload.result : [];
const chart = payload.chart ?? undefined;
const canVisualize = Boolean(chart?.can_visualize);
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
const canVisualize = chart?.can_visualize ?? Boolean(chart?.chart_spec);
const chartSpec = chart?.chart_spec ?? null;
return {
sql: typeof payload.sql === "string" ? payload.sql : "",
rows,
@@ -553,12 +553,12 @@ export function ChatInterface() {
let renderedText = "";
const flushAssistant = (force = false) => {
if (streamedText === renderedText) return;
if (streamedText === renderedText && !force) return;
if (force) {
renderedText = streamedText;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
)
);
return;
@@ -571,7 +571,7 @@ export function ChatInterface() {
renderedText = streamedText;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
)
);
});
@@ -645,11 +645,7 @@ export function ChatInterface() {
if (payload.type === "viz") {
pushProgressLog("可视化结果已生成");
streamedViz = buildMessageViz(payload);
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, viz: streamedViz || undefined } : msg
)
);
flushAssistant(true); // 立即把 viz 状态刷入 messages
}
}
}
@@ -171,8 +171,8 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) {
</div>
{view === "chart" ? (
viz.canVisualize && viz.chartSpec && objectRows.length > 0 ? (
<div className="w-full h-80 rounded-xl border border-zinc-100 p-2">
viz.chartSpec && objectRows.length > 0 ? (
<div className="w-full h-80 min-h-[320px] rounded-xl border border-zinc-100 p-2">
<VegaChart data={objectRows} spec={viz.chartSpec} />
</div>
) : (
+18 -9
View File
@@ -30,23 +30,27 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
}, []);
const vegaSpec: any = useMemo(() => {
// Clone spec and ensure tooltip is enabled in mark if not already specified
const baseSpec = { ...spec };
// Deep clone spec to avoid mutating React state/props
const baseSpec = JSON.parse(JSON.stringify(spec));
// Ensure tooltip is enabled in mark if not already specified
if (typeof baseSpec.mark === 'string') {
baseSpec.mark = { type: baseSpec.mark, tooltip: true };
} else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
baseSpec.mark = { ...baseSpec.mark, tooltip: true };
baseSpec.mark.tooltip = true;
}
// Add highlight effect: hover over an element makes others transparent
// 1. Define hover param
if (!baseSpec.params) {
baseSpec.params = [
{
baseSpec.params = [];
}
const hasHighlight = baseSpec.params.some((p: any) => p.name === "highlight");
if (!hasHighlight) {
baseSpec.params.push({
name: "highlight",
select: { type: "point", on: "mouseover", clear: "mouseout" }
}
];
});
}
// 2. Add conditional opacity to encoding
@@ -64,7 +68,7 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
// Also add cursor: pointer for marks
if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
(baseSpec.mark as any).cursor = "pointer";
baseSpec.mark.cursor = "pointer";
}
return {
@@ -77,12 +81,17 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
};
}, [data, size.height, size.width, spec]);
const handleError = (error: any) => {
console.error("VegaEmbed rendering error:", error, "Spec:", vegaSpec);
};
return (
<div className="w-full h-full" ref={containerRef}>
<div className="w-full h-full min-h-[300px]" ref={containerRef}>
<VegaEmbed
spec={vegaSpec}
options={{ actions: false }}
style={{width: '100%', height: '100%'}}
onError={handleError}
/>
</div>
);
@@ -214,6 +214,8 @@ class LiteLLMProvider(LLMProvider):
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
request_timeout: float | None = None,
num_retries: int | None = None,
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
@@ -269,6 +271,12 @@ class LiteLLMProvider(LLMProvider):
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
if request_timeout is not None:
kwargs["timeout"] = request_timeout
if num_retries is not None:
kwargs["num_retries"] = max(0, int(num_retries))
try:
response = await acompletion(**kwargs)
return self._parse_response(response)