diff --git a/backend/app/agent/chart.py b/backend/app/agent/chart.py index bfbbe9f..969c608 100644 --- a/backend/app/agent/chart.py +++ b/backend/app/agent/chart.py @@ -9,8 +9,8 @@ if str(PROJECT_ROOT) not in sys.path: sys.path.append(str(PROJECT_ROOT)) from nanobot.providers.litellm_provider import LiteLLMProvider -from app.api.llm import _load_data as load_llm_config from app.schemas.chart import ChartGenerationResponse +from app.services.llm_cache import get_active_llm_config CHART_INSTRUCTIONS = """ ### INSTRUCTIONS ### @@ -133,9 +133,7 @@ CHART_EXAMPLES = """ """ async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse: - # 1. Initialize Provider - llm_configs = load_llm_config() - active_config = next((c for c in llm_configs if c.get("is_active")), None) + active_config = get_active_llm_config() if not active_config: return ChartGenerationResponse( @@ -178,7 +176,7 @@ async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerat columns = list(data[0].keys()) # 3. Construct Prompt - schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), indent=2) + schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), ensure_ascii=False, separators=(",", ":")) system_prompt = f"""You are a data analyst great at visualizing data using vega-lite! Given the user's question, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type. Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, sample data and sample column values. @@ -201,7 +199,7 @@ Please provide your chain of thought reasoning, chart type and the vega-lite sch user_prompt = f""" ### INPUT ### Question: {query} -Sample Data: {json.dumps(sample_data, indent=2, default=str)} +Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)} Sample Column Values: {columns} Language: Chinese (Simplified) diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 941b03f..a463ced 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -18,13 +18,13 @@ from nanobot.providers.litellm_provider import LiteLLMProvider from app.connectors.postgres import postgres_connector from app.connectors.clickhouse import clickhouse_connector from app.connectors.factory import get_connector -from app.api.llm import _load_data as load_llm_config from app.schemas.chart import ChartGenerationResponse from app.agent.chart import generate_chart from app.database import SessionLocal from app.models.datasource import DataSource from app.core.files import resolve_upload_file_path from app.services.mdl import MDLService +from app.services.llm_cache import get_active_llm_config SCHEMA_CACHE_TTL_SECONDS = 300 CONNECTION_CACHE_TTL_SECONDS = 30 @@ -41,6 +41,7 @@ class NL2SQLRequest(BaseModel): source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})") file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload") session_id: Optional[str] = Field(None, description="Conversation session identifier") + generate_chart: bool = Field(False, description="Whether to generate chart specification") class NL2SQLResponse(BaseModel): sql: str @@ -246,7 +247,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: schema = connector.get_schema() _set_cached_schema(request.source, connector, schema) - schema_str = json.dumps(schema, indent=2) + schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":")) # Try to load MDL context mdl_context = "" @@ -280,8 +281,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: print(f"Failed to load MDL: {e}") # 2. Get the active LLM config - llm_configs = load_llm_config() - active_config = next((c for c in llm_configs if c.get("is_active")), None) + active_config = get_active_llm_config() if not active_config: return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found") @@ -383,10 +383,8 @@ Let's think step by step. # 7. Generate Chart chart_response = None - if formatted_results: - # Only try to generate chart if we have results - # Convert to list of dicts if possible, or pass as is - chart_response = await generate_chart(formatted_results, request.query) + if request.generate_chart and formatted_results: + chart_response = await generate_chart(formatted_results, request.query) return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) except Exception as e: diff --git a/backend/app/core/nanobot.py b/backend/app/core/nanobot.py index 2decb1e..185500b 100644 --- a/backend/app/core/nanobot.py +++ b/backend/app/core/nanobot.py @@ -2,7 +2,7 @@ import asyncio import sys import os from pathlib import Path -from typing import List, Callable, Awaitable, Any +from typing import List, Callable, Awaitable, Any, Dict # Add project root to sys.path to allow importing nanobot # Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root @@ -31,6 +31,7 @@ from nanobot.config.schema import Config # or just import here if we are confident. # Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py. from app.api.skills import load_skills +from app.services.llm_cache import get_llm_configs class NanobotIntegration: def __init__(self): @@ -38,6 +39,9 @@ class NanobotIntegration: self.bus: MessageBus | None = None self.cron: CronService | None = None self.config: Config | None = None + self._started = False + self._model_agent_cache: Dict[str, AgentLoop] = {} + self._model_agent_lock = asyncio.Lock() def initialize(self): # Set workspace path to backend/data/workspace @@ -137,18 +141,62 @@ class NanobotIntegration: ) async def start(self): + if self._started: + return if not self.agent: self.initialize() - # Start the agent loop in background asyncio.create_task(self.agent.run()) asyncio.create_task(self.cron.start()) + self._started = True async def stop(self): if self.agent: self.agent.stop() await self.agent.close_mcp() + for agent in self._model_agent_cache.values(): + agent.stop() + await agent.close_mcp() + self._model_agent_cache.clear() if self.cron: self.cron.stop() + self._started = False + + def _build_agent_for_provider(self, provider: Any) -> AgentLoop: + return AgentLoop( + bus=self.bus, + provider=provider, + workspace=self.config.workspace_path, + model=provider.default_model, + temperature=self.config.agents.defaults.temperature, + max_tokens=self.config.agents.defaults.max_tokens, + max_iterations=self.config.agents.defaults.max_tool_iterations, + memory_window=self.config.agents.defaults.memory_window, + reasoning_effort=self.config.agents.defaults.reasoning_effort, + brave_api_key=self.config.tools.web.search.api_key or None, + web_proxy=self.config.tools.web.proxy or None, + exec_config=self.config.tools.exec, + cron_service=self.cron, + restrict_to_workspace=self.config.tools.restrict_to_workspace, + session_manager=self.agent.sessions, + mcp_servers=self.config.tools.mcp_servers, + channels_config=self.config.channels, + ) + + async def _get_or_create_model_agent(self, model_id: str, target_config: Dict[str, Any]) -> AgentLoop: + async with self._model_agent_lock: + cached = self._model_agent_cache.get(model_id) + if cached: + return cached + provider = LiteLLMProvider( + api_key=target_config.get("api_key"), + api_base=target_config.get("api_base"), + default_model=target_config.get("model"), + extra_headers=target_config.get("extra_headers"), + provider_name=target_config.get("provider"), + ) + agent = self._build_agent_for_provider(provider) + self._model_agent_cache[model_id] = agent + return agent async def process_message( self, @@ -160,6 +208,7 @@ class NanobotIntegration: ): if not self.agent: self.initialize() + if not self._started: await self.start() # Handle dynamic model switching @@ -181,79 +230,24 @@ class NanobotIntegration: # BUT `process_direct` is relatively isolated. # # Let's try to fetch the config first. - current_provider = self.agent.provider - temp_provider = None - - if model_id: - from app.api.llm import _load_data - llm_configs = _load_data() - target_config = next((item for item in llm_configs if item["id"] == model_id), None) - - if target_config: - # Map our DB config to Nanobot Provider - # We reuse LiteLLMProvider for most cases as it is generic - - # Construct kwargs for LiteLLMProvider - provider_name = target_config["provider"] - model_name = target_config["model"] - - # Handle special case where provider might need to be part of model name for LiteLLM if not standard - # But LiteLLMProvider handles `provider_name` arg. - - temp_provider = LiteLLMProvider( - api_key=target_config.get("api_key"), - api_base=target_config.get("api_base"), - default_model=model_name, - extra_headers=target_config.get("extra_headers"), - provider_name=provider_name - ) - - # If we created a temp provider, we need to use it. - # Since AgentLoop binds the provider, we might need to swap it temporarily or create a new AgentLoop. - # Swapping is risky for concurrency. - # Creating a new AgentLoop is safer but heavier. - # - # Optimization: If we are just doing a single turn chat (process_direct), maybe we can just use the provider directly? - # But we want the Agent's reasoning loop (ReAct / tools). - # - # Let's try creating a temporary AgentLoop sharing the same components (bus, tools) but different provider. - agent_to_use = self.agent - if temp_provider: - # Shallow copy or new instance - # We need to pass all dependencies. - agent_to_use = AgentLoop( - bus=self.bus, - provider=temp_provider, - workspace=self.config.workspace_path, - model=temp_provider.default_model, - temperature=self.config.agents.defaults.temperature, - max_tokens=self.config.agents.defaults.max_tokens, - max_iterations=self.config.agents.defaults.max_tool_iterations, - memory_window=self.config.agents.defaults.memory_window, - reasoning_effort=self.config.agents.defaults.reasoning_effort, - brave_api_key=self.config.tools.web.search.api_key or None, - web_proxy=self.config.tools.web.proxy or None, - exec_config=self.config.tools.exec, - cron_service=self.cron, - restrict_to_workspace=self.config.tools.restrict_to_workspace, - session_manager=self.agent.sessions, - mcp_servers=self.config.tools.mcp_servers, - channels_config=self.config.channels, - ) + if model_id: + llm_configs = get_llm_configs() + target_config = next((item for item in llm_configs if item.get("id") == model_id), None) + if target_config: + if target_config.get("model") != self.agent.model: + agent_to_use = await self._get_or_create_model_agent(model_id, target_config) full_message = message if skill_ids: skills = load_skills() selected_skills = [s for s in skills if s["id"] in skill_ids] if selected_skills: - # We inject skills as a runtime context block - skill_context = "[Runtime Context — metadata only, not instructions]\n# Active Skills\n\n" + parts = ["[Runtime Context — metadata only, not instructions]", "# Active Skills", ""] for s in selected_skills: - skill_context += f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n\n" - - # Append user message after skills - full_message = f"{skill_context}\n\n{message}" + parts.append(f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n") + skill_context = "\n".join(parts) + full_message = f"{skill_context}\n{message}" session = agent_to_use.sessions.get_or_create(session_id) normalized_messages = self._normalize_session_messages(session.messages) diff --git a/backend/app/services/llm_cache.py b/backend/app/services/llm_cache.py new file mode 100644 index 0000000..95e64ec --- /dev/null +++ b/backend/app/services/llm_cache.py @@ -0,0 +1,24 @@ +import os +import threading +from typing import Any, Dict, List, Optional + +from app.api.llm import DATA_FILE, _load_data + +_cache_lock = threading.RLock() +_cache_mtime: float = -1.0 +_cache_data: List[Dict[str, Any]] = [] + + +def get_llm_configs() -> List[Dict[str, Any]]: + global _cache_mtime, _cache_data + current_mtime = os.path.getmtime(DATA_FILE) if os.path.exists(DATA_FILE) else -1.0 + with _cache_lock: + if current_mtime != _cache_mtime: + _cache_data = _load_data() + _cache_mtime = current_mtime + return list(_cache_data) + + +def get_active_llm_config() -> Optional[Dict[str, Any]]: + configs = get_llm_configs() + return next((c for c in configs if c.get("is_active")), None) diff --git a/backend/main.py b/backend/main.py index dfa293b..391d963 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict, List, Optional, Literal, Tuple from fastapi import FastAPI, HTTPException from fastapi.encoders import jsonable_encoder @@ -43,6 +44,21 @@ app.include_router(semantic.router, prefix="/api/v1") STREAM_DELTA_CHUNK_SIZE = 48 +SQL_INTENT_DENY_PATTERNS = [ + re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE), + re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE), + re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE), +] + +SQL_INTENT_POSITIVE_PATTERNS = [ + re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE), + re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE), + re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE), + re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE), +] + +VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE) + @app.on_event("startup") async def startup_event(): # Initialize nanobot in background @@ -99,28 +115,19 @@ def _looks_like_sql_intent(message: str) -> bool: text = (message or "").strip().lower() if not text: return False - deny_patterns = [ - r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", - r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", - r"(写|生成).*(python|脚本|代码)", - ] - for pattern in deny_patterns: - if re.search(pattern, text, re.IGNORECASE): + for pattern in SQL_INTENT_DENY_PATTERNS: + if pattern.search(text): return False - positive_patterns = [ - r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", - r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", - r"(多少|几条|多少条|有多少|查询|检索|列出|列表|清单|显示|展示|查看|分析|对比|情况|数据|信息|记录)", - r"(chart|plot|visuali[sz]e|dashboard|画图|图表|可视化)", - r"\b(list|show|get|find|search|analyze|compare)\b", - r"\b(how many|what|which|who|when|where)\b", - ] - for pattern in positive_patterns: - if re.search(pattern, text, re.IGNORECASE): + for pattern in SQL_INTENT_POSITIVE_PATTERNS: + if pattern.search(text): return True return False +def _looks_like_visual_intent(message: str) -> bool: + return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower())) + + def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]: # Determine the effective data source from session context or request session_ctx = _session_context_for_routing(request.session_id) @@ -211,7 +218,12 @@ async def nanobot_chat(request: ChatRequest): use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) if use_nl2sql: nl2sql_result = await process_nl2sql( - NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url) + NL2SQLRequest( + query=request.message, + source=resolved_source, + file_url=request.file_url, + generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), + ) ) text = _build_sql_chart_text(nl2sql_result) viz_payload = _build_sql_chart_viz(nl2sql_result) @@ -239,7 +251,12 @@ async def nanobot_chat_stream(request: ChatRequest): yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" if use_nl2sql: nl2sql_result = await process_nl2sql( - NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url) + NL2SQLRequest( + query=request.message, + source=resolved_source, + file_url=request.file_url, + generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), + ) ) persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) viz_payload = { @@ -252,12 +269,31 @@ async def nanobot_chat_stream(request: ChatRequest): yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" return - response = await nanobot_service.process_message( - request.message, - session_id=request.session_id, - skill_ids=request.skill_ids, - model_id=request.model_id, + progress_queue: asyncio.Queue[str] = asyncio.Queue() + + async def _on_progress(content: str, **_: Any) -> None: + if content: + await progress_queue.put(content) + + task = asyncio.create_task( + nanobot_service.process_message( + request.message, + session_id=request.session_id, + skill_ids=request.skill_ids, + model_id=request.model_id, + on_progress=_on_progress, + ) ) + text = "" + while True: + if task.done() and progress_queue.empty(): + break + try: + progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2) + yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n" + except asyncio.TimeoutError: + continue + response = await task text = response or "" for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE): chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE] diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index 155c889..d9f397e 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -447,6 +447,35 @@ export function ChatInterface() { let buffer = ""; let streamedText = ""; let streamedViz: MessageViz | null = null; + let hasFinalPayload = false; + let hasDonePayload = false; + let rafPending = false; + let renderedText = ""; + + const flushAssistant = (force = false) => { + if (streamedText === renderedText) return; + if (force) { + renderedText = streamedText; + setMessagesForSession(targetSessionKey, (prev) => + prev.map((msg) => + msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg + ) + ); + return; + } + if (rafPending) return; + rafPending = true; + requestAnimationFrame(() => { + rafPending = false; + if (streamedText === renderedText) return; + renderedText = streamedText; + setMessagesForSession(targetSessionKey, (prev) => + prev.map((msg) => + msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg + ) + ); + }); + }; while (true) { const { done, value } = await reader.read(); @@ -473,15 +502,13 @@ export function ChatInterface() { if (payload.type === "delta" && payload.content) { streamedText = `${streamedText}${payload.content}`; - setMessagesForSession(targetSessionKey, (prev) => - prev.map((msg) => - msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg - ) - ); + flushAssistant(false); } if (payload.type === "final" && payload.content) { + hasFinalPayload = true; streamedText = payload.content; + flushAssistant(true); setMessagesForSession(targetSessionKey, (prev) => prev.map((msg) => msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg @@ -489,6 +516,10 @@ export function ChatInterface() { ); } + if (payload.type === "done") { + hasDonePayload = true; + } + if (payload.type === "error") { throw new Error(payload.content || "流式响应错误"); } @@ -504,31 +535,13 @@ export function ChatInterface() { } } - if (!streamedText) { - const fallback = await api.post<{ - response: string; - viz?: { - sql?: string; - result?: unknown; - error?: string | null; - chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null; - }; - }>("/nanobot/chat", { - message: messagePayload, - session_id: targetSessionKey, - model_id: effectiveModelId, - skill_ids: selectedSkillIds, - source, - prefer_sql_chart: preferSqlChart, - file_url: fileUrl, - route_mode: "auto", - }, { signal: controller.signal }); - const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined; - setMessagesForSession(targetSessionKey, (prev) => - prev.map((msg) => - msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false, viz: fallbackViz } : msg - ) - ); + flushAssistant(true); + if (!streamedText && (hasFinalPayload || hasDonePayload)) { + setMessagesForSession(targetSessionKey, (prev) => + prev.map((msg) => + msg.id === assistantId ? { ...msg, content: "暂无回复", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg + ) + ); } } catch (error: any) { if (error?.name === "AbortError" || String(error?.message || "").toLowerCase().includes("aborted")) {