import pytest import json from unittest.mock import AsyncMock, MagicMock, patch from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse @pytest.mark.asyncio async def test_nl2sql_optimized_flow(): # Mock parameters query = "Show me the top 5 sales" source = "ds:1" # Mock connector and schema mock_connector = MagicMock() mock_connector.get_schema.return_value = { "sales": {"columns": [{"name": "id", "type": "INT"}, {"name": "amount", "type": "DECIMAL"}]}, "users": {"columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "TEXT"}]}, "logs": {"columns": [{"name": "id", "type": "INT"}, {"name": "event", "type": "TEXT"}]}, "products": {"columns": [{"name": "id", "type": "INT"}]}, "categories": {"columns": [{"name": "id", "type": "INT"}]}, "inventory": {"columns": [{"name": "id", "type": "INT"}]} } mock_connector.test_connection.return_value = True mock_connector.execute_query.return_value = [{"id": 1, "amount": 100}] # Mock LLM provider mock_provider = AsyncMock() # First response for Table Selector mock_resp_tables = MagicMock() mock_resp_tables.content = '["sales"]' mock_resp_tables.finish_reason = "stop" # Second response for SQL Generation mock_resp_sql = MagicMock() mock_resp_sql.content = '{"reasoning": "Plan...", "sql": "SELECT * FROM sales LIMIT 5"}' mock_resp_sql.finish_reason = "stop" mock_provider.chat.side_effect = [mock_resp_tables, mock_resp_sql] # Patch dependencies with patch("app.agent.nl2sql.get_active_llm_config", return_value={"model": "gpt-4"}), \ patch("app.agent.nl2sql.build_llm_provider", return_value=mock_provider), \ patch("app.agent.nl2sql.get_connector", return_value=mock_connector), \ patch("app.agent.nl2sql.SessionLocal"), \ patch("app.agent.nl2sql.DataSource"), \ patch("app.agent.nl2sql.postgres_connector", mock_connector), \ patch("app.agent.nl2sql._check_connection_with_cache", return_value=True): request = NL2SQLRequest(query=query, source=source) response = await process_nl2sql(request) print(f"DEBUG: Response SQL: '{response.sql}'") print(f"DEBUG: Response Error: '{response.error}'") assert response.sql == "SELECT * FROM sales LIMIT 5" assert len(response.result) == 1 assert response.error is None # Verify provider was called twice assert mock_provider.chat.call_count == 2 # Verify first call was for table selection args, kwargs = mock_provider.chat.call_args_list[0] assert "TABLE_SELECTOR_SYSTEM_PROMPT" in str(args) or "Identifying relevant tables" in str(args) or any("system" in m["role"] for m in kwargs["messages"]) if __name__ == "__main__": import asyncio asyncio.run(test_nl2sql_optimized_flow())