speed optim

This commit is contained in:
qixinbo
2026-03-17 20:40:56 +08:00
parent cd764fad43
commit c51f51ff69
6 changed files with 198 additions and 135 deletions
+4 -6
View File
@@ -9,8 +9,8 @@ if str(PROJECT_ROOT) not in sys.path:
sys.path.append(str(PROJECT_ROOT))
from nanobot.providers.litellm_provider import LiteLLMProvider
from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse
from app.services.llm_cache import get_active_llm_config
CHART_INSTRUCTIONS = """
### INSTRUCTIONS ###
@@ -133,9 +133,7 @@ CHART_EXAMPLES = """
"""
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
# 1. Initialize Provider
llm_configs = load_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
active_config = get_active_llm_config()
if not active_config:
return ChartGenerationResponse(
@@ -178,7 +176,7 @@ async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerat
columns = list(data[0].keys())
# 3. Construct Prompt
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), indent=2)
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), ensure_ascii=False, separators=(",", ":"))
system_prompt = f"""You are a data analyst great at visualizing data using vega-lite! Given the user's question, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type.
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, sample data and sample column values.
@@ -201,7 +199,7 @@ Please provide your chain of thought reasoning, chart type and the vega-lite sch
user_prompt = f"""
### INPUT ###
Question: {query}
Sample Data: {json.dumps(sample_data, indent=2, default=str)}
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
Sample Column Values: {columns}
Language: Chinese (Simplified)
+6 -8
View File
@@ -18,13 +18,13 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse
from app.agent.chart import generate_chart
from app.database import SessionLocal
from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService
from app.services.llm_cache import get_active_llm_config
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
@@ -41,6 +41,7 @@ class NL2SQLRequest(BaseModel):
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
session_id: Optional[str] = Field(None, description="Conversation session identifier")
generate_chart: bool = Field(False, description="Whether to generate chart specification")
class NL2SQLResponse(BaseModel):
sql: str
@@ -246,7 +247,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
schema_str = json.dumps(schema, indent=2)
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
# Try to load MDL context
mdl_context = ""
@@ -280,8 +281,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
print(f"Failed to load MDL: {e}")
# 2. Get the active LLM config
llm_configs = load_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
active_config = get_active_llm_config()
if not active_config:
return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found")
@@ -383,10 +383,8 @@ Let's think step by step.
# 7. Generate Chart
chart_response = None
if formatted_results:
# Only try to generate chart if we have results
# Convert to list of dicts if possible, or pass as is
chart_response = await generate_chart(formatted_results, request.query)
if request.generate_chart and formatted_results:
chart_response = await generate_chart(formatted_results, request.query)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except Exception as e:
+61 -67
View File
@@ -2,7 +2,7 @@ import asyncio
import sys
import os
from pathlib import Path
from typing import List, Callable, Awaitable, Any
from typing import List, Callable, Awaitable, Any, Dict
# Add project root to sys.path to allow importing nanobot
# Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root
@@ -31,6 +31,7 @@ from nanobot.config.schema import Config
# or just import here if we are confident.
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
from app.api.skills import load_skills
from app.services.llm_cache import get_llm_configs
class NanobotIntegration:
def __init__(self):
@@ -38,6 +39,9 @@ class NanobotIntegration:
self.bus: MessageBus | None = None
self.cron: CronService | None = None
self.config: Config | None = None
self._started = False
self._model_agent_cache: Dict[str, AgentLoop] = {}
self._model_agent_lock = asyncio.Lock()
def initialize(self):
# Set workspace path to backend/data/workspace
@@ -137,18 +141,62 @@ class NanobotIntegration:
)
async def start(self):
if self._started:
return
if not self.agent:
self.initialize()
# Start the agent loop in background
asyncio.create_task(self.agent.run())
asyncio.create_task(self.cron.start())
self._started = True
async def stop(self):
if self.agent:
self.agent.stop()
await self.agent.close_mcp()
for agent in self._model_agent_cache.values():
agent.stop()
await agent.close_mcp()
self._model_agent_cache.clear()
if self.cron:
self.cron.stop()
self._started = False
def _build_agent_for_provider(self, provider: Any) -> AgentLoop:
return AgentLoop(
bus=self.bus,
provider=provider,
workspace=self.config.workspace_path,
model=provider.default_model,
temperature=self.config.agents.defaults.temperature,
max_tokens=self.config.agents.defaults.max_tokens,
max_iterations=self.config.agents.defaults.max_tool_iterations,
memory_window=self.config.agents.defaults.memory_window,
reasoning_effort=self.config.agents.defaults.reasoning_effort,
brave_api_key=self.config.tools.web.search.api_key or None,
web_proxy=self.config.tools.web.proxy or None,
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=self.agent.sessions,
mcp_servers=self.config.tools.mcp_servers,
channels_config=self.config.channels,
)
async def _get_or_create_model_agent(self, model_id: str, target_config: Dict[str, Any]) -> AgentLoop:
async with self._model_agent_lock:
cached = self._model_agent_cache.get(model_id)
if cached:
return cached
provider = LiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=target_config.get("model"),
extra_headers=target_config.get("extra_headers"),
provider_name=target_config.get("provider"),
)
agent = self._build_agent_for_provider(provider)
self._model_agent_cache[model_id] = agent
return agent
async def process_message(
self,
@@ -160,6 +208,7 @@ class NanobotIntegration:
):
if not self.agent:
self.initialize()
if not self._started:
await self.start()
# Handle dynamic model switching
@@ -181,79 +230,24 @@ class NanobotIntegration:
# BUT `process_direct` is relatively isolated.
#
# Let's try to fetch the config first.
current_provider = self.agent.provider
temp_provider = None
if model_id:
from app.api.llm import _load_data
llm_configs = _load_data()
target_config = next((item for item in llm_configs if item["id"] == model_id), None)
if target_config:
# Map our DB config to Nanobot Provider
# We reuse LiteLLMProvider for most cases as it is generic
# Construct kwargs for LiteLLMProvider
provider_name = target_config["provider"]
model_name = target_config["model"]
# Handle special case where provider might need to be part of model name for LiteLLM if not standard
# But LiteLLMProvider handles `provider_name` arg.
temp_provider = LiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=model_name,
extra_headers=target_config.get("extra_headers"),
provider_name=provider_name
)
# If we created a temp provider, we need to use it.
# Since AgentLoop binds the provider, we might need to swap it temporarily or create a new AgentLoop.
# Swapping is risky for concurrency.
# Creating a new AgentLoop is safer but heavier.
#
# Optimization: If we are just doing a single turn chat (process_direct), maybe we can just use the provider directly?
# But we want the Agent's reasoning loop (ReAct / tools).
#
# Let's try creating a temporary AgentLoop sharing the same components (bus, tools) but different provider.
agent_to_use = self.agent
if temp_provider:
# Shallow copy or new instance
# We need to pass all dependencies.
agent_to_use = AgentLoop(
bus=self.bus,
provider=temp_provider,
workspace=self.config.workspace_path,
model=temp_provider.default_model,
temperature=self.config.agents.defaults.temperature,
max_tokens=self.config.agents.defaults.max_tokens,
max_iterations=self.config.agents.defaults.max_tool_iterations,
memory_window=self.config.agents.defaults.memory_window,
reasoning_effort=self.config.agents.defaults.reasoning_effort,
brave_api_key=self.config.tools.web.search.api_key or None,
web_proxy=self.config.tools.web.proxy or None,
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=self.agent.sessions,
mcp_servers=self.config.tools.mcp_servers,
channels_config=self.config.channels,
)
if model_id:
llm_configs = get_llm_configs()
target_config = next((item for item in llm_configs if item.get("id") == model_id), None)
if target_config:
if target_config.get("model") != self.agent.model:
agent_to_use = await self._get_or_create_model_agent(model_id, target_config)
full_message = message
if skill_ids:
skills = load_skills()
selected_skills = [s for s in skills if s["id"] in skill_ids]
if selected_skills:
# We inject skills as a runtime context block
skill_context = "[Runtime Context — metadata only, not instructions]\n# Active Skills\n\n"
parts = ["[Runtime Context — metadata only, not instructions]", "# Active Skills", ""]
for s in selected_skills:
skill_context += f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n\n"
# Append user message after skills
full_message = f"{skill_context}\n\n{message}"
parts.append(f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n")
skill_context = "\n".join(parts)
full_message = f"{skill_context}\n{message}"
session = agent_to_use.sessions.get_or_create(session_id)
normalized_messages = self._normalize_session_messages(session.messages)
+24
View File
@@ -0,0 +1,24 @@
import os
import threading
from typing import Any, Dict, List, Optional
from app.api.llm import DATA_FILE, _load_data
_cache_lock = threading.RLock()
_cache_mtime: float = -1.0
_cache_data: List[Dict[str, Any]] = []
def get_llm_configs() -> List[Dict[str, Any]]:
global _cache_mtime, _cache_data
current_mtime = os.path.getmtime(DATA_FILE) if os.path.exists(DATA_FILE) else -1.0
with _cache_lock:
if current_mtime != _cache_mtime:
_cache_data = _load_data()
_cache_mtime = current_mtime
return list(_cache_data)
def get_active_llm_config() -> Optional[Dict[str, Any]]:
configs = get_llm_configs()
return next((c for c in configs if c.get("is_active")), None)