speed optim
This commit is contained in:
@@ -9,8 +9,8 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.append(str(PROJECT_ROOT))
|
||||
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.services.llm_cache import get_active_llm_config
|
||||
|
||||
CHART_INSTRUCTIONS = """
|
||||
### INSTRUCTIONS ###
|
||||
@@ -133,9 +133,7 @@ CHART_EXAMPLES = """
|
||||
"""
|
||||
|
||||
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
||||
# 1. Initialize Provider
|
||||
llm_configs = load_llm_config()
|
||||
active_config = next((c for c in llm_configs if c.get("is_active")), None)
|
||||
active_config = get_active_llm_config()
|
||||
|
||||
if not active_config:
|
||||
return ChartGenerationResponse(
|
||||
@@ -178,7 +176,7 @@ async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerat
|
||||
columns = list(data[0].keys())
|
||||
|
||||
# 3. Construct Prompt
|
||||
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), indent=2)
|
||||
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
system_prompt = f"""You are a data analyst great at visualizing data using vega-lite! Given the user's question, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type.
|
||||
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, sample data and sample column values.
|
||||
@@ -201,7 +199,7 @@ Please provide your chain of thought reasoning, chart type and the vega-lite sch
|
||||
user_prompt = f"""
|
||||
### INPUT ###
|
||||
Question: {query}
|
||||
Sample Data: {json.dumps(sample_data, indent=2, default=str)}
|
||||
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
|
||||
Sample Column Values: {columns}
|
||||
Language: Chinese (Simplified)
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
from app.core.files import resolve_upload_file_path
|
||||
from app.services.mdl import MDLService
|
||||
from app.services.llm_cache import get_active_llm_config
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -41,6 +41,7 @@ class NL2SQLRequest(BaseModel):
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
|
||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||
session_id: Optional[str] = Field(None, description="Conversation session identifier")
|
||||
generate_chart: bool = Field(False, description="Whether to generate chart specification")
|
||||
|
||||
class NL2SQLResponse(BaseModel):
|
||||
sql: str
|
||||
@@ -246,7 +247,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
# Try to load MDL context
|
||||
mdl_context = ""
|
||||
@@ -280,8 +281,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
print(f"Failed to load MDL: {e}")
|
||||
|
||||
# 2. Get the active LLM config
|
||||
llm_configs = load_llm_config()
|
||||
active_config = next((c for c in llm_configs if c.get("is_active")), None)
|
||||
active_config = get_active_llm_config()
|
||||
|
||||
if not active_config:
|
||||
return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found")
|
||||
@@ -383,10 +383,8 @@ Let's think step by step.
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if formatted_results:
|
||||
# Only try to generate chart if we have results
|
||||
# Convert to list of dicts if possible, or pass as is
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
if request.generate_chart and formatted_results:
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user