Files

69 lines
2.9 KiB
Python
Raw Permalink Normal View History

2026-05-13 16:43:53 +08:00
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
@pytest.mark.asyncio
async def test_nl2sql_optimized_flow():
# Mock parameters
query = "Show me the top 5 sales"
source = "ds:1"
# Mock connector and schema
mock_connector = MagicMock()
mock_connector.get_schema.return_value = {
"sales": {"columns": [{"name": "id", "type": "INT"}, {"name": "amount", "type": "DECIMAL"}]},
"users": {"columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "TEXT"}]},
"logs": {"columns": [{"name": "id", "type": "INT"}, {"name": "event", "type": "TEXT"}]},
"products": {"columns": [{"name": "id", "type": "INT"}]},
"categories": {"columns": [{"name": "id", "type": "INT"}]},
"inventory": {"columns": [{"name": "id", "type": "INT"}]}
}
mock_connector.test_connection.return_value = True
mock_connector.execute_query.return_value = [{"id": 1, "amount": 100}]
# Mock LLM provider
mock_provider = AsyncMock()
# First response for Table Selector
mock_resp_tables = MagicMock()
mock_resp_tables.content = '["sales"]'
mock_resp_tables.finish_reason = "stop"
# Second response for SQL Generation
mock_resp_sql = MagicMock()
mock_resp_sql.content = '{"reasoning": "Plan...", "sql": "SELECT * FROM sales LIMIT 5"}'
mock_resp_sql.finish_reason = "stop"
mock_provider.chat.side_effect = [mock_resp_tables, mock_resp_sql]
# Patch dependencies
with patch("app.agent.nl2sql.get_active_llm_config", return_value={"model": "gpt-4"}), \
patch("app.agent.nl2sql.build_llm_provider", return_value=mock_provider), \
patch("app.agent.nl2sql.get_connector", return_value=mock_connector), \
patch("app.agent.nl2sql.SessionLocal"), \
patch("app.agent.nl2sql.DataSource"), \
patch("app.agent.nl2sql.postgres_connector", mock_connector), \
patch("app.agent.nl2sql._check_connection_with_cache", return_value=True):
request = NL2SQLRequest(query=query, source=source)
response = await process_nl2sql(request)
print(f"DEBUG: Response SQL: '{response.sql}'")
print(f"DEBUG: Response Error: '{response.error}'")
assert response.sql == "SELECT * FROM sales LIMIT 5"
assert len(response.result) == 1
assert response.error is None
# Verify provider was called twice
assert mock_provider.chat.call_count == 2
# Verify first call was for table selection
args, kwargs = mock_provider.chat.call_args_list[0]
assert "TABLE_SELECTOR_SYSTEM_PROMPT" in str(args) or "Identifying relevant tables" in str(args) or any("system" in m["role"] for m in kwargs["messages"])
if __name__ == "__main__":
import asyncio
asyncio.run(test_nl2sql_optimized_flow())