Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nl2sql_optimized_flow():
|
||||
# Mock parameters
|
||||
query = "Show me the top 5 sales"
|
||||
source = "ds:1"
|
||||
|
||||
# Mock connector and schema
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.get_schema.return_value = {
|
||||
"sales": {"columns": [{"name": "id", "type": "INT"}, {"name": "amount", "type": "DECIMAL"}]},
|
||||
"users": {"columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "TEXT"}]},
|
||||
"logs": {"columns": [{"name": "id", "type": "INT"}, {"name": "event", "type": "TEXT"}]},
|
||||
"products": {"columns": [{"name": "id", "type": "INT"}]},
|
||||
"categories": {"columns": [{"name": "id", "type": "INT"}]},
|
||||
"inventory": {"columns": [{"name": "id", "type": "INT"}]}
|
||||
}
|
||||
mock_connector.test_connection.return_value = True
|
||||
mock_connector.execute_query.return_value = [{"id": 1, "amount": 100}]
|
||||
|
||||
# Mock LLM provider
|
||||
mock_provider = AsyncMock()
|
||||
|
||||
# First response for Table Selector
|
||||
mock_resp_tables = MagicMock()
|
||||
mock_resp_tables.content = '["sales"]'
|
||||
mock_resp_tables.finish_reason = "stop"
|
||||
|
||||
# Second response for SQL Generation
|
||||
mock_resp_sql = MagicMock()
|
||||
mock_resp_sql.content = '{"reasoning": "Plan...", "sql": "SELECT * FROM sales LIMIT 5"}'
|
||||
mock_resp_sql.finish_reason = "stop"
|
||||
|
||||
mock_provider.chat.side_effect = [mock_resp_tables, mock_resp_sql]
|
||||
|
||||
# Patch dependencies
|
||||
with patch("app.agent.nl2sql.get_active_llm_config", return_value={"model": "gpt-4"}), \
|
||||
patch("app.agent.nl2sql.build_llm_provider", return_value=mock_provider), \
|
||||
patch("app.agent.nl2sql.get_connector", return_value=mock_connector), \
|
||||
patch("app.agent.nl2sql.SessionLocal"), \
|
||||
patch("app.agent.nl2sql.DataSource"), \
|
||||
patch("app.agent.nl2sql.postgres_connector", mock_connector), \
|
||||
patch("app.agent.nl2sql._check_connection_with_cache", return_value=True):
|
||||
|
||||
request = NL2SQLRequest(query=query, source=source)
|
||||
response = await process_nl2sql(request)
|
||||
|
||||
print(f"DEBUG: Response SQL: '{response.sql}'")
|
||||
print(f"DEBUG: Response Error: '{response.error}'")
|
||||
|
||||
assert response.sql == "SELECT * FROM sales LIMIT 5"
|
||||
assert len(response.result) == 1
|
||||
assert response.error is None
|
||||
|
||||
# Verify provider was called twice
|
||||
assert mock_provider.chat.call_count == 2
|
||||
|
||||
# Verify first call was for table selection
|
||||
args, kwargs = mock_provider.chat.call_args_list[0]
|
||||
assert "TABLE_SELECTOR_SYSTEM_PROMPT" in str(args) or "Identifying relevant tables" in str(args) or any("system" in m["role"] for m in kwargs["messages"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(test_nl2sql_optimized_flow())
|
||||
Reference in New Issue
Block a user