69 lines
2.9 KiB
Python
69 lines
2.9 KiB
Python
|
|
import pytest
|
||
|
|
import json
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_nl2sql_optimized_flow():
|
||
|
|
# Mock parameters
|
||
|
|
query = "Show me the top 5 sales"
|
||
|
|
source = "ds:1"
|
||
|
|
|
||
|
|
# Mock connector and schema
|
||
|
|
mock_connector = MagicMock()
|
||
|
|
mock_connector.get_schema.return_value = {
|
||
|
|
"sales": {"columns": [{"name": "id", "type": "INT"}, {"name": "amount", "type": "DECIMAL"}]},
|
||
|
|
"users": {"columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "TEXT"}]},
|
||
|
|
"logs": {"columns": [{"name": "id", "type": "INT"}, {"name": "event", "type": "TEXT"}]},
|
||
|
|
"products": {"columns": [{"name": "id", "type": "INT"}]},
|
||
|
|
"categories": {"columns": [{"name": "id", "type": "INT"}]},
|
||
|
|
"inventory": {"columns": [{"name": "id", "type": "INT"}]}
|
||
|
|
}
|
||
|
|
mock_connector.test_connection.return_value = True
|
||
|
|
mock_connector.execute_query.return_value = [{"id": 1, "amount": 100}]
|
||
|
|
|
||
|
|
# Mock LLM provider
|
||
|
|
mock_provider = AsyncMock()
|
||
|
|
|
||
|
|
# First response for Table Selector
|
||
|
|
mock_resp_tables = MagicMock()
|
||
|
|
mock_resp_tables.content = '["sales"]'
|
||
|
|
mock_resp_tables.finish_reason = "stop"
|
||
|
|
|
||
|
|
# Second response for SQL Generation
|
||
|
|
mock_resp_sql = MagicMock()
|
||
|
|
mock_resp_sql.content = '{"reasoning": "Plan...", "sql": "SELECT * FROM sales LIMIT 5"}'
|
||
|
|
mock_resp_sql.finish_reason = "stop"
|
||
|
|
|
||
|
|
mock_provider.chat.side_effect = [mock_resp_tables, mock_resp_sql]
|
||
|
|
|
||
|
|
# Patch dependencies
|
||
|
|
with patch("app.agent.nl2sql.get_active_llm_config", return_value={"model": "gpt-4"}), \
|
||
|
|
patch("app.agent.nl2sql.build_llm_provider", return_value=mock_provider), \
|
||
|
|
patch("app.agent.nl2sql.get_connector", return_value=mock_connector), \
|
||
|
|
patch("app.agent.nl2sql.SessionLocal"), \
|
||
|
|
patch("app.agent.nl2sql.DataSource"), \
|
||
|
|
patch("app.agent.nl2sql.postgres_connector", mock_connector), \
|
||
|
|
patch("app.agent.nl2sql._check_connection_with_cache", return_value=True):
|
||
|
|
|
||
|
|
request = NL2SQLRequest(query=query, source=source)
|
||
|
|
response = await process_nl2sql(request)
|
||
|
|
|
||
|
|
print(f"DEBUG: Response SQL: '{response.sql}'")
|
||
|
|
print(f"DEBUG: Response Error: '{response.error}'")
|
||
|
|
|
||
|
|
assert response.sql == "SELECT * FROM sales LIMIT 5"
|
||
|
|
assert len(response.result) == 1
|
||
|
|
assert response.error is None
|
||
|
|
|
||
|
|
# Verify provider was called twice
|
||
|
|
assert mock_provider.chat.call_count == 2
|
||
|
|
|
||
|
|
# Verify first call was for table selection
|
||
|
|
args, kwargs = mock_provider.chat.call_args_list[0]
|
||
|
|
assert "TABLE_SELECTOR_SYSTEM_PROMPT" in str(args) or "Identifying relevant tables" in str(args) or any("system" in m["role"] for m in kwargs["messages"])
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import asyncio
|
||
|
|
asyncio.run(test_nl2sql_optimized_flow())
|