diff --git a/backend/app/agent/chart.py b/backend/app/agent/chart.py index b7f941f..d9d348a 100644 --- a/backend/app/agent/chart.py +++ b/backend/app/agent/chart.py @@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """ - For daily question, the time unit should be "yearmonthdate". - Default time unit is "yearmonth". - For each axis, generate the corresponding human-readable title based on the language provided by the user. -- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data. +- **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property! ### GUIDELINES TO PLOT CHART ### @@ -233,6 +233,18 @@ Language: Chinese (Simplified) content = content.strip() result = json.loads(content) + + # Post-process to fix common LLM hallucinations (translating field names) + if result.get("chart_spec") and isinstance(result["chart_spec"], dict): + encoding = result["chart_spec"].get("encoding", {}) + for channel, enc_def in encoding.items(): + if isinstance(enc_def, dict) and "field" in enc_def: + field = enc_def["field"] + # If field is not in columns, try to find a match or let it be (Vega will render empty) + # But if we can detect it was translated, we might not be able to fix it perfectly. + # As a simple fallback, if there's only one string column and one numeric column, we could guess, + # but it's safer to just rely on the stricter prompt. + return ChartGenerationResponse(**result) except Exception as e: diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 5eaae3a..20922ae 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -4,7 +4,6 @@ import os import json import time import threading -import re from pathlib import Path from typing import List, Optional, Dict, Any, Callable, Awaitable from pydantic import BaseModel, Field @@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8 NL2SQL_MAX_TOKENS = 900 NL2SQL_TEMPERATURE = 0.1 NL2SQL_REASONING_EFFORT = "low" -NL2SQL_LLM_TIMEOUT_SECONDS = 60*5 +NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90")) +NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45")) +NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0")) NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60 -NL2SQL_CHART_TIMEOUT_SECONDS = 45 +NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20")) _schema_cache: Dict[str, Dict[str, Any]] = {} _connection_cache: Dict[str, Dict[str, Any]] = {} @@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any] conn.close() return result_df.to_dict(orient="records") +def _to_number(value: Any) -> Optional[float]: + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + text = value.strip().replace(",", "") + if not text: + return None + try: + return float(text) + except ValueError: + return None + return None + +# _build_fallback_chart removed as per user request to not hardcode fallbacks + def _build_schema_cache_key(source: str, connector: Any) -> str: # If source is ds:ID, that's already a good key if source.startswith("ds:"): @@ -270,7 +288,7 @@ async def process_nl2sql( if connector: await emit_progress("正在检测数据源连通性") cached_schema = _get_cached_schema(request.source, connector) - if cached_schema: + if cached_schema is not None: schema = cached_schema await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表") else: @@ -289,23 +307,6 @@ async def process_nl2sql( _set_cached_schema(request.source, connector, schema) await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)") - - if connector and not schema: - retry_started = time.perf_counter() - # Double check in case schema was empty but connection is ok (e.g. empty db) - if not await _check_connection_with_cache(request.source, connector): - return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") - - try: - schema = await asyncio.wait_for( - asyncio.to_thread(connector.get_schema), - timeout=30.0 - ) - except Exception as e: - return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}") - - _set_cached_schema(request.source, connector, schema) - await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)") schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":")) @@ -378,42 +379,66 @@ Language: Chinese (Simplified) try: llm_started = time.perf_counter() await emit_progress("正在生成 SQL") - response = await asyncio.wait_for( - provider.chat( - messages=messages, - max_tokens=NL2SQL_MAX_TOKENS, - temperature=NL2SQL_TEMPERATURE, - reasoning_effort=NL2SQL_REASONING_EFFORT, - ), - timeout=NL2SQL_LLM_TIMEOUT_SECONDS, - ) - - if response.finish_reason == "error": - return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error") - + response = None + last_error = "" + + for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1): + try: + response = await asyncio.wait_for( + provider.chat( + messages=messages, + max_tokens=NL2SQL_MAX_TOKENS, + temperature=NL2SQL_TEMPERATURE, + reasoning_effort=NL2SQL_REASONING_EFFORT, + request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS, + num_retries=0, + ), + timeout=NL2SQL_LLM_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s" + if attempt < NL2SQL_LLM_RETRY_COUNT: + await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") + continue + return NL2SQLResponse(sql="", result=[], error=last_error) + except Exception as e: + last_error = f"LLM generation failed: {e}" + if attempt < NL2SQL_LLM_RETRY_COUNT: + await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") + continue + return NL2SQLResponse(sql="", result=[], error=last_error) + + if response.finish_reason == "error": + last_error = response.content or "LLM Error" + if attempt < NL2SQL_LLM_RETRY_COUNT: + await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") + continue + return NL2SQLResponse(sql="", result=[], error=last_error) + break + + if response is None: + return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed") + content = (response.content or "").strip() if not content: return NL2SQLResponse(sql="", result=[], error="LLM returned empty response") - + # Clean up code blocks if "```json" in content: content = content.split("```json")[1].split("```")[0] elif "```" in content: content = content.split("```")[1].split("```")[0] - + content = content.strip() - + try: result_json = json.loads(content) sql_query = result_json.get("sql", "").strip() - reasoning = result_json.get("reasoning", "") # We can log this or return it if needed except json.JSONDecodeError: # Fallback if LLM doesn't return valid JSON despite instructions sql_query = content await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") - - except asyncio.TimeoutError: - return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s") + except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") @@ -481,13 +506,23 @@ Language: Chinese (Simplified) generate_chart(formatted_results, request.query), timeout=NL2SQL_CHART_TIMEOUT_SECONDS, ) + if not chart_response or not chart_response.chart_spec: + # Do not fallback automatically if the LLM explicitly decided not to or failed. + # Just pass whatever it returned (or lack thereof) + pass await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)") await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s") return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) except asyncio.TimeoutError: if timeout_stage == "chart_generation": - return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s") + fallback_chart = ChartGenerationResponse( + reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s", + chart_type="", + can_visualize=False, + chart_spec=None, + ) + return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart) return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s") except Exception as e: return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}") diff --git a/backend/app/api/skills.py b/backend/app/api/skills.py index 5b671c1..98c055c 100644 --- a/backend/app/api/skills.py +++ b/backend/app/api/skills.py @@ -3,6 +3,7 @@ import os import shutil import zipfile import tarfile +import re import yaml from typing import List, Optional, Dict, Any from datetime import datetime @@ -111,6 +112,26 @@ def _save_data(data: List[Dict[str, Any]]): with open(DATA_FILE, "w") as f: json.dump(data, f, indent=2, ensure_ascii=False) +def _safe_skill_dir_name(value: str) -> str: + safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', value or "").lower() + return safe or "skill" + +def _write_skill_markdown(skill_dir: str, skill_name: str, description: Optional[str], content: str) -> str: + os.makedirs(skill_dir, exist_ok=True) + skill_md_path = os.path.join(skill_dir, "SKILL.md") + final_description = description or "No description provided" + body = content or "" + markdown = ( + f"---\n" + f"name: {skill_name}\n" + f"description: {final_description}\n" + f"---\n\n" + f"{body}\n" + ) + with open(skill_md_path, "w", encoding="utf-8") as f: + f.write(markdown) + return skill_md_path + def load_skills(project_id: Optional[int] = None) -> List[Dict[str, Any]]: data = _load_data() if project_id is not None: @@ -205,9 +226,7 @@ async def upload_skill( skill_name = os.path.splitext(filename)[0] # Create a safe directory name for the skill - import re - safe_name = re.sub(r'[^a-zA-Z0-9_\-]', '_', skill_name).lower() - if not safe_name: safe_name = "skill" + safe_name = _safe_skill_dir_name(skill_name) final_skill_id = f"{safe_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}" final_skill_dir = os.path.join(SKILL_HUB_DIR, final_skill_id) @@ -264,6 +283,15 @@ def create_skill(skill: SkillCreate): new_skill_dict = skill.dict() if not new_skill_dict.get("installation_time"): new_skill_dict["installation_time"] = datetime.now().strftime("%Y年%m月%d日") + if not new_skill_dict.get("file_path"): + skill_dir = os.path.join(SKILL_HUB_DIR, _safe_skill_dir_name(new_skill_dict["id"])) + _write_skill_markdown( + skill_dir=skill_dir, + skill_name=new_skill_dict["name"], + description=new_skill_dict.get("description"), + content=new_skill_dict.get("content", ""), + ) + new_skill_dict["file_path"] = skill_dir data.append(new_skill_dict) _save_data(data) @@ -279,6 +307,13 @@ def update_skill(skill_id: str, skill: SkillUpdate, project_id: Optional[int] = updated_item = item.copy() update_data = skill.dict(exclude_unset=True) updated_item.update(update_data) + if updated_item.get("file_path"): + _write_skill_markdown( + skill_dir=updated_item["file_path"], + skill_name=updated_item.get("name") or item.get("name") or "skill", + description=updated_item.get("description"), + content=updated_item.get("content", ""), + ) data[i] = updated_item _save_data(data) return Skill(**updated_item) diff --git a/backend/app/core/nanobot.py b/backend/app/core/nanobot.py index df092b3..2cea561 100644 --- a/backend/app/core/nanobot.py +++ b/backend/app/core/nanobot.py @@ -1,6 +1,7 @@ import asyncio import sys import os +import shutil from pathlib import Path from typing import List, Callable, Awaitable, Any, Dict @@ -47,6 +48,7 @@ class NanobotIntegration: # Set workspace path to backend/data/workspace workspace_path = Path(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "workspace")) workspace_path.mkdir(parents=True, exist_ok=True) + self._sync_builtin_skills_to_workspace(workspace_path) # Override config workspace path via environment variable (since config is loaded from env) os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path) @@ -87,6 +89,20 @@ class NanobotIntegration: self._register_custom_tools(self.agent) + def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None: + builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin" + workspace_skills_root = workspace_path / "skills" + workspace_skills_root.mkdir(parents=True, exist_ok=True) + + for skill_name in ("nl2sql", "visualization"): + source_dir = builtin_root / skill_name + source_skill_file = source_dir / "SKILL.md" + if not source_skill_file.exists(): + continue + target_dir = workspace_skills_root / skill_name + target_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(source_skill_file, target_dir / "SKILL.md") + def _register_custom_tools(self, agent: AgentLoop): from app.tools.nl2sql import NL2SQLTool from app.tools.visualization import VisualizationTool diff --git a/backend/app/schemas/chart.py b/backend/app/schemas/chart.py index c85e884..1464acc 100644 --- a/backend/app/schemas/chart.py +++ b/backend/app/schemas/chart.py @@ -123,13 +123,6 @@ class ChartGenerationResponse(BaseModel): chart_type: Literal[ "line", "multi_line", "bar", "pie", "grouped_bar", "stacked_bar", "area", "" ] = Field(..., description="The type of chart generated, or empty string if none") - chart_spec: Optional[Union[ - LineChartSchema, - MultiLineChartSchema, - BarChartSchema, - PieChartSchema, - GroupedBarChartSchema, - StackedBarChartSchema, - AreaChartSchema - ]] = Field(None, description="The generated Vega-Lite chart specification") + # Using Dict[str, Any] allows LLM to output valid Vega-Lite spec directly, avoiding Pydantic strict model serialization issues with dynamic fields + chart_spec: Optional[Dict[str, Any]] = Field(None, description="The generated Vega-Lite chart specification") can_visualize: bool = Field(..., description="Whether the data can be visualized") diff --git a/backend/app/skills_builtin/nl2sql/SKILL.md b/backend/app/skills_builtin/nl2sql/SKILL.md new file mode 100644 index 0000000..f102cce --- /dev/null +++ b/backend/app/skills_builtin/nl2sql/SKILL.md @@ -0,0 +1,24 @@ +--- +description: Data Analysis and SQL Generation +metadata: + nanobot: + always: true +--- + +# NL2SQL Data Analysis Skill + +You are an expert data analyst. You have access to a powerful `nl2sql` tool that can query the connected database using natural language. + +## When to use this skill +- When the user asks to query, analyze, aggregate, or fetch data from the database. +- Examples: "Show me the top 10 sales", "What is the average revenue by month?", "How many users registered yesterday?". + +## How to use this skill +- Call the `nl2sql` tool with the user's natural language query. +- If the user explicitly asks to "visualize" or "plot" the data in the SAME message as the query (e.g., "Show me sales by region and plot it as a pie chart"), you can set `generate_chart=True` in the `nl2sql` tool. +- If the user ONLY asks to query data, set `generate_chart=False` (default). + +## After using the tool +- The tool will return a summary of the executed query and a sample of the results. +- Use this information to provide a clear, concise, and helpful response to the user. +- If a chart was successfully generated by the tool, inform the user that the chart is available in the visualization panel. diff --git a/backend/app/skills_builtin/visualization/SKILL.md b/backend/app/skills_builtin/visualization/SKILL.md new file mode 100644 index 0000000..f317a4c --- /dev/null +++ b/backend/app/skills_builtin/visualization/SKILL.md @@ -0,0 +1,23 @@ +--- +description: Data Visualization and Chart Generation +metadata: + nanobot: + always: true +--- + +# Data Visualization Skill + +You are an expert data visualization specialist. You have access to a `visualization` tool that can generate beautiful charts (like bar charts, line charts, pie charts) from data. + +## When to use this skill +- When the user asks to visualize, plot, or draw a chart based on data that has ALREADY been queried or is currently in context. +- Examples: "Visualize it as a bar chart", "Plot the trend over time", "Draw a pie chart of the regions". +- DO NOT use this tool if the data hasn't been queried yet. If the user asks a new question and wants it visualized (e.g., "Show me sales and plot it"), use the `nl2sql` tool with `generate_chart=True` instead, or call `nl2sql` first and then this tool. + +## How to use this skill +- Call the `visualization` tool with the user's specific visualization request (e.g., "plot as a pie chart"). +- The tool relies on the data from the most recent SQL query. It will automatically read this data from the context. + +## After using the tool +- The tool will return a success message and the reasoning for the chosen chart type. +- Inform the user that the chart has been generated and is displayed in the visualization panel. Explain briefly what the chart shows if helpful. diff --git a/backend/app/tools/nl2sql.py b/backend/app/tools/nl2sql.py index c5cd8e5..4036721 100644 --- a/backend/app/tools/nl2sql.py +++ b/backend/app/tools/nl2sql.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Dict +from typing import Any from nanobot.agent.tools.base import Tool from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse @@ -14,7 +14,7 @@ def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict: payload = { "sql": nl2sql_result.sql, "result": nl2sql_result.result, - "chart": chart.model_dump() if chart else None, + "chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None, "error": nl2sql_result.error, } return jsonable_encoder(payload) diff --git a/backend/app/tools/visualization.py b/backend/app/tools/visualization.py index 4e071de..7752254 100644 --- a/backend/app/tools/visualization.py +++ b/backend/app/tools/visualization.py @@ -60,7 +60,7 @@ class VisualizationTool(Tool): viz_payload = { "sql": existing_viz.get("sql", ""), "result": data, - "chart": chart_response.model_dump(), + "chart": chart_response.model_dump(by_alias=True, exclude_none=True), "error": None, } encoded_viz = jsonable_encoder(viz_payload) diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index abf5c94..7c3d8d7 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -313,8 +313,8 @@ export function ChatInterface() { }): MessageViz => { const rows = Array.isArray(payload.result) ? payload.result : []; const chart = payload.chart ?? undefined; - const canVisualize = Boolean(chart?.can_visualize); - const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null; + const canVisualize = chart?.can_visualize ?? Boolean(chart?.chart_spec); + const chartSpec = chart?.chart_spec ?? null; return { sql: typeof payload.sql === "string" ? payload.sql : "", rows, @@ -553,12 +553,12 @@ export function ChatInterface() { let renderedText = ""; const flushAssistant = (force = false) => { - if (streamedText === renderedText) return; + if (streamedText === renderedText && !force) return; if (force) { renderedText = streamedText; setMessagesForSession(targetSessionKey, (prev) => prev.map((msg) => - msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg + msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg ) ); return; @@ -571,7 +571,7 @@ export function ChatInterface() { renderedText = streamedText; setMessagesForSession(targetSessionKey, (prev) => prev.map((msg) => - msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg + msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg ) ); }); @@ -645,11 +645,7 @@ export function ChatInterface() { if (payload.type === "viz") { pushProgressLog("可视化结果已生成"); streamedViz = buildMessageViz(payload); - setMessagesForSession(targetSessionKey, (prev) => - prev.map((msg) => - msg.id === assistantId ? { ...msg, viz: streamedViz || undefined } : msg - ) - ); + flushAssistant(true); // 立即把 viz 状态刷入 messages } } } diff --git a/frontend/src/components/InlineVisualizationCard.tsx b/frontend/src/components/InlineVisualizationCard.tsx index 3533a6b..757f49c 100644 --- a/frontend/src/components/InlineVisualizationCard.tsx +++ b/frontend/src/components/InlineVisualizationCard.tsx @@ -171,8 +171,8 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) { {view === "chart" ? ( - viz.canVisualize && viz.chartSpec && objectRows.length > 0 ? ( -
+ viz.chartSpec && objectRows.length > 0 ? ( +
) : ( diff --git a/frontend/src/components/VegaChart.tsx b/frontend/src/components/VegaChart.tsx index dff51da..4230cfa 100644 --- a/frontend/src/components/VegaChart.tsx +++ b/frontend/src/components/VegaChart.tsx @@ -30,23 +30,27 @@ export const VegaChart: React.FC = ({ data, spec }) => { }, []); const vegaSpec: any = useMemo(() => { - // Clone spec and ensure tooltip is enabled in mark if not already specified - const baseSpec = { ...spec }; + // Deep clone spec to avoid mutating React state/props + const baseSpec = JSON.parse(JSON.stringify(spec)); + + // Ensure tooltip is enabled in mark if not already specified if (typeof baseSpec.mark === 'string') { baseSpec.mark = { type: baseSpec.mark, tooltip: true }; } else if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) { - baseSpec.mark = { ...baseSpec.mark, tooltip: true }; + baseSpec.mark.tooltip = true; } // Add highlight effect: hover over an element makes others transparent // 1. Define hover param if (!baseSpec.params) { - baseSpec.params = [ - { - name: "highlight", - select: { type: "point", on: "mouseover", clear: "mouseout" } - } - ]; + baseSpec.params = []; + } + const hasHighlight = baseSpec.params.some((p: any) => p.name === "highlight"); + if (!hasHighlight) { + baseSpec.params.push({ + name: "highlight", + select: { type: "point", on: "mouseover", clear: "mouseout" } + }); } // 2. Add conditional opacity to encoding @@ -64,7 +68,7 @@ export const VegaChart: React.FC = ({ data, spec }) => { // Also add cursor: pointer for marks if (typeof baseSpec.mark === 'object' && baseSpec.mark !== null) { - (baseSpec.mark as any).cursor = "pointer"; + baseSpec.mark.cursor = "pointer"; } return { @@ -77,12 +81,17 @@ export const VegaChart: React.FC = ({ data, spec }) => { }; }, [data, size.height, size.width, spec]); + const handleError = (error: any) => { + console.error("VegaEmbed rendering error:", error, "Spec:", vegaSpec); + }; + return ( -
+
); diff --git a/nanobot/nanobot/providers/litellm_provider.py b/nanobot/nanobot/providers/litellm_provider.py index cb67635..2eddd8c 100644 --- a/nanobot/nanobot/providers/litellm_provider.py +++ b/nanobot/nanobot/providers/litellm_provider.py @@ -214,6 +214,8 @@ class LiteLLMProvider(LLMProvider): max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, + request_timeout: float | None = None, + num_retries: int | None = None, ) -> LLMResponse: """ Send a chat completion request via LiteLLM. @@ -268,6 +270,12 @@ class LiteLLMProvider(LLMProvider): if tools: kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + + if request_timeout is not None: + kwargs["timeout"] = request_timeout + + if num_retries is not None: + kwargs["num_retries"] = max(0, int(num_retries)) try: response = await acompletion(**kwargs)