reorg skill folder
This commit is contained in:
@@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """
|
|||||||
- For daily question, the time unit should be "yearmonthdate".
|
- For daily question, the time unit should be "yearmonthdate".
|
||||||
- Default time unit is "yearmonth".
|
- Default time unit is "yearmonth".
|
||||||
- For each axis, generate the corresponding human-readable title based on the language provided by the user.
|
- For each axis, generate the corresponding human-readable title based on the language provided by the user.
|
||||||
- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data.
|
- **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property!
|
||||||
|
|
||||||
### GUIDELINES TO PLOT CHART ###
|
### GUIDELINES TO PLOT CHART ###
|
||||||
|
|
||||||
@@ -233,6 +233,18 @@ Language: Chinese (Simplified)
|
|||||||
|
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
|
|
||||||
|
# Post-process to fix common LLM hallucinations (translating field names)
|
||||||
|
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
|
||||||
|
encoding = result["chart_spec"].get("encoding", {})
|
||||||
|
for channel, enc_def in encoding.items():
|
||||||
|
if isinstance(enc_def, dict) and "field" in enc_def:
|
||||||
|
field = enc_def["field"]
|
||||||
|
# If field is not in columns, try to find a match or let it be (Vega will render empty)
|
||||||
|
# But if we can detect it was translated, we might not be able to fix it perfectly.
|
||||||
|
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
|
||||||
|
# but it's safer to just rely on the stricter prompt.
|
||||||
|
|
||||||
return ChartGenerationResponse(**result)
|
return ChartGenerationResponse(**result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+77
-42
@@ -4,7 +4,6 @@ import os
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
|
|||||||
NL2SQL_MAX_TOKENS = 900
|
NL2SQL_MAX_TOKENS = 900
|
||||||
NL2SQL_TEMPERATURE = 0.1
|
NL2SQL_TEMPERATURE = 0.1
|
||||||
NL2SQL_REASONING_EFFORT = "low"
|
NL2SQL_REASONING_EFFORT = "low"
|
||||||
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
|
NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
|
||||||
|
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
|
||||||
|
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
|
||||||
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
||||||
NL2SQL_CHART_TIMEOUT_SECONDS = 45
|
NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
|
||||||
|
|
||||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
|
|||||||
conn.close()
|
conn.close()
|
||||||
return result_df.to_dict(orient="records")
|
return result_df.to_dict(orient="records")
|
||||||
|
|
||||||
|
def _to_number(value: Any) -> Optional[float]:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return None
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
text = value.strip().replace(",", "")
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(text)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# _build_fallback_chart removed as per user request to not hardcode fallbacks
|
||||||
|
|
||||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||||
# If source is ds:ID, that's already a good key
|
# If source is ds:ID, that's already a good key
|
||||||
if source.startswith("ds:"):
|
if source.startswith("ds:"):
|
||||||
@@ -270,7 +288,7 @@ async def process_nl2sql(
|
|||||||
if connector:
|
if connector:
|
||||||
await emit_progress("正在检测数据源连通性")
|
await emit_progress("正在检测数据源连通性")
|
||||||
cached_schema = _get_cached_schema(request.source, connector)
|
cached_schema = _get_cached_schema(request.source, connector)
|
||||||
if cached_schema:
|
if cached_schema is not None:
|
||||||
schema = cached_schema
|
schema = cached_schema
|
||||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||||
else:
|
else:
|
||||||
@@ -289,23 +307,6 @@ async def process_nl2sql(
|
|||||||
|
|
||||||
_set_cached_schema(request.source, connector, schema)
|
_set_cached_schema(request.source, connector, schema)
|
||||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||||
|
|
||||||
if connector and not schema:
|
|
||||||
retry_started = time.perf_counter()
|
|
||||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
|
||||||
if not await _check_connection_with_cache(request.source, connector):
|
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema = await asyncio.wait_for(
|
|
||||||
asyncio.to_thread(connector.get_schema),
|
|
||||||
timeout=30.0
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
|
|
||||||
|
|
||||||
_set_cached_schema(request.source, connector, schema)
|
|
||||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
|
||||||
|
|
||||||
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
||||||
|
|
||||||
@@ -378,42 +379,66 @@ Language: Chinese (Simplified)
|
|||||||
try:
|
try:
|
||||||
llm_started = time.perf_counter()
|
llm_started = time.perf_counter()
|
||||||
await emit_progress("正在生成 SQL")
|
await emit_progress("正在生成 SQL")
|
||||||
response = await asyncio.wait_for(
|
response = None
|
||||||
provider.chat(
|
last_error = ""
|
||||||
messages=messages,
|
|
||||||
max_tokens=NL2SQL_MAX_TOKENS,
|
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||||
temperature=NL2SQL_TEMPERATURE,
|
try:
|
||||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
response = await asyncio.wait_for(
|
||||||
),
|
provider.chat(
|
||||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
messages=messages,
|
||||||
)
|
max_tokens=NL2SQL_MAX_TOKENS,
|
||||||
|
temperature=NL2SQL_TEMPERATURE,
|
||||||
if response.finish_reason == "error":
|
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||||
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
|
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||||
|
num_retries=0,
|
||||||
|
),
|
||||||
|
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||||
|
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||||
|
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||||
|
continue
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = f"LLM generation failed: {e}"
|
||||||
|
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||||
|
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||||
|
continue
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||||
|
|
||||||
|
if response.finish_reason == "error":
|
||||||
|
last_error = response.content or "LLM Error"
|
||||||
|
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||||
|
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||||
|
continue
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||||
|
break
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
|
||||||
|
|
||||||
content = (response.content or "").strip()
|
content = (response.content or "").strip()
|
||||||
if not content:
|
if not content:
|
||||||
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
||||||
|
|
||||||
# Clean up code blocks
|
# Clean up code blocks
|
||||||
if "```json" in content:
|
if "```json" in content:
|
||||||
content = content.split("```json")[1].split("```")[0]
|
content = content.split("```json")[1].split("```")[0]
|
||||||
elif "```" in content:
|
elif "```" in content:
|
||||||
content = content.split("```")[1].split("```")[0]
|
content = content.split("```")[1].split("```")[0]
|
||||||
|
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_json = json.loads(content)
|
result_json = json.loads(content)
|
||||||
sql_query = result_json.get("sql", "").strip()
|
sql_query = result_json.get("sql", "").strip()
|
||||||
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||||
sql_query = content
|
sql_query = content
|
||||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||||
|
|
||||||
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
|
|||||||
generate_chart(formatted_results, request.query),
|
generate_chart(formatted_results, request.query),
|
||||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
|
if not chart_response or not chart_response.chart_spec:
|
||||||
|
# Do not fallback automatically if the LLM explicitly decided not to or failed.
|
||||||
|
# Just pass whatever it returned (or lack thereof)
|
||||||
|
pass
|
||||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||||
|
|
||||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if timeout_stage == "chart_generation":
|
if timeout_stage == "chart_generation":
|
||||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
|
fallback_chart = ChartGenerationResponse(
|
||||||
|
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
|
||||||
|
chart_type="",
|
||||||
|
can_visualize=False,
|
||||||
|
chart_spec=None,
|
||||||
|
)
|
||||||
|
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
|
||||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import zipfile
|
import zipfile
|
||||||
import tarfile
|
import tarfile
|
||||||
|
import re
|
||||||
import yaml
|
import yaml
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -111,6 +112,26 @@ def _save_data(data: List[Dict[str, Any]]):
|
|||||||
with open(DATA_FILE, "w") as f:
|
with open(DATA_FILE, "w") as f:
|
||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
def _safe_skill_dir_name(value: str) -> str:
|
||||||
|
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
|
||||||
|
return safe or "skill"
|
||||||
|
|
||||||
|
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
|
||||||
|
os.makedirs(skill_dir, exist_ok=True)
|
||||||
|
skill_md_path = os.path.join(skill_dir, "SKILL.md")
|
||||||
|
final_description = description or "No description provided"
|
||||||
|
body = content or ""
|
||||||
|
markdown = (
|
||||||
|
f"---\n"
|
||||||
|
f"name: {skill_name}\n"
|
||||||
|
f"description: {final_description}\n"
|
||||||
|
f"---\n\n"
|
||||||
|
f"{body}\n"
|
||||||
|
)
|
||||||
|
with open(skill_md_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(markdown)
|
||||||
|
return skill_md_path
|
||||||
|
|
||||||
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||||
data = _load_data()
|
data = _load_data()
|
||||||
if project_id is not None:
|
if project_id is not None:
|
||||||
@@ -205,9 +226,7 @@ async def upload_skill(
|
|||||||
skill_name = os.path.splitext(filename)[0]
|
skill_name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
# Create a safe directory name for the skill
|
# Create a safe directory name for the skill
|
||||||
import re
|
safe_name = _safe_skill_dir_name(skill_name)
|
||||||
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', skill_name).lower()
|
|
||||||
if not safe_name: safe_name = "skill"
|
|
||||||
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
|
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
|
||||||
|
|
||||||
@@ -264,6 +283,15 @@ def create_skill(skill: SkillCreate):
|
|||||||
new_skill_dict = skill.dict()
|
new_skill_dict = skill.dict()
|
||||||
if not new_skill_dict.get("installation_time"):
|
if not new_skill_dict.get("installation_time"):
|
||||||
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d日")
|
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d日")
|
||||||
|
if not new_skill_dict.get("file_path"):
|
||||||
|
skill_dir = os.path.join(SKILL_HUB_DIR, _safe_skill_dir_name(new_skill_dict["id"]))
|
||||||
|
_write_skill_markdown(
|
||||||
|
skill_dir=skill_dir,
|
||||||
|
skill_name=new_skill_dict["name"],
|
||||||
|
description=new_skill_dict.get("description"),
|
||||||
|
content=new_skill_dict.get("content", ""),
|
||||||
|
)
|
||||||
|
new_skill_dict["file_path"] = skill_dir
|
||||||
|
|
||||||
data.append(new_skill_dict)
|
data.append(new_skill_dict)
|
||||||
_save_data(data)
|
_save_data(data)
|
||||||
@@ -279,6 +307,13 @@ def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] =
|
|||||||
updated_item = item.copy()
|
updated_item = item.copy()
|
||||||
update_data = skill.dict(exclude_unset=True)
|
update_data = skill.dict(exclude_unset=True)
|
||||||
updated_item.update(update_data)
|
updated_item.update(update_data)
|
||||||
|
if updated_item.get("file_path"):
|
||||||
|
_write_skill_markdown(
|
||||||
|
skill_dir=updated_item["file_path"],
|
||||||
|
skill_name=updated_item.get("name") or item.get("name") or "skill",
|
||||||
|
description=updated_item.get("description"),
|
||||||
|
content=updated_item.get("content", ""),
|
||||||
|
)
|
||||||
data[i] = updated_item
|
data[i] = updated_item
|
||||||
_save_data(data)
|
_save_data(data)
|
||||||
return Skill(**updated_item)
|
return Skill(**updated_item)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Callable, Awaitable, Any, Dict
|
from typing import List, Callable, Awaitable, Any, Dict
|
||||||
|
|
||||||
@@ -47,6 +48,7 @@ class NanobotIntegration:
|
|||||||
# Set workspace path to backend/data/workspace
|
# Set workspace path to backend/data/workspace
|
||||||
workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace"))
|
workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace"))
|
||||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._sync_builtin_skills_to_workspace(workspace_path)
|
||||||
|
|
||||||
# Override config workspace path via environment variable (since config is loaded from env)
|
# Override config workspace path via environment variable (since config is loaded from env)
|
||||||
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
|
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
|
||||||
@@ -87,6 +89,20 @@ class NanobotIntegration:
|
|||||||
|
|
||||||
self._register_custom_tools(self.agent)
|
self._register_custom_tools(self.agent)
|
||||||
|
|
||||||
|
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
|
||||||
|
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
|
||||||
|
workspace_skills_root = workspace_path / "skills"
|
||||||
|
workspace_skills_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for skill_name in ("nl2sql", "visualization"):
|
||||||
|
source_dir = builtin_root / skill_name
|
||||||
|
source_skill_file = source_dir / "SKILL.md"
|
||||||
|
if not source_skill_file.exists():
|
||||||
|
continue
|
||||||
|
target_dir = workspace_skills_root / skill_name
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
|
||||||
|
|
||||||
def _register_custom_tools(self, agent: AgentLoop):
|
def _register_custom_tools(self, agent: AgentLoop):
|
||||||
from app.tools.nl2sql import NL2SQLTool
|
from app.tools.nl2sql import NL2SQLTool
|
||||||
from app.tools.visualization import VisualizationTool
|
from app.tools.visualization import VisualizationTool
|
||||||
|
|||||||
@@ -123,13 +123,6 @@ class ChartGenerationResponse(BaseModel):
|
|||||||
chart_type: Literal[
|
chart_type: Literal[
|
||||||
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
|
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
|
||||||
] = Field(..., description="The type of chart generated, or empty string if none")
|
] = Field(..., description="The type of chart generated, or empty string if none")
|
||||||
chart_spec: Optional[Union[
|
# Using Dict[str, Any] allows LLM to output valid Vega-Lite spec directly, avoiding Pydantic strict model serialization issues with dynamic fields
|
||||||
LineChartSchema,
|
chart_spec: Optional[Dict[str, Any]] = Field(None, description="The generated Vega-Lite chart specification")
|
||||||
MultiLineChartSchema,
|
|
||||||
BarChartSchema,
|
|
||||||
PieChartSchema,
|
|
||||||
GroupedBarChartSchema,
|
|
||||||
StackedBarChartSchema,
|
|
||||||
AreaChartSchema
|
|
||||||
]] = Field(None, description="The generated Vega-Lite chart specification")
|
|
||||||
can_visualize: bool = Field(..., description="Whether the data can be visualized")
|
can_visualize: bool = Field(..., description="Whether the data can be visualized")
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
---
|
||||||
|
description: Data Analysis and SQL Generation
|
||||||
|
metadata:
|
||||||
|
nanobot:
|
||||||
|
always: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# NL2SQL Data Analysis Skill
|
||||||
|
|
||||||
|
You are an expert data analyst. You have access to a powerful `nl2sql` tool that can query the connected database using natural language.
|
||||||
|
|
||||||
|
## When to use this skill
|
||||||
|
- When the user asks to query, analyze, aggregate, or fetch data from the database.
|
||||||
|
- Examples: "Show me the top 10 sales", "What is the average revenue by month?", "How many users registered yesterday?".
|
||||||
|
|
||||||
|
## How to use this skill
|
||||||
|
- Call the `nl2sql` tool with the user's natural language query.
|
||||||
|
- If the user explicitly asks to "visualize" or "plot" the data in the SAME message as the query (e.g., "Show me sales by region and plot it as a pie chart"), you can set `generate_chart=True` in the `nl2sql` tool.
|
||||||
|
- If the user ONLY asks to query data, set `generate_chart=False` (default).
|
||||||
|
|
||||||
|
## After using the tool
|
||||||
|
- The tool will return a summary of the executed query and a sample of the results.
|
||||||
|
- Use this information to provide a clear, concise, and helpful response to the user.
|
||||||
|
- If a chart was successfully generated by the tool, inform the user that the chart is available in the visualization panel.
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
---
|
||||||
|
description: Data Visualization and Chart Generation
|
||||||
|
metadata:
|
||||||
|
nanobot:
|
||||||
|
always: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# Data Visualization Skill
|
||||||
|
|
||||||
|
You are an expert data visualization specialist. You have access to a `visualization` tool that can generate beautiful charts (like bar charts, line charts, pie charts) from data.
|
||||||
|
|
||||||
|
## When to use this skill
|
||||||
|
- When the user asks to visualize, plot, or draw a chart based on data that has ALREADY been queried or is currently in context.
|
||||||
|
- Examples: "Visualize it as a bar chart", "Plot the trend over time", "Draw a pie chart of the regions".
|
||||||
|
- DO NOT use this tool if the data hasn't been queried yet. If the user asks a new question and wants it visualized (e.g., "Show me sales and plot it"), use the `nl2sql` tool with `generate_chart=True` instead, or call `nl2sql` first and then this tool.
|
||||||
|
|
||||||
|
## How to use this skill
|
||||||
|
- Call the `visualization` tool with the user's specific visualization request (e.g., "plot as a pie chart").
|
||||||
|
- The tool relies on the data from the most recent SQL query. It will automatically read this data from the context.
|
||||||
|
|
||||||
|
## After using the tool
|
||||||
|
- The tool will return a success message and the reasoning for the chosen chart type.
|
||||||
|
- Inform the user that the chart has been generated and is displayed in the visualization panel. Explain briefly what the chart shows if helpful.
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||||
@@ -14,7 +14,7 @@ def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
|||||||
payload = {
|
payload = {
|
||||||
"sql": nl2sql_result.sql,
|
"sql": nl2sql_result.sql,
|
||||||
"result": nl2sql_result.result,
|
"result": nl2sql_result.result,
|
||||||
"chart": chart.model_dump() if chart else None,
|
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
|
||||||
"error": nl2sql_result.error,
|
"error": nl2sql_result.error,
|
||||||
}
|
}
|
||||||
return jsonable_encoder(payload)
|
return jsonable_encoder(payload)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class VisualizationTool(Tool):
|
|||||||
viz_payload = {
|
viz_payload = {
|
||||||
"sql": existing_viz.get("sql", ""),
|
"sql": existing_viz.get("sql", ""),
|
||||||
"result": data,
|
"result": data,
|
||||||
"chart": chart_response.model_dump(),
|
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
|
||||||
"error": None,
|
"error": None,
|
||||||
}
|
}
|
||||||
encoded_viz = jsonable_encoder(viz_payload)
|
encoded_viz = jsonable_encoder(viz_payload)
|
||||||
|
|||||||
@@ -313,8 +313,8 @@ export function ChatInterface() {
|
|||||||
}): MessageViz => {
|
}): MessageViz => {
|
||||||
const rows = Array.isArray(payload.result) ? payload.result : [];
|
const rows = Array.isArray(payload.result) ? payload.result : [];
|
||||||
const chart = payload.chart ?? undefined;
|
const chart = payload.chart ?? undefined;
|
||||||
const canVisualize = Boolean(chart?.can_visualize);
|
const canVisualize = chart?.can_visualize ?? Boolean(chart?.chart_spec);
|
||||||
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
|
const chartSpec = chart?.chart_spec ?? null;
|
||||||
return {
|
return {
|
||||||
sql: typeof payload.sql === "string" ? payload.sql : "",
|
sql: typeof payload.sql === "string" ? payload.sql : "",
|
||||||
rows,
|
rows,
|
||||||
@@ -553,12 +553,12 @@ export function ChatInterface() {
|
|||||||
let renderedText = "";
|
let renderedText = "";
|
||||||
|
|
||||||
const flushAssistant = (force = false) => {
|
const flushAssistant = (force = false) => {
|
||||||
if (streamedText === renderedText) return;
|
if (streamedText === renderedText && !force) return;
|
||||||
if (force) {
|
if (force) {
|
||||||
renderedText = streamedText;
|
renderedText = streamedText;
|
||||||
setMessagesForSession(targetSessionKey, (prev) =>
|
setMessagesForSession(targetSessionKey, (prev) =>
|
||||||
prev.map((msg) =>
|
prev.map((msg) =>
|
||||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
@@ -571,7 +571,7 @@ export function ChatInterface() {
|
|||||||
renderedText = streamedText;
|
renderedText = streamedText;
|
||||||
setMessagesForSession(targetSessionKey, (prev) =>
|
setMessagesForSession(targetSessionKey, (prev) =>
|
||||||
prev.map((msg) =>
|
prev.map((msg) =>
|
||||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -645,11 +645,7 @@ export function ChatInterface() {
|
|||||||
if (payload.type === "viz") {
|
if (payload.type === "viz") {
|
||||||
pushProgressLog("可视化结果已生成");
|
pushProgressLog("可视化结果已生成");
|
||||||
streamedViz = buildMessageViz(payload);
|
streamedViz = buildMessageViz(payload);
|
||||||
setMessagesForSession(targetSessionKey, (prev) =>
|
flushAssistant(true); // 立即把 viz 状态刷入 messages
|
||||||
prev.map((msg) =>
|
|
||||||
msg.id === assistantId ? { ...msg, viz: streamedViz || undefined } : msg
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -171,8 +171,8 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{view === "chart" ? (
|
{view === "chart" ? (
|
||||||
viz.canVisualize && viz.chartSpec && objectRows.length > 0 ? (
|
viz.chartSpec && objectRows.length > 0 ? (
|
||||||
<div className="w-full h-80 rounded-xl border border-zinc-100 p-2">
|
<div className="w-full h-80 min-h-[320px] rounded-xl border border-zinc-100 p-2">
|
||||||
<VegaChart data={objectRows} spec={viz.chartSpec} />
|
<VegaChart data={objectRows} spec={viz.chartSpec} />
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
|
|||||||
@@ -30,23 +30,27 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const vegaSpec: any = useMemo(() => {
|
const vegaSpec: any = useMemo(() => {
|
||||||
// Clone spec and ensure tooltip is enabled in mark if not already specified
|
// Deep clone spec to avoid mutating React state/props
|
||||||
const baseSpec = { ...spec };
|
const baseSpec = JSON.parse(JSON.stringify(spec));
|
||||||
|
|
||||||
|
// Ensure tooltip is enabled in mark if not already specified
|
||||||
if (typeof baseSpec.mark === 'string') {
|
if (typeof baseSpec.mark === 'string') {
|
||||||
baseSpec.mark = { type: baseSpec.mark, tooltip: true };
|
baseSpec.mark = { type: baseSpec.mark, tooltip: true };
|
||||||
} else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
|
} else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
|
||||||
baseSpec.mark = { ...baseSpec.mark, tooltip: true };
|
baseSpec.mark.tooltip = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add highlight effect: hover over an element makes others transparent
|
// Add highlight effect: hover over an element makes others transparent
|
||||||
// 1. Define hover param
|
// 1. Define hover param
|
||||||
if (!baseSpec.params) {
|
if (!baseSpec.params) {
|
||||||
baseSpec.params = [
|
baseSpec.params = [];
|
||||||
{
|
}
|
||||||
name: "highlight",
|
const hasHighlight = baseSpec.params.some((p: any) => p.name === "highlight");
|
||||||
select: { type: "point", on: "mouseover", clear: "mouseout" }
|
if (!hasHighlight) {
|
||||||
}
|
baseSpec.params.push({
|
||||||
];
|
name: "highlight",
|
||||||
|
select: { type: "point", on: "mouseover", clear: "mouseout" }
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Add conditional opacity to encoding
|
// 2. Add conditional opacity to encoding
|
||||||
@@ -64,7 +68,7 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
|||||||
|
|
||||||
// Also add cursor: pointer for marks
|
// Also add cursor: pointer for marks
|
||||||
if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
|
if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
|
||||||
(baseSpec.mark as any).cursor = "pointer";
|
baseSpec.mark.cursor = "pointer";
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -77,12 +81,17 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
|||||||
};
|
};
|
||||||
}, [data, size.height, size.width, spec]);
|
}, [data, size.height, size.width, spec]);
|
||||||
|
|
||||||
|
const handleError = (error: any) => {
|
||||||
|
console.error("VegaEmbed rendering error:", error, "Spec:", vegaSpec);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full h-full" ref={containerRef}>
|
<div className="w-full h-full min-h-[300px]" ref={containerRef}>
|
||||||
<VegaEmbed
|
<VegaEmbed
|
||||||
spec={vegaSpec}
|
spec={vegaSpec}
|
||||||
options={{ actions: false }}
|
options={{ actions: false }}
|
||||||
style={{width: '100%', height: '100%'}}
|
style={{width: '100%', height: '100%'}}
|
||||||
|
onError={handleError}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -214,6 +214,8 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
request_timeout: float | None = None,
|
||||||
|
num_retries: int | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request via LiteLLM.
|
Send a chat completion request via LiteLLM.
|
||||||
@@ -268,6 +270,12 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = "auto"
|
kwargs["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
if request_timeout is not None:
|
||||||
|
kwargs["timeout"] = request_timeout
|
||||||
|
|
||||||
|
if num_retries is not None:
|
||||||
|
kwargs["num_retries"] = max(0, int(num_retries))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await acompletion(**kwargs)
|
response = await acompletion(**kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user