reorg skill folder

This commit is contained in:
qixinbo
2026-03-19 12:27:31 +08:00
parent baec21c774
commit cca492cfdb
13 changed files with 232 additions and 81 deletions
+13 -1
View File
@@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """
- For daily question, the time unit should be "yearmonthdate". - For daily question, the time unit should be "yearmonthdate".
- Default time unit is "yearmonth". - Default time unit is "yearmonth".
- For each axis, generate the corresponding human-readable title based on the language provided by the user. - For each axis, generate the corresponding human-readable title based on the language provided by the user.
- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data. - **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property!
### GUIDELINES TO PLOT CHART ### ### GUIDELINES TO PLOT CHART ###
@@ -233,6 +233,18 @@ Language: Chinese (Simplified)
content = content.strip() content = content.strip()
result = json.loads(content) result = json.loads(content)
# Post-process to fix common LLM hallucinations (translating field names)
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
encoding = result["chart_spec"].get("encoding", {})
for channel, enc_def in encoding.items():
if isinstance(enc_def, dict) and "field" in enc_def:
field = enc_def["field"]
# If field is not in columns, try to find a match or let it be (Vega will render empty)
# But if we can detect it was translated, we might not be able to fix it perfectly.
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
# but it's safer to just rely on the stricter prompt.
return ChartGenerationResponse(**result) return ChartGenerationResponse(**result)
except Exception as e: except Exception as e:
+77 -42
View File
@@ -4,7 +4,6 @@ import os
import json import json
import time import time
import threading import threading
import re
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900 NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1 NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low" NL2SQL_REASONING_EFFORT = "low"
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5 NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60 NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
NL2SQL_CHART_TIMEOUT_SECONDS = 45 NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
_schema_cache: Dict[str, Dict[str, Any]] = {} _schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {} _connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
conn.close() conn.close()
return result_df.to_dict(orient="records") return result_df.to_dict(orient="records")
def _to_number(value: Any) -> Optional[float]:
if isinstance(value, bool):
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
text = value.strip().replace(",", "")
if not text:
return None
try:
return float(text)
except ValueError:
return None
return None
# _build_fallback_chart removed as per user request to not hardcode fallbacks
def _build_schema_cache_key(source: str, connector: Any) -> str: def _build_schema_cache_key(source: str, connector: Any) -> str:
# If source is ds:ID, that's already a good key # If source is ds:ID, that's already a good key
if source.startswith("ds:"): if source.startswith("ds:"):
@@ -270,7 +288,7 @@ async def process_nl2sql(
if connector: if connector:
await emit_progress("正在检测数据源连通性") await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector) cached_schema = _get_cached_schema(request.source, connector)
if cached_schema: if cached_schema is not None:
schema = cached_schema schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表") await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else: else:
@@ -289,23 +307,6 @@ async def process_nl2sql(
_set_cached_schema(request.source, connector, schema) _set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)") await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not await _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=30.0
)
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":")) schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
@@ -378,42 +379,66 @@ Language: Chinese (Simplified)
try: try:
llm_started = time.perf_counter() llm_started = time.perf_counter()
await emit_progress("正在生成 SQL") await emit_progress("正在生成 SQL")
response = await asyncio.wait_for( response = None
provider.chat( last_error = ""
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS, for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
temperature=NL2SQL_TEMPERATURE, try:
reasoning_effort=NL2SQL_REASONING_EFFORT, response = await asyncio.wait_for(
), provider.chat(
timeout=NL2SQL_LLM_TIMEOUT_SECONDS, messages=messages,
) max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
if response.finish_reason == "error": reasoning_effort=NL2SQL_REASONING_EFFORT,
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error") request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
num_retries=0,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
except Exception as e:
last_error = f"LLM generation failed: {e}"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
if response.finish_reason == "error":
last_error = response.content or "LLM Error"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
break
if response is None:
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
content = (response.content or "").strip() content = (response.content or "").strip()
if not content: if not content:
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response") return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
# Clean up code blocks # Clean up code blocks
if "```json" in content: if "```json" in content:
content = content.split("```json")[1].split("```")[0] content = content.split("```json")[1].split("```")[0]
elif "```" in content: elif "```" in content:
content = content.split("```")[1].split("```")[0] content = content.split("```")[1].split("```")[0]
content = content.strip() content = content.strip()
try: try:
result_json = json.loads(content) result_json = json.loads(content)
sql_query = result_json.get("sql", "").strip() sql_query = result_json.get("sql", "").strip()
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
except json.JSONDecodeError: except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions # Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
except Exception as e: except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
generate_chart(formatted_results, request.query), generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS, timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
) )
if not chart_response or not chart_response.chart_spec:
# Do not fallback automatically if the LLM explicitly decided not to or failed.
# Just pass whatever it returned (or lack thereof)
pass
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)") await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s") await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if timeout_stage == "chart_generation": if timeout_stage == "chart_generation":
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s") fallback_chart = ChartGenerationResponse(
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
chart_type="",
can_visualize=False,
chart_spec=None,
)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s") return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e: except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}") return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
+38 -3
View File
@@ -3,6 +3,7 @@ import os
import shutil import shutil
import zipfile import zipfile
import tarfile import tarfile
import re
import yaml import yaml
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from datetime import datetime from datetime import datetime
@@ -111,6 +112,26 @@ def _save_data(data: List[Dict[str, Any]]):
with open(DATA_FILE, "w") as f: with open(DATA_FILE, "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False) json.dump(data, f, indent=2, ensure_ascii=False)
def _safe_skill_dir_name(value: str) -> str:
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
return safe or "skill"
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
os.makedirs(skill_dir, exist_ok=True)
skill_md_path = os.path.join(skill_dir, "SKILL.md")
final_description = description or "No description provided"
body = content or ""
markdown = (
f"---\n"
f"name: {skill_name}\n"
f"description: {final_description}\n"
f"---\n\n"
f"{body}\n"
)
with open(skill_md_path, "w", encoding="utf-8") as f:
f.write(markdown)
return skill_md_path
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]: def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
data = _load_data() data = _load_data()
if project_id is not None: if project_id is not None:
@@ -205,9 +226,7 @@ async def upload_skill(
skill_name = os.path.splitext(filename)[0] skill_name = os.path.splitext(filename)[0]
# Create a safe directory name for the skill # Create a safe directory name for the skill
import re safe_name = _safe_skill_dir_name(skill_name)
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', skill_name).lower()
if not safe_name: safe_name = "skill"
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}" final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id) final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
@@ -264,6 +283,15 @@ def create_skill(skill: SkillCreate):
new_skill_dict = skill.dict() new_skill_dict = skill.dict()
if not new_skill_dict.get("installation_time"): if not new_skill_dict.get("installation_time"):
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d") new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d")
if not new_skill_dict.get("file_path"):
skill_dir = os.path.join(SKILL_HUB_DIR, _safe_skill_dir_name(new_skill_dict["id"]))
_write_skill_markdown(
skill_dir=skill_dir,
skill_name=new_skill_dict["name"],
description=new_skill_dict.get("description"),
content=new_skill_dict.get("content", ""),
)
new_skill_dict["file_path"] = skill_dir
data.append(new_skill_dict) data.append(new_skill_dict)
_save_data(data) _save_data(data)
@@ -279,6 +307,13 @@ def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] =
updated_item = item.copy() updated_item = item.copy()
update_data = skill.dict(exclude_unset=True) update_data = skill.dict(exclude_unset=True)
updated_item.update(update_data) updated_item.update(update_data)
if updated_item.get("file_path"):
_write_skill_markdown(
skill_dir=updated_item["file_path"],
skill_name=updated_item.get("name") or item.get("name") or "skill",
description=updated_item.get("description"),
content=updated_item.get("content", ""),
)
data[i] = updated_item data[i] = updated_item
_save_data(data) _save_data(data)
return Skill(**updated_item) return Skill(**updated_item)
+16
View File
@@ -1,6 +1,7 @@
import asyncio import asyncio
import sys import sys
import os import os
import shutil
from pathlib import Path from pathlib import Path
from typing import List, Callable, Awaitable, Any, Dict from typing import List, Callable, Awaitable, Any, Dict
@@ -47,6 +48,7 @@ class NanobotIntegration:
# Set workspace path to backend/data/workspace # Set workspace path to backend/data/workspace
workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace")) workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace"))
workspace_path.mkdir(parents=True, exist_ok=True) workspace_path.mkdir(parents=True, exist_ok=True)
self._sync_builtin_skills_to_workspace(workspace_path)
# Override config workspace path via environment variable (since config is loaded from env) # Override config workspace path via environment variable (since config is loaded from env)
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path) os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
@@ -87,6 +89,20 @@ class NanobotIntegration:
self._register_custom_tools(self.agent) self._register_custom_tools(self.agent)
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
workspace_skills_root = workspace_path / "skills"
workspace_skills_root.mkdir(parents=True, exist_ok=True)
for skill_name in ("nl2sql", "visualization"):
source_dir = builtin_root / skill_name
source_skill_file = source_dir / "SKILL.md"
if not source_skill_file.exists():
continue
target_dir = workspace_skills_root / skill_name
target_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
def _register_custom_tools(self, agent: AgentLoop): def _register_custom_tools(self, agent: AgentLoop):
from app.tools.nl2sql import NL2SQLTool from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool from app.tools.visualization import VisualizationTool
+2 -9
View File
@@ -123,13 +123,6 @@ class ChartGenerationResponse(BaseModel):
chart_type: Literal[ chart_type: Literal[
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", "" "line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
] = Field(..., description="The type of chart generated, or empty string if none") ] = Field(..., description="The type of chart generated, or empty string if none")
chart_spec: Optional[Union[ # Using Dict[str, Any] allows LLM to output valid Vega-Lite spec directly, avoiding Pydantic strict model serialization issues with dynamic fields
LineChartSchema, chart_spec: Optional[Dict[str, Any]] = Field(None, description="The generated Vega-Lite chart specification")
MultiLineChartSchema,
BarChartSchema,
PieChartSchema,
GroupedBarChartSchema,
StackedBarChartSchema,
AreaChartSchema
]] = Field(None, description="The generated Vega-Lite chart specification")
can_visualize: bool = Field(..., description="Whether the data can be visualized") can_visualize: bool = Field(..., description="Whether the data can be visualized")
@@ -0,0 +1,24 @@
---
description: Data Analysis and SQL Generation
metadata:
nanobot:
always: true
---
# NL2SQL Data Analysis Skill
You are an expert data analyst. You have access to a powerful `nl2sql` tool that can query the connected database using natural language.
## When to use this skill
- When the user asks to query, analyze, aggregate, or fetch data from the database.
- Examples: "Show me the top 10 sales", "What is the average revenue by month?", "How many users registered yesterday?".
## How to use this skill
- Call the `nl2sql` tool with the user's natural language query.
- If the user explicitly asks to "visualize" or "plot" the data in the SAME message as the query (e.g., "Show me sales by region and plot it as a pie chart"), you can set `generate_chart=True` in the `nl2sql` tool.
- If the user ONLY asks to query data, set `generate_chart=False` (default).
## After using the tool
- The tool will return a summary of the executed query and a sample of the results.
- Use this information to provide a clear, concise, and helpful response to the user.
- If a chart was successfully generated by the tool, inform the user that the chart is available in the visualization panel.
@@ -0,0 +1,23 @@
---
description: Data Visualization and Chart Generation
metadata:
nanobot:
always: true
---
# Data Visualization Skill
You are an expert data visualization specialist. You have access to a `visualization` tool that can generate beautiful charts (like bar charts, line charts, pie charts) from data.
## When to use this skill
- When the user asks to visualize, plot, or draw a chart based on data that has ALREADY been queried or is currently in context.
- Examples: "Visualize it as a bar chart", "Plot the trend over time", "Draw a pie chart of the regions".
- DO NOT use this tool if the data hasn't been queried yet. If the user asks a new question and wants it visualized (e.g., "Show me sales and plot it"), use the `nl2sql` tool with `generate_chart=True` instead, or call `nl2sql` first and then this tool.
## How to use this skill
- Call the `visualization` tool with the user's specific visualization request (e.g., "plot as a pie chart").
- The tool relies on the data from the most recent SQL query. It will automatically read this data from the context.
## After using the tool
- The tool will return a success message and the reasoning for the chosen chart type.
- Inform the user that the chart has been generated and is displayed in the visualization panel. Explain briefly what the chart shows if helpful.
+2 -2
View File
@@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Any, Dict from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
@@ -14,7 +14,7 @@ def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
payload = { payload = {
"sql": nl2sql_result.sql, "sql": nl2sql_result.sql,
"result": nl2sql_result.result, "result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None, "chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
"error": nl2sql_result.error, "error": nl2sql_result.error,
} }
return jsonable_encoder(payload) return jsonable_encoder(payload)
+1 -1
View File
@@ -60,7 +60,7 @@ class VisualizationTool(Tool):
viz_payload = { viz_payload = {
"sql": existing_viz.get("sql", ""), "sql": existing_viz.get("sql", ""),
"result": data, "result": data,
"chart": chart_response.model_dump(), "chart": chart_response.model_dump(by_alias=True, exclude_none=True),
"error": None, "error": None,
} }
encoded_viz = jsonable_encoder(viz_payload) encoded_viz = jsonable_encoder(viz_payload)
+6 -10
View File
@@ -313,8 +313,8 @@ export function ChatInterface() {
}): MessageViz => { }): MessageViz => {
const rows = Array.isArray(payload.result) ? payload.result : []; const rows = Array.isArray(payload.result) ? payload.result : [];
const chart = payload.chart ?? undefined; const chart = payload.chart ?? undefined;
const canVisualize = Boolean(chart?.can_visualize); const canVisualize = chart?.can_visualize ?? Boolean(chart?.chart_spec);
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null; const chartSpec = chart?.chart_spec ?? null;
return { return {
sql: typeof payload.sql === "string" ? payload.sql : "", sql: typeof payload.sql === "string" ? payload.sql : "",
rows, rows,
@@ -553,12 +553,12 @@ export function ChatInterface() {
let renderedText = ""; let renderedText = "";
const flushAssistant = (force = false) => { const flushAssistant = (force = false) => {
if (streamedText === renderedText) return; if (streamedText === renderedText && !force) return;
if (force) { if (force) {
renderedText = streamedText; renderedText = streamedText;
setMessagesForSession(targetSessionKey, (prev) => setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
) )
); );
return; return;
@@ -571,7 +571,7 @@ export function ChatInterface() {
renderedText = streamedText; renderedText = streamedText;
setMessagesForSession(targetSessionKey, (prev) => setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
) )
); );
}); });
@@ -645,11 +645,7 @@ export function ChatInterface() {
if (payload.type === "viz") { if (payload.type === "viz") {
pushProgressLog("可视化结果已生成"); pushProgressLog("可视化结果已生成");
streamedViz = buildMessageViz(payload); streamedViz = buildMessageViz(payload);
setMessagesForSession(targetSessionKey, (prev) => flushAssistant(true); // 立即把 viz 状态刷入 messages
prev.map((msg) =>
msg.id === assistantId ? { ...msg, viz: streamedViz || undefined } : msg
)
);
} }
} }
} }
@@ -171,8 +171,8 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) {
</div> </div>
{view === "chart" ? ( {view === "chart" ? (
viz.canVisualize && viz.chartSpec && objectRows.length > 0 ? ( viz.chartSpec && objectRows.length > 0 ? (
<div className="w-full h-80 rounded-xl border border-zinc-100 p-2"> <div className="w-full h-80 min-h-[320px] rounded-xl border border-zinc-100 p-2">
<VegaChart data={objectRows} spec={viz.chartSpec} /> <VegaChart data={objectRows} spec={viz.chartSpec} />
</div> </div>
) : ( ) : (
+20 -11
View File
@@ -30,23 +30,27 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
}, []); }, []);
const vegaSpec: any = useMemo(() => { const vegaSpec: any = useMemo(() => {
// Clone spec and ensure tooltip is enabled in mark if not already specified // Deep clone spec to avoid mutating React state/props
const baseSpec = { ...spec }; const baseSpec = JSON.parse(JSON.stringify(spec));
// Ensure tooltip is enabled in mark if not already specified
if (typeof baseSpec.mark === 'string') { if (typeof baseSpec.mark === 'string') {
baseSpec.mark = { type: baseSpec.mark, tooltip: true }; baseSpec.mark = { type: baseSpec.mark, tooltip: true };
} else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) { } else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
baseSpec.mark = { ...baseSpec.mark, tooltip: true }; baseSpec.mark.tooltip = true;
} }
// Add highlight effect: hover over an element makes others transparent // Add highlight effect: hover over an element makes others transparent
// 1. Define hover param // 1. Define hover param
if (!baseSpec.params) { if (!baseSpec.params) {
baseSpec.params = [ baseSpec.params = [];
{ }
name: "highlight", const hasHighlight = baseSpec.params.some((p: any) => p.name === "highlight");
select: { type: "point", on: "mouseover", clear: "mouseout" } if (!hasHighlight) {
} baseSpec.params.push({
]; name: "highlight",
select: { type: "point", on: "mouseover", clear: "mouseout" }
});
} }
// 2. Add conditional opacity to encoding // 2. Add conditional opacity to encoding
@@ -64,7 +68,7 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
// Also add cursor: pointer for marks // Also add cursor: pointer for marks
if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) { if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
(baseSpec.mark as any).cursor = "pointer"; baseSpec.mark.cursor = "pointer";
} }
return { return {
@@ -77,12 +81,17 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
}; };
}, [data, size.height, size.width, spec]); }, [data, size.height, size.width, spec]);
const handleError = (error: any) => {
console.error("VegaEmbed rendering error:", error, "Spec:", vegaSpec);
};
return ( return (
<div className="w-full h-full" ref={containerRef}> <div className="w-full h-full min-h-[300px]" ref={containerRef}>
<VegaEmbed <VegaEmbed
spec={vegaSpec} spec={vegaSpec}
options={{ actions: false }} options={{ actions: false }}
style={{width: '100%', height: '100%'}} style={{width: '100%', height: '100%'}}
onError={handleError}
/> />
</div> </div>
); );
@@ -214,6 +214,8 @@ class LiteLLMProvider(LLMProvider):
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
request_timeout: float | None = None,
num_retries: int | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request via LiteLLM. Send a chat completion request via LiteLLM.
@@ -268,6 +270,12 @@ class LiteLLMProvider(LLMProvider):
if tools: if tools:
kwargs["tools"] = tools kwargs["tools"] = tools
kwargs["tool_choice"] = "auto" kwargs["tool_choice"] = "auto"
if request_timeout is not None:
kwargs["timeout"] = request_timeout
if num_retries is not None:
kwargs["num_retries"] = max(0, int(num_retries))
try: try:
response = await acompletion(**kwargs) response = await acompletion(**kwargs)