reorg skill folder
This commit is contained in:
@@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """
|
||||
- For daily question, the time unit should be "yearmonthdate".
|
||||
- Default time unit is "yearmonth".
|
||||
- For each axis, generate the corresponding human-readable title based on the language provided by the user.
|
||||
- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data.
|
||||
- **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property!
|
||||
|
||||
### GUIDELINES TO PLOT CHART ###
|
||||
|
||||
@@ -233,6 +233,18 @@ Language: Chinese (Simplified)
|
||||
|
||||
content = content.strip()
|
||||
result = json.loads(content)
|
||||
|
||||
# Post-process to fix common LLM hallucinations (translating field names)
|
||||
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
|
||||
encoding = result["chart_spec"].get("encoding", {})
|
||||
for channel, enc_def in encoding.items():
|
||||
if isinstance(enc_def, dict) and "field" in enc_def:
|
||||
field = enc_def["field"]
|
||||
# If field is not in columns, try to find a match or let it be (Vega will render empty)
|
||||
# But if we can detect it was translated, we might not be able to fix it perfectly.
|
||||
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
|
||||
# but it's safer to just rely on the stricter prompt.
|
||||
|
||||
return ChartGenerationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
+61
-26
@@ -4,7 +4,6 @@ import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
|
||||
NL2SQL_MAX_TOKENS = 900
|
||||
NL2SQL_TEMPERATURE = 0.1
|
||||
NL2SQL_REASONING_EFFORT = "low"
|
||||
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
|
||||
NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
|
||||
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
|
||||
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
|
||||
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
||||
NL2SQL_CHART_TIMEOUT_SECONDS = 45
|
||||
NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
|
||||
|
||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
|
||||
conn.close()
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
def _to_number(value: Any) -> Optional[float]:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
text = value.strip().replace(",", "")
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# _build_fallback_chart removed as per user request to not hardcode fallbacks
|
||||
|
||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
# If source is ds:ID, that's already a good key
|
||||
if source.startswith("ds:"):
|
||||
@@ -270,7 +288,7 @@ async def process_nl2sql(
|
||||
if connector:
|
||||
await emit_progress("正在检测数据源连通性")
|
||||
cached_schema = _get_cached_schema(request.source, connector)
|
||||
if cached_schema:
|
||||
if cached_schema is not None:
|
||||
schema = cached_schema
|
||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||
else:
|
||||
@@ -290,23 +308,6 @@ async def process_nl2sql(
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||
|
||||
if connector and not schema:
|
||||
retry_started = time.perf_counter()
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not await _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
||||
|
||||
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
# Try to load MDL context
|
||||
@@ -378,18 +379,45 @@ Language: Chinese (Simplified)
|
||||
try:
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = None
|
||||
last_error = ""
|
||||
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||
num_retries=0,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
except Exception as e:
|
||||
last_error = f"LLM generation failed: {e}"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
|
||||
last_error = response.content or "LLM Error"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
@@ -406,14 +434,11 @@ Language: Chinese (Simplified)
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
if not chart_response or not chart_response.chart_spec:
|
||||
# Do not fallback automatically if the LLM explicitly decided not to or failed.
|
||||
# Just pass whatever it returned (or lack thereof)
|
||||
pass
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except asyncio.TimeoutError:
|
||||
if timeout_stage == "chart_generation":
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
|
||||
fallback_chart = ChartGenerationResponse(
|
||||
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
|
||||
chart_type="",
|
||||
can_visualize=False,
|
||||
chart_spec=None,
|
||||
)
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tarfile
|
||||
import re
|
||||
import yaml
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
@@ -111,6 +112,26 @@ def _save_data(data: List[Dict[str, Any]]):
|
||||
with open(DATA_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _safe_skill_dir_name(value: str) -> str:
|
||||
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower()
|
||||
return safe or "skill"
|
||||
|
||||
def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str:
|
||||
os.makedirs(skill_dir, exist_ok=True)
|
||||
skill_md_path = os.path.join(skill_dir, "SKILL.md")
|
||||
final_description = description or "No description provided"
|
||||
body = content or ""
|
||||
markdown = (
|
||||
f"---\n"
|
||||
f"name: {skill_name}\n"
|
||||
f"description: {final_description}\n"
|
||||
f"---\n\n"
|
||||
f"{body}\n"
|
||||
)
|
||||
with open(skill_md_path, "w", encoding="utf-8") as f:
|
||||
f.write(markdown)
|
||||
return skill_md_path
|
||||
|
||||
def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
data = _load_data()
|
||||
if project_id is not None:
|
||||
@@ -205,9 +226,7 @@ async def upload_skill(
|
||||
skill_name = os.path.splitext(filename)[0]
|
||||
|
||||
# Create a safe directory name for the skill
|
||||
import re
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', skill_name).lower()
|
||||
if not safe_name: safe_name = "skill"
|
||||
safe_name = _safe_skill_dir_name(skill_name)
|
||||
final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id)
|
||||
|
||||
@@ -264,6 +283,15 @@ def create_skill(skill: SkillCreate):
|
||||
new_skill_dict = skill.dict()
|
||||
if not new_skill_dict.get("installation_time"):
|
||||
new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d日")
|
||||
if not new_skill_dict.get("file_path"):
|
||||
skill_dir = os.path.join(SKILL_HUB_DIR, _safe_skill_dir_name(new_skill_dict["id"]))
|
||||
_write_skill_markdown(
|
||||
skill_dir=skill_dir,
|
||||
skill_name=new_skill_dict["name"],
|
||||
description=new_skill_dict.get("description"),
|
||||
content=new_skill_dict.get("content", ""),
|
||||
)
|
||||
new_skill_dict["file_path"] = skill_dir
|
||||
|
||||
data.append(new_skill_dict)
|
||||
_save_data(data)
|
||||
@@ -279,6 +307,13 @@ def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] =
|
||||
updated_item = item.copy()
|
||||
update_data = skill.dict(exclude_unset=True)
|
||||
updated_item.update(update_data)
|
||||
if updated_item.get("file_path"):
|
||||
_write_skill_markdown(
|
||||
skill_dir=updated_item["file_path"],
|
||||
skill_name=updated_item.get("name") or item.get("name") or "skill",
|
||||
description=updated_item.get("description"),
|
||||
content=updated_item.get("content", ""),
|
||||
)
|
||||
data[i] = updated_item
|
||||
_save_data(data)
|
||||
return Skill(**updated_item)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Callable, Awaitable, Any, Dict
|
||||
|
||||
@@ -47,6 +48,7 @@ class NanobotIntegration:
|
||||
# Set workspace path to backend/data/workspace
|
||||
workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace"))
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
self._sync_builtin_skills_to_workspace(workspace_path)
|
||||
|
||||
# Override config workspace path via environment variable (since config is loaded from env)
|
||||
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
|
||||
@@ -87,6 +89,20 @@ class NanobotIntegration:
|
||||
|
||||
self._register_custom_tools(self.agent)
|
||||
|
||||
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
|
||||
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
|
||||
workspace_skills_root = workspace_path / "skills"
|
||||
workspace_skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for skill_name in ("nl2sql", "visualization"):
|
||||
source_dir = builtin_root / skill_name
|
||||
source_skill_file = source_dir / "SKILL.md"
|
||||
if not source_skill_file.exists():
|
||||
continue
|
||||
target_dir = workspace_skills_root / skill_name
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
|
||||
|
||||
def _register_custom_tools(self, agent: AgentLoop):
|
||||
from app.tools.nl2sql import NL2SQLTool
|
||||
from app.tools.visualization import VisualizationTool
|
||||
|
||||
@@ -123,13 +123,6 @@ class ChartGenerationResponse(BaseModel):
|
||||
chart_type: Literal[
|
||||
"line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", ""
|
||||
] = Field(..., description="The type of chart generated, or empty string if none")
|
||||
chart_spec: Optional[Union[
|
||||
LineChartSchema,
|
||||
MultiLineChartSchema,
|
||||
BarChartSchema,
|
||||
PieChartSchema,
|
||||
GroupedBarChartSchema,
|
||||
StackedBarChartSchema,
|
||||
AreaChartSchema
|
||||
]] = Field(None, description="The generated Vega-Lite chart specification")
|
||||
# Using Dict[str, Any] allows LLM to output valid Vega-Lite spec directly, avoiding Pydantic strict model serialization issues with dynamic fields
|
||||
chart_spec: Optional[Dict[str, Any]] = Field(None, description="The generated Vega-Lite chart specification")
|
||||
can_visualize: bool = Field(..., description="Whether the data can be visualized")
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
---
|
||||
description: Data Analysis and SQL Generation
|
||||
metadata:
|
||||
nanobot:
|
||||
always: true
|
||||
---
|
||||
|
||||
# NL2SQL Data Analysis Skill
|
||||
|
||||
You are an expert data analyst. You have access to a powerful `nl2sql` tool that can query the connected database using natural language.
|
||||
|
||||
## When to use this skill
|
||||
- When the user asks to query, analyze, aggregate, or fetch data from the database.
|
||||
- Examples: "Show me the top 10 sales", "What is the average revenue by month?", "How many users registered yesterday?".
|
||||
|
||||
## How to use this skill
|
||||
- Call the `nl2sql` tool with the user's natural language query.
|
||||
- If the user explicitly asks to "visualize" or "plot" the data in the SAME message as the query (e.g., "Show me sales by region and plot it as a pie chart"), you can set `generate_chart=True` in the `nl2sql` tool.
|
||||
- If the user ONLY asks to query data, set `generate_chart=False` (default).
|
||||
|
||||
## After using the tool
|
||||
- The tool will return a summary of the executed query and a sample of the results.
|
||||
- Use this information to provide a clear, concise, and helpful response to the user.
|
||||
- If a chart was successfully generated by the tool, inform the user that the chart is available in the visualization panel.
|
||||
@@ -0,0 +1,23 @@
|
||||
---
|
||||
description: Data Visualization and Chart Generation
|
||||
metadata:
|
||||
nanobot:
|
||||
always: true
|
||||
---
|
||||
|
||||
# Data Visualization Skill
|
||||
|
||||
You are an expert data visualization specialist. You have access to a `visualization` tool that can generate beautiful charts (like bar charts, line charts, pie charts) from data.
|
||||
|
||||
## When to use this skill
|
||||
- When the user asks to visualize, plot, or draw a chart based on data that has ALREADY been queried or is currently in context.
|
||||
- Examples: "Visualize it as a bar chart", "Plot the trend over time", "Draw a pie chart of the regions".
|
||||
- DO NOT use this tool if the data hasn't been queried yet. If the user asks a new question and wants it visualized (e.g., "Show me sales and plot it"), use the `nl2sql` tool with `generate_chart=True` instead, or call `nl2sql` first and then this tool.
|
||||
|
||||
## How to use this skill
|
||||
- Call the `visualization` tool with the user's specific visualization request (e.g., "plot as a pie chart").
|
||||
- The tool relies on the data from the most recent SQL query. It will automatically read this data from the context.
|
||||
|
||||
## After using the tool
|
||||
- The tool will return a success message and the reasoning for the chosen chart type.
|
||||
- Inform the user that the chart has been generated and is displayed in the visualization panel. Explain briefly what the chart shows if helpful.
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
@@ -14,7 +14,7 @@ def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
||||
payload = {
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump() if chart else None,
|
||||
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
|
||||
"error": nl2sql_result.error,
|
||||
}
|
||||
return jsonable_encoder(payload)
|
||||
|
||||
@@ -60,7 +60,7 @@ class VisualizationTool(Tool):
|
||||
viz_payload = {
|
||||
"sql": existing_viz.get("sql", ""),
|
||||
"result": data,
|
||||
"chart": chart_response.model_dump(),
|
||||
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
|
||||
"error": None,
|
||||
}
|
||||
encoded_viz = jsonable_encoder(viz_payload)
|
||||
|
||||
@@ -313,8 +313,8 @@ export function ChatInterface() {
|
||||
}): MessageViz => {
|
||||
const rows = Array.isArray(payload.result) ? payload.result : [];
|
||||
const chart = payload.chart ?? undefined;
|
||||
const canVisualize = Boolean(chart?.can_visualize);
|
||||
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
|
||||
const canVisualize = chart?.can_visualize ?? Boolean(chart?.chart_spec);
|
||||
const chartSpec = chart?.chart_spec ?? null;
|
||||
return {
|
||||
sql: typeof payload.sql === "string" ? payload.sql : "",
|
||||
rows,
|
||||
@@ -553,12 +553,12 @@ export function ChatInterface() {
|
||||
let renderedText = "";
|
||||
|
||||
const flushAssistant = (force = false) => {
|
||||
if (streamedText === renderedText) return;
|
||||
if (streamedText === renderedText && !force) return;
|
||||
if (force) {
|
||||
renderedText = streamedText;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||
)
|
||||
);
|
||||
return;
|
||||
@@ -571,7 +571,7 @@ export function ChatInterface() {
|
||||
renderedText = streamedText;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||
)
|
||||
);
|
||||
});
|
||||
@@ -645,11 +645,7 @@ export function ChatInterface() {
|
||||
if (payload.type === "viz") {
|
||||
pushProgressLog("可视化结果已生成");
|
||||
streamedViz = buildMessageViz(payload);
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, viz: streamedViz || undefined } : msg
|
||||
)
|
||||
);
|
||||
flushAssistant(true); // 立即把 viz 状态刷入 messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,8 +171,8 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) {
|
||||
</div>
|
||||
|
||||
{view === "chart" ? (
|
||||
viz.canVisualize && viz.chartSpec && objectRows.length > 0 ? (
|
||||
<div className="w-full h-80 rounded-xl border border-zinc-100 p-2">
|
||||
viz.chartSpec && objectRows.length > 0 ? (
|
||||
<div className="w-full h-80 min-h-[320px] rounded-xl border border-zinc-100 p-2">
|
||||
<VegaChart data={objectRows} spec={viz.chartSpec} />
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -30,23 +30,27 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
||||
}, []);
|
||||
|
||||
const vegaSpec: any = useMemo(() => {
|
||||
// Clone spec and ensure tooltip is enabled in mark if not already specified
|
||||
const baseSpec = { ...spec };
|
||||
// Deep clone spec to avoid mutating React state/props
|
||||
const baseSpec = JSON.parse(JSON.stringify(spec));
|
||||
|
||||
// Ensure tooltip is enabled in mark if not already specified
|
||||
if (typeof baseSpec.mark === 'string') {
|
||||
baseSpec.mark = { type: baseSpec.mark, tooltip: true };
|
||||
} else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
|
||||
baseSpec.mark = { ...baseSpec.mark, tooltip: true };
|
||||
baseSpec.mark.tooltip = true;
|
||||
}
|
||||
|
||||
// Add highlight effect: hover over an element makes others transparent
|
||||
// 1. Define hover param
|
||||
if (!baseSpec.params) {
|
||||
baseSpec.params = [
|
||||
{
|
||||
baseSpec.params = [];
|
||||
}
|
||||
const hasHighlight = baseSpec.params.some((p: any) => p.name === "highlight");
|
||||
if (!hasHighlight) {
|
||||
baseSpec.params.push({
|
||||
name: "highlight",
|
||||
select: { type: "point", on: "mouseover", clear: "mouseout" }
|
||||
}
|
||||
];
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Add conditional opacity to encoding
|
||||
@@ -64,7 +68,7 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
||||
|
||||
// Also add cursor: pointer for marks
|
||||
if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) {
|
||||
(baseSpec.mark as any).cursor = "pointer";
|
||||
baseSpec.mark.cursor = "pointer";
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -77,12 +81,17 @@ export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
||||
};
|
||||
}, [data, size.height, size.width, spec]);
|
||||
|
||||
const handleError = (error: any) => {
|
||||
console.error("VegaEmbed rendering error:", error, "Spec:", vegaSpec);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full h-full" ref={containerRef}>
|
||||
<div className="w-full h-full min-h-[300px]" ref={containerRef}>
|
||||
<VegaEmbed
|
||||
spec={vegaSpec}
|
||||
options={{ actions: false }}
|
||||
style={{width: '100%', height: '100%'}}
|
||||
onError={handleError}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -214,6 +214,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
request_timeout: float | None = None,
|
||||
num_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
@@ -269,6 +271,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
if request_timeout is not None:
|
||||
kwargs["timeout"] = request_timeout
|
||||
|
||||
if num_retries is not None:
|
||||
kwargs["num_retries"] = max(0, int(num_retries))
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
|
||||
Reference in New Issue
Block a user